diff --git a/cmd/import-adjustment-stock-prices/main.go b/cmd/import-adjustment-stock-prices/main.go new file mode 100644 index 00000000..e4092e79 --- /dev/null +++ b/cmd/import-adjustment-stock-prices/main.go @@ -0,0 +1,587 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "os" + "sort" + "strconv" + "strings" + + "github.com/xuri/excelize/v2" + "gitlab.com/mbugroup/lti-api.git/internal/config" + "gitlab.com/mbugroup/lti-api.git/internal/database" + "gorm.io/gorm" +) + +type importOptions struct { + FilePath string + Sheet string + Apply bool +} + +type headerIndexes struct { + AdjustmentID int + Weight int +} + +type adjustmentPriceImportRow struct { + RowNumber int + AdjustmentID uint + Weight float64 +} + +type validationIssue struct { + Row int + Field string + Message string +} + +func (i validationIssue) Error() string { + if i.Row > 0 { + return fmt.Sprintf("row=%d field=%s message=%s", i.Row, i.Field, i.Message) + } + return fmt.Sprintf("field=%s message=%s", i.Field, i.Message) +} + +type adjustmentResolver interface { + ResolveExistingAdjustmentIDs(ctx context.Context, adjustmentIDs []uint) (map[uint]struct{}, error) +} + +type dbAdjustmentResolver struct { + db *gorm.DB +} + +type adjustmentPriceStore interface { + UpdatePrice(ctx context.Context, adjustmentID uint, price float64) (bool, error) +} + +type txRunner interface { + InTx(ctx context.Context, fn func(store adjustmentPriceStore) error) error +} + +type dbTxRunner struct { + db *gorm.DB +} + +type dbAdjustmentPriceStore struct { + db *gorm.DB +} + +type applyRowResult struct { + RowNumber int + AdjustmentID uint + Price float64 + Changed bool +} + +func main() { + var opts importOptions + flag.StringVar(&opts.FilePath, "file", "", "Path to .xlsx file (required)") + flag.StringVar(&opts.Sheet, "sheet", "", "Sheet name (optional, default: first sheet)") + flag.BoolVar(&opts.Apply, "apply", false, "Apply changes. If false, run as dry-run") + flag.Parse() + + opts.FilePath = strings.TrimSpace(opts.FilePath) + opts.Sheet = strings.TrimSpace(opts.Sheet) + + if opts.FilePath == "" { + log.Fatal("--file is required") + } + + sheetName, rows, parseIssues, err := parseAdjustmentPriceFile(opts.FilePath, opts.Sheet) + if err != nil { + log.Fatalf("failed reading excel: %v", err) + } + + ctx := context.Background() + db := database.Connect(config.DBHost, config.DBName) + resolver := dbAdjustmentResolver{db: db} + + existingAdjustmentIDs, err := resolver.ResolveExistingAdjustmentIDs(ctx, collectAdjustmentIDs(rows)) + if err != nil { + log.Fatalf("failed checking adjustment_id against adjustment_stocks: %v", err) + } + + processableRows, skippedRows := splitRowsByExistingIDs(rows, existingAdjustmentIDs) + issues := append([]validationIssue{}, parseIssues...) + sortValidationIssues(issues) + + fmt.Printf("Mode: %s\n", modeLabel(opts.Apply)) + fmt.Printf("File: %s\n", opts.FilePath) + fmt.Printf("Sheet: %s\n", sheetName) + fmt.Printf("Rows parsed: %d\n", len(rows)) + fmt.Printf("Rows invalid: %d\n", len(issues)) + fmt.Printf("Rows processable: %d\n", len(processableRows)) + fmt.Printf("Rows skipped_missing: %d\n", len(skippedRows)) + fmt.Println() + + if len(processableRows) > 0 { + printPlanRows(processableRows) + } + if len(skippedRows) > 0 { + printSkippedRows(skippedRows) + } + if len(processableRows) > 0 || len(skippedRows) > 0 { + fmt.Println() + } + + if len(issues) > 0 { + fmt.Println("Validation errors:") + for _, issue := range issues { + fmt.Printf("ERROR %s\n", issue.Error()) + } + fmt.Println() + fmt.Printf( + "Summary: planned=%d processable=%d skipped_missing=%d applied=0 failed=%d\n", + len(rows), + len(processableRows), + len(skippedRows), + len(issues), + ) + os.Exit(1) + } + + if !opts.Apply { + fmt.Printf( + "Summary: planned=%d processable=%d skipped_missing=%d applied=0 failed=0\n", + len(rows), + len(processableRows), + len(skippedRows), + ) + return + } + + results, err := applyIfRequested(ctx, true, dbTxRunner{db: db}, processableRows) + if err != nil { + log.Fatalf("apply failed: %v", err) + } + + for _, result := range results { + fmt.Printf( + "DONE row=%d adjustment_id=%d price=%.3f status=%s\n", + result.RowNumber, + result.AdjustmentID, + result.Price, + applyStatus(result.Changed), + ) + } + + appliedCount := countChangedRows(results) + if len(results) > 0 { + fmt.Println() + } + fmt.Printf( + "Summary: planned=%d processable=%d skipped_missing=%d applied=%d failed=0\n", + len(rows), + len(processableRows), + len(skippedRows), + appliedCount, + ) +} + +func parseAdjustmentPriceFile( + filePath string, + requestedSheet string, +) (string, []adjustmentPriceImportRow, []validationIssue, error) { + workbook, err := excelize.OpenFile(filePath) + if err != nil { + return "", nil, nil, err + } + defer func() { + _ = workbook.Close() + }() + + sheetName, err := resolveSheetName(workbook, requestedSheet) + if err != nil { + return "", nil, nil, err + } + + allRows, err := workbook.GetRows(sheetName, excelize.Options{RawCellValue: true}) + if err != nil { + return "", nil, nil, err + } + if len(allRows) == 0 { + return sheetName, nil, []validationIssue{{Field: "header", Message: "sheet is empty"}}, nil + } + + indexes, headerIssues := parseHeaderIndexes(allRows[0]) + if len(headerIssues) > 0 { + return sheetName, nil, headerIssues, nil + } + + rowsByAdjustmentID := make(map[uint]adjustmentPriceImportRow) + issues := make([]validationIssue, 0) + + for idx := 1; idx < len(allRows); idx++ { + rowNumber := idx + 1 + rawRow := allRows[idx] + + if isRowEmpty(rawRow) { + continue + } + + parsed, rowIssues := parseDataRow(rawRow, rowNumber, indexes) + if len(rowIssues) > 0 { + issues = append(issues, rowIssues...) + continue + } + + rowsByAdjustmentID[parsed.AdjustmentID] = *parsed + } + + rows := make([]adjustmentPriceImportRow, 0, len(rowsByAdjustmentID)) + for _, row := range rowsByAdjustmentID { + rows = append(rows, row) + } + sort.Slice(rows, func(i, j int) bool { + return rows[i].RowNumber < rows[j].RowNumber + }) + + if len(rows) == 0 && len(issues) == 0 { + issues = append(issues, validationIssue{Field: "rows", Message: "no data rows found"}) + } + + return sheetName, rows, issues, nil +} + +func resolveSheetName(workbook *excelize.File, requestedSheet string) (string, error) { + if workbook == nil { + return "", fmt.Errorf("workbook is nil") + } + + sheets := workbook.GetSheetList() + if len(sheets) == 0 { + return "", fmt.Errorf("workbook has no sheets") + } + + if requestedSheet == "" { + return sheets[0], nil + } + + for _, sheet := range sheets { + if strings.EqualFold(strings.TrimSpace(sheet), strings.TrimSpace(requestedSheet)) { + return sheet, nil + } + } + + return "", fmt.Errorf("sheet %q not found", requestedSheet) +} + +func parseHeaderIndexes(headerRow []string) (headerIndexes, []validationIssue) { + indexes := headerIndexes{AdjustmentID: -1, Weight: -1} + issues := make([]validationIssue, 0) + + for idx, raw := range headerRow { + header := normalizeHeader(raw) + if header == "" { + continue + } + + switch header { + case "adjustment_id": + if indexes.AdjustmentID >= 0 { + issues = append(issues, validationIssue{Field: "header", Message: "duplicate header adjustment_id"}) + } + indexes.AdjustmentID = idx + case "weight": + if indexes.Weight >= 0 { + issues = append(issues, validationIssue{Field: "header", Message: "duplicate header weight"}) + } + indexes.Weight = idx + } + } + + if indexes.AdjustmentID < 0 { + issues = append(issues, validationIssue{Field: "adjustment_id", Message: "required header is missing"}) + } + if indexes.Weight < 0 { + issues = append(issues, validationIssue{Field: "weight", Message: "required header is missing"}) + } + + return indexes, issues +} + +func parseDataRow( + rawRow []string, + rowNumber int, + indexes headerIndexes, +) (*adjustmentPriceImportRow, []validationIssue) { + issues := make([]validationIssue, 0) + + adjustmentIDRaw := strings.TrimSpace(cellValue(rawRow, indexes.AdjustmentID)) + adjustmentID, err := parsePositiveUint(adjustmentIDRaw) + if err != nil { + issues = append(issues, validationIssue{Row: rowNumber, Field: "adjustment_id", Message: err.Error()}) + } + + weightRaw := strings.TrimSpace(cellValue(rawRow, indexes.Weight)) + weight, err := parseNonNegativeFloat(weightRaw) + if err != nil { + issues = append(issues, validationIssue{Row: rowNumber, Field: "weight", Message: err.Error()}) + } + + if len(issues) > 0 { + return nil, issues + } + + return &adjustmentPriceImportRow{ + RowNumber: rowNumber, + AdjustmentID: adjustmentID, + Weight: weight, + }, nil +} + +func parsePositiveUint(raw string) (uint, error) { + if raw == "" { + return 0, fmt.Errorf("is required") + } + + uintValue, err := strconv.ParseUint(raw, 10, 64) + if err == nil { + if uintValue == 0 { + return 0, fmt.Errorf("must be greater than 0") + } + return uint(uintValue), nil + } + + floatValue, floatErr := strconv.ParseFloat(raw, 64) + if floatErr != nil { + return 0, fmt.Errorf("must be a positive integer") + } + if floatValue <= 0 { + return 0, fmt.Errorf("must be greater than 0") + } + if floatValue != float64(uint(floatValue)) { + return 0, fmt.Errorf("must be a positive integer") + } + + return uint(floatValue), nil +} + +func parseNonNegativeFloat(raw string) (float64, error) { + if raw == "" { + return 0, fmt.Errorf("is required") + } + + value, err := strconv.ParseFloat(raw, 64) + if err != nil { + return 0, fmt.Errorf("must be numeric") + } + if value < 0 { + return 0, fmt.Errorf("must be greater than or equal to 0") + } + + return value, nil +} + +func isRowEmpty(row []string) bool { + for _, cell := range row { + if strings.TrimSpace(cell) != "" { + return false + } + } + return true +} + +func normalizeHeader(raw string) string { + return strings.ToLower(strings.TrimSpace(raw)) +} + +func cellValue(row []string, index int) string { + if index < 0 || index >= len(row) { + return "" + } + return row[index] +} + +func collectAdjustmentIDs(rows []adjustmentPriceImportRow) []uint { + ids := make([]uint, 0, len(rows)) + seen := make(map[uint]struct{}, len(rows)) + for _, row := range rows { + if row.AdjustmentID == 0 { + continue + } + if _, exists := seen[row.AdjustmentID]; exists { + continue + } + seen[row.AdjustmentID] = struct{}{} + ids = append(ids, row.AdjustmentID) + } + sort.Slice(ids, func(i, j int) bool { + return ids[i] < ids[j] + }) + return ids +} + +func (r dbAdjustmentResolver) ResolveExistingAdjustmentIDs( + ctx context.Context, + adjustmentIDs []uint, +) (map[uint]struct{}, error) { + result := make(map[uint]struct{}) + if len(adjustmentIDs) == 0 { + return result, nil + } + + type adjustmentIDRow struct { + ID uint `gorm:"column:id"` + } + + rows := make([]adjustmentIDRow, 0, len(adjustmentIDs)) + if err := r.db.WithContext(ctx). + Table("adjustment_stocks"). + Select("id"). + Where("id IN ?", adjustmentIDs). + Scan(&rows).Error; err != nil { + return nil, err + } + + for _, row := range rows { + result[row.ID] = struct{}{} + } + + return result, nil +} + +func splitRowsByExistingIDs( + rows []adjustmentPriceImportRow, + existing map[uint]struct{}, +) ([]adjustmentPriceImportRow, []adjustmentPriceImportRow) { + processable := make([]adjustmentPriceImportRow, 0, len(rows)) + skipped := make([]adjustmentPriceImportRow, 0) + + for _, row := range rows { + if _, exists := existing[row.AdjustmentID]; exists { + processable = append(processable, row) + continue + } + skipped = append(skipped, row) + } + + return processable, skipped +} + +func printPlanRows(rows []adjustmentPriceImportRow) { + for _, row := range rows { + fmt.Printf( + "PLAN row=%d adjustment_id=%d price=%.3f\n", + row.RowNumber, + row.AdjustmentID, + row.Weight, + ) + } +} + +func printSkippedRows(rows []adjustmentPriceImportRow) { + for _, row := range rows { + fmt.Printf( + "SKIP row=%d adjustment_id=%d reason=adjustment_id not found\n", + row.RowNumber, + row.AdjustmentID, + ) + } +} + +func sortValidationIssues(issues []validationIssue) { + sort.Slice(issues, func(i, j int) bool { + if issues[i].Row == issues[j].Row { + if issues[i].Field == issues[j].Field { + return issues[i].Message < issues[j].Message + } + return issues[i].Field < issues[j].Field + } + return issues[i].Row < issues[j].Row + }) +} + +func applyIfRequested( + ctx context.Context, + apply bool, + runner txRunner, + rows []adjustmentPriceImportRow, +) ([]applyRowResult, error) { + if !apply || len(rows) == 0 { + return nil, nil + } + return applyImportRows(ctx, runner, rows) +} + +func applyImportRows( + ctx context.Context, + runner txRunner, + rows []adjustmentPriceImportRow, +) ([]applyRowResult, error) { + results := make([]applyRowResult, 0, len(rows)) + + err := runner.InTx(ctx, func(store adjustmentPriceStore) error { + for _, row := range rows { + changed, err := store.UpdatePrice(ctx, row.AdjustmentID, row.Weight) + if err != nil { + return fmt.Errorf("row %d adjustment_id=%d update failed: %w", row.RowNumber, row.AdjustmentID, err) + } + + results = append(results, applyRowResult{ + RowNumber: row.RowNumber, + AdjustmentID: row.AdjustmentID, + Price: row.Weight, + Changed: changed, + }) + } + return nil + }) + if err != nil { + return nil, err + } + + return results, nil +} + +func (r dbTxRunner) InTx(ctx context.Context, fn func(store adjustmentPriceStore) error) error { + return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + return fn(dbAdjustmentPriceStore{db: tx}) + }) +} + +func (s dbAdjustmentPriceStore) UpdatePrice( + ctx context.Context, + adjustmentID uint, + price float64, +) (bool, error) { + result := s.db.WithContext(ctx).Exec(` + UPDATE adjustment_stocks + SET price = ?, + updated_at = NOW() + WHERE id = ? + AND price IS DISTINCT FROM ? + `, price, adjustmentID, price) + if result.Error != nil { + return false, result.Error + } + return result.RowsAffected > 0, nil +} + +func modeLabel(apply bool) string { + if apply { + return "APPLY" + } + return "DRY-RUN" +} + +func applyStatus(changed bool) string { + if changed { + return "UPDATED" + } + return "UNCHANGED" +} + +func countChangedRows(results []applyRowResult) int { + count := 0 + for _, result := range results { + if result.Changed { + count++ + } + } + return count +} diff --git a/cmd/import-adjustment-stock-prices/main_test.go b/cmd/import-adjustment-stock-prices/main_test.go new file mode 100644 index 00000000..121ddc81 --- /dev/null +++ b/cmd/import-adjustment-stock-prices/main_test.go @@ -0,0 +1,362 @@ +package main + +import ( + "context" + "errors" + "fmt" + "path/filepath" + "strings" + "testing" + + "github.com/xuri/excelize/v2" +) + +func TestParseAdjustmentPriceFile_ValidSingleRow(t *testing.T) { + filePath := createWorkbook( + t, + "adjustment_prices", + []string{"adjustment_id", "weight"}, + [][]string{{"101", "12.345"}}, + ) + + sheet, rows, issues, err := parseAdjustmentPriceFile(filePath, "") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if sheet != "adjustment_prices" { + t.Fatalf("expected selected sheet adjustment_prices, got %q", sheet) + } + if len(issues) != 0 { + t.Fatalf("expected no issues, got %+v", issues) + } + if len(rows) != 1 { + t.Fatalf("expected 1 row, got %d", len(rows)) + } + if rows[0].AdjustmentID != 101 { + t.Fatalf("expected adjustment_id 101, got %d", rows[0].AdjustmentID) + } + if rows[0].Weight != 12.345 { + t.Fatalf("expected weight 12.345, got %v", rows[0].Weight) + } +} + +func TestParseAdjustmentPriceFile_ValidMultiRow(t *testing.T) { + filePath := createWorkbook( + t, + "adjustment_prices", + []string{" Adjustment_ID ", "WEIGHT"}, + [][]string{{"101", "10"}, {"102", "11.5"}}, + ) + + _, rows, issues, err := parseAdjustmentPriceFile(filePath, "adjustment_prices") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(issues) != 0 { + t.Fatalf("expected no issues, got %+v", issues) + } + if len(rows) != 2 { + t.Fatalf("expected 2 rows, got %d", len(rows)) + } +} + +func TestParseAdjustmentPriceFile_MissingRequiredHeader(t *testing.T) { + filePath := createWorkbook( + t, + "adjustment_prices", + []string{"adjustment_id", "price"}, + [][]string{{"101", "12"}}, + ) + + _, rows, issues, err := parseAdjustmentPriceFile(filePath, "") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(rows) != 0 { + t.Fatalf("expected 0 parsed rows when header invalid, got %d", len(rows)) + } + if !hasIssue(issues, 0, "weight", "required header is missing") { + t.Fatalf("expected missing weight header issue, got %+v", issues) + } +} + +func TestParseAdjustmentPriceFile_InvalidAdjustmentID(t *testing.T) { + filePath := createWorkbook( + t, + "adjustment_prices", + []string{"adjustment_id", "weight"}, + [][]string{{"abc", "10"}, {"0", "12"}}, + ) + + _, rows, issues, err := parseAdjustmentPriceFile(filePath, "") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(rows) != 0 { + t.Fatalf("expected no valid rows, got %d", len(rows)) + } + if !hasIssue(issues, 2, "adjustment_id", "must be a positive integer") { + t.Fatalf("expected non numeric adjustment_id issue, got %+v", issues) + } + if !hasIssue(issues, 3, "adjustment_id", "must be greater than 0") { + t.Fatalf("expected adjustment_id >0 issue, got %+v", issues) + } +} + +func TestParseAdjustmentPriceFile_InvalidWeight(t *testing.T) { + filePath := createWorkbook( + t, + "adjustment_prices", + []string{"adjustment_id", "weight"}, + [][]string{{"101", "abc"}, {"102", "-1"}}, + ) + + _, rows, issues, err := parseAdjustmentPriceFile(filePath, "") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(rows) != 0 { + t.Fatalf("expected no valid rows, got %d", len(rows)) + } + if !hasIssue(issues, 2, "weight", "must be numeric") { + t.Fatalf("expected weight numeric issue, got %+v", issues) + } + if !hasIssue(issues, 3, "weight", "must be greater than or equal to 0") { + t.Fatalf("expected weight >=0 issue, got %+v", issues) + } +} + +func TestParseAdjustmentPriceFile_DuplicateAdjustmentID_LastRowWins(t *testing.T) { + filePath := createWorkbook( + t, + "adjustment_prices", + []string{"adjustment_id", "weight"}, + [][]string{{"101", "10"}, {"102", "20"}, {"101", "30"}}, + ) + + _, rows, issues, err := parseAdjustmentPriceFile(filePath, "") + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if len(issues) != 0 { + t.Fatalf("expected no issues, got %+v", issues) + } + if len(rows) != 2 { + t.Fatalf("expected 2 deduped rows, got %d", len(rows)) + } + + row101, ok := findRowByAdjustmentID(rows, 101) + if !ok { + t.Fatalf("expected adjustment_id 101 to exist, got %+v", rows) + } + if row101.Weight != 30 { + t.Fatalf("expected duplicate adjustment_id to keep last weight 30, got %v", row101.Weight) + } + if row101.RowNumber != 4 { + t.Fatalf("expected duplicate adjustment_id to keep last row number 4, got %d", row101.RowNumber) + } +} + +func TestSplitRowsByExistingIDs_SkipMissing(t *testing.T) { + rows := []adjustmentPriceImportRow{ + {RowNumber: 2, AdjustmentID: 101, Weight: 10}, + {RowNumber: 3, AdjustmentID: 102, Weight: 11}, + {RowNumber: 4, AdjustmentID: 103, Weight: 12}, + } + existing := map[uint]struct{}{101: {}, 103: {}} + + processable, skipped := splitRowsByExistingIDs(rows, existing) + if len(processable) != 2 { + t.Fatalf("expected 2 processable rows, got %d", len(processable)) + } + if len(skipped) != 1 { + t.Fatalf("expected 1 skipped row, got %d", len(skipped)) + } + if skipped[0].AdjustmentID != 102 { + t.Fatalf("expected adjustment_id 102 skipped, got %+v", skipped) + } +} + +func TestApplyIfRequested_DryRunDoesNotWrite(t *testing.T) { + runner := &fakeTransactionRunner{} + rows := []adjustmentPriceImportRow{{RowNumber: 2, AdjustmentID: 101, Weight: 10}} + + results, err := applyIfRequested(context.Background(), false, runner, rows) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if results != nil { + t.Fatalf("expected nil results on dry-run, got %+v", results) + } + if runner.txCalls != 0 { + t.Fatalf("expected no transaction call during dry-run, got %d", runner.txCalls) + } +} + +func TestApplyImportRows_Success(t *testing.T) { + runner := &fakeTransactionRunner{ + changedByID: map[uint]bool{101: true, 102: false}, + } + rows := []adjustmentPriceImportRow{ + {RowNumber: 2, AdjustmentID: 101, Weight: 10}, + {RowNumber: 3, AdjustmentID: 102, Weight: 11}, + } + + results, err := applyImportRows(context.Background(), runner, rows) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if runner.txCalls != 1 { + t.Fatalf("expected 1 transaction call, got %d", runner.txCalls) + } + if len(runner.committedCalls) != 2 { + t.Fatalf("expected 2 committed updates, got %d", len(runner.committedCalls)) + } + if len(results) != 2 { + t.Fatalf("expected 2 row results, got %d", len(results)) + } + if !results[0].Changed || results[1].Changed { + t.Fatalf("unexpected changed flags: %+v", results) + } +} + +func TestApplyImportRows_RollbackOnError(t *testing.T) { + runner := &fakeTransactionRunner{ + errByID: map[uint]error{102: errors.New("boom")}, + } + rows := []adjustmentPriceImportRow{ + {RowNumber: 2, AdjustmentID: 101, Weight: 10}, + {RowNumber: 3, AdjustmentID: 102, Weight: 11}, + } + + _, err := applyImportRows(context.Background(), runner, rows) + if err == nil { + t.Fatal("expected error due to update failure") + } + if !strings.Contains(err.Error(), "row 3 adjustment_id=102 update failed") { + t.Fatalf("unexpected error message: %v", err) + } + if runner.txCalls != 1 { + t.Fatalf("expected 1 transaction call, got %d", runner.txCalls) + } + if len(runner.committedCalls) != 0 { + t.Fatalf("expected no committed updates on rollback, got %d", len(runner.committedCalls)) + } +} + +func createWorkbook(t *testing.T, sheetName string, headers []string, rows [][]string) string { + t.Helper() + + f := excelize.NewFile() + defaultSheet := f.GetSheetName(f.GetActiveSheetIndex()) + if sheetName == "" { + sheetName = defaultSheet + } else if sheetName != defaultSheet { + f.SetSheetName(defaultSheet, sheetName) + } + + for idx, header := range headers { + cell, err := excelize.CoordinatesToCellName(idx+1, 1) + if err != nil { + t.Fatalf("failed resolving header cell: %v", err) + } + if err := f.SetCellValue(sheetName, cell, header); err != nil { + t.Fatalf("failed setting header cell: %v", err) + } + } + + for rowIdx, row := range rows { + for colIdx, value := range row { + cell, err := excelize.CoordinatesToCellName(colIdx+1, rowIdx+2) + if err != nil { + t.Fatalf("failed resolving data cell: %v", err) + } + if err := f.SetCellValue(sheetName, cell, value); err != nil { + t.Fatalf("failed setting data cell: %v", err) + } + } + } + + path := filepath.Join(t.TempDir(), "adjustment_prices.xlsx") + if err := f.SaveAs(path); err != nil { + t.Fatalf("failed saving workbook: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("failed closing workbook: %v", err) + } + + return path +} + +func hasIssue(issues []validationIssue, row int, field, messageContains string) bool { + for _, issue := range issues { + if issue.Row != row { + continue + } + if issue.Field != field { + continue + } + if strings.Contains(issue.Message, messageContains) { + return true + } + } + return false +} + +func findRowByAdjustmentID(rows []adjustmentPriceImportRow, adjustmentID uint) (adjustmentPriceImportRow, bool) { + for _, row := range rows { + if row.AdjustmentID == adjustmentID { + return row, true + } + } + return adjustmentPriceImportRow{}, false +} + +type updateCall struct { + adjustmentID uint + price float64 +} + +type fakeAdjustmentPriceStore struct { + changedByID map[uint]bool + errByID map[uint]error + calls []updateCall +} + +func (s *fakeAdjustmentPriceStore) UpdatePrice(_ context.Context, adjustmentID uint, price float64) (bool, error) { + s.calls = append(s.calls, updateCall{adjustmentID: adjustmentID, price: price}) + if err, exists := s.errByID[adjustmentID]; exists { + return false, fmt.Errorf("forced update failure for adjustment_id=%d: %w", adjustmentID, err) + } + if changed, exists := s.changedByID[adjustmentID]; exists { + return changed, nil + } + return true, nil +} + +type fakeTransactionRunner struct { + txCalls int + changedByID map[uint]bool + errByID map[uint]error + committedCalls []updateCall +} + +func (r *fakeTransactionRunner) InTx(ctx context.Context, fn func(store adjustmentPriceStore) error) error { + r.txCalls++ + + txStore := &fakeAdjustmentPriceStore{ + changedByID: r.changedByID, + errByID: r.errByID, + calls: make([]updateCall, 0), + } + + if err := fn(txStore); err != nil { + return err + } + + r.committedCalls = append(r.committedCalls, txStore.calls...) + return nil +} + +var _ txRunner = (*fakeTransactionRunner)(nil) +var _ adjustmentPriceStore = (*fakeAdjustmentPriceStore)(nil) diff --git a/cmd/run-sql-file/main.go b/cmd/run-sql-file/main.go new file mode 100644 index 00000000..6c8cb094 --- /dev/null +++ b/cmd/run-sql-file/main.go @@ -0,0 +1,75 @@ +package main + +import ( + "flag" + "fmt" + "log" + "os" + "strings" + + "gitlab.com/mbugroup/lti-api.git/internal/config" + "gitlab.com/mbugroup/lti-api.git/internal/database" + "gorm.io/gorm" +) + +type options struct { + FilePath string + Apply bool +} + +func main() { + var opts options + flag.StringVar(&opts.FilePath, "file", "", "Path to .sql file (required)") + flag.BoolVar(&opts.Apply, "apply", false, "Apply SQL to database. If false, run as dry-run") + flag.Parse() + + opts.FilePath = strings.TrimSpace(opts.FilePath) + if opts.FilePath == "" { + log.Fatal("--file is required") + } + + sqlContent, err := readSQLFile(opts.FilePath) + if err != nil { + log.Fatalf("failed reading sql file: %v", err) + } + + mode := "dry-run" + if opts.Apply { + mode = "apply" + } + fmt.Printf("Mode: %s\n", mode) + fmt.Printf("File: %s\n", opts.FilePath) + fmt.Printf("SQL bytes: %d\n", len(sqlContent)) + + if !opts.Apply { + fmt.Println("Dry-run only. Add --apply to execute the SQL file.") + return + } + + db := database.Connect(config.DBHost, config.DBName) + if err := executeSQL(db, sqlContent); err != nil { + log.Fatalf("failed executing sql file: %v", err) + } + + fmt.Println("DONE: SQL executed successfully") +} + +func readSQLFile(path string) (string, error) { + raw, err := os.ReadFile(path) + if err != nil { + return "", err + } + + sql := strings.TrimSpace(strings.TrimPrefix(string(raw), "\ufeff")) + if sql == "" { + return "", fmt.Errorf("sql file is empty") + } + + return sql, nil +} + +func executeSQL(db *gorm.DB, sql string) error { + return db.Transaction(func(tx *gorm.DB) error { + return tx.Exec(sql).Error + }) +} diff --git a/docs/templates/adjustment_stock_prices.xlsx b/docs/templates/adjustment_stock_prices.xlsx new file mode 100644 index 00000000..3e1ce3de Binary files /dev/null and b/docs/templates/adjustment_stock_prices.xlsx differ diff --git a/docs/templates/~$adjustment_stock_prices.xlsx b/docs/templates/~$adjustment_stock_prices.xlsx new file mode 100644 index 00000000..5a932052 Binary files /dev/null and b/docs/templates/~$adjustment_stock_prices.xlsx differ diff --git a/internal/modules/dashboards/module.go b/internal/modules/dashboards/module.go index d7d0d477..622222c5 100644 --- a/internal/modules/dashboards/module.go +++ b/internal/modules/dashboards/module.go @@ -18,11 +18,11 @@ type DashboardModule struct{} func (DashboardModule) RegisterRoutes(router fiber.Router, db *gorm.DB, validate *validator.Validate) { dashboardRepo := rDashboard.NewDashboardRepository(db) - hppCostRepo := commonRepo.NewHppCostRepository(db) + hppV2CostRepo := commonRepo.NewHppV2CostRepository(db) userRepo := rUser.NewUserRepository(db) - hppSvc := commonService.NewHppService(hppCostRepo) - dashboardService := sDashboard.NewDashboardService(dashboardRepo, validate, hppSvc) + hppV2Svc := commonService.NewHppV2Service(hppV2CostRepo) + dashboardService := sDashboard.NewDashboardService(dashboardRepo, validate, hppV2Svc) userService := sUser.NewUserService(userRepo, validate) DashboardRoutes(router, userService, dashboardService) diff --git a/internal/modules/dashboards/services/dashboard.service.go b/internal/modules/dashboards/services/dashboard.service.go index 275b53f3..3811917d 100644 --- a/internal/modules/dashboards/services/dashboard.service.go +++ b/internal/modules/dashboards/services/dashboard.service.go @@ -30,10 +30,10 @@ type dashboardService struct { Log *logrus.Logger Validate *validator.Validate Repository repository.DashboardRepository - HppSvc commonService.HppService + HppSvc commonService.HppV2Service } -func NewDashboardService(repo repository.DashboardRepository, validate *validator.Validate, hppSvc commonService.HppService) DashboardService { +func NewDashboardService(repo repository.DashboardRepository, validate *validator.Validate, hppSvc commonService.HppV2Service) DashboardService { return &dashboardService{ Log: utils.Log, Validate: validate, diff --git a/internal/modules/repports/services/repport.service.go b/internal/modules/repports/services/repport.service.go index 3ffa1e09..572a2317 100644 --- a/internal/modules/repports/services/repport.service.go +++ b/internal/modules/repports/services/repport.service.go @@ -773,7 +773,7 @@ func (s *repportService) GetMarketing(c *fiber.Ctx, params *validation.Marketing return nil, 0, err } - hppByDelivery := buildMarketingHppByDelivery(c.Context(), s.HppSvc, attributionRows) + hppByDelivery := buildMarketingHppByDelivery(c.Context(), s.HppV2Svc, attributionRows) categoryByDelivery := buildMarketingCategoryByDelivery(deliveryProducts, attributionRows) items := dto.ToMarketingReportItems(deliveryProducts, hppByDelivery, categoryByDelivery, agingMap) @@ -782,7 +782,7 @@ func (s *repportService) GetMarketing(c *fiber.Ctx, params *validation.Marketing func buildMarketingHppByDelivery( ctx context.Context, - hppSvc approvalService.HppService, + hppSvc approvalService.HppV2Service, attributionRows []commonRepo.MarketingDeliveryAttributionRow, ) map[uint]float64 { if len(attributionRows) == 0 {