diff --git a/.golangci.bck.yml b/.golangci.bck.yml new file mode 100644 index 0000000..d9266b8 --- /dev/null +++ b/.golangci.bck.yml @@ -0,0 +1,110 @@ +run: + timeout: 5m + tests: true + skip-dirs: + - vendor + - .github + +linters: + enable: + - errcheck + - gosimple + - govet + - ineffassign + - staticcheck + - unused + - gofmt + - goimports + - misspell + - gocritic + - revive + - stylecheck + disable: + - typecheck # Can cause issues with generics in some cases + +linters-settings: + errcheck: + check-type-assertions: false + check-blank: false + + govet: + check-shadowing: false + + gofmt: + simplify: true + + goimports: + local-prefixes: github.com/bitechdev/ResolveSpec + + gocritic: + enabled-checks: + - appendAssign + - assignOp + - boolExprSimplify + - builtinShadow + - captLocal + - caseOrder + - defaultCaseOrder + - dupArg + - dupBranchBody + - dupCase + - dupSubExpr + - elseif + - emptyFallthrough + - equalFold + - flagName + - ifElseChain + - indexAlloc + - initClause + - methodExprCall + - nilValReturn + - rangeExprCopy + - rangeValCopy + - regexpMust + - singleCaseSwitch + - sloppyLen + - stringXbytes + - switchTrue + - typeAssertChain + - typeSwitchVar + - underef + - unlabelStmt + - unnamedResult + - unnecessaryBlock + - weakCond + - yodaStyleExpr + + revive: + rules: + - name: exported + disabled: true + - name: package-comments + disabled: true + +issues: + exclude-use-default: false + max-issues-per-linter: 0 + max-same-issues: 0 + + # Exclude some linters from running on tests files + exclude-rules: + - path: _test\.go + linters: + - errcheck + - dupl + - gosec + - gocritic + + # Ignore "error return value not checked" for defer statements + - linters: + - errcheck + text: "Error return value of .((os\\.)?std(out|err)\\..*|.*Close|.*Flush|os\\.Remove(All)?|.*print(f|ln)?|os\\.(Un)?Setenv). is not checked" + + # Ignore complexity in test files + - path: _test\.go + text: "cognitive complexity|cyclomatic complexity" + +output: + format: colored-line-number + print-issued-lines: true + print-linter-name: true diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index d46e4fe..82ffb8d 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -215,6 +215,40 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com return b } +func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { + b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery { + if len(apply) == 0 { + return sq + } + + // Wrap the incoming *bun.SelectQuery in our adapter + wrapper := &BunSelectQuery{ + query: sq, + db: b.db, + } + + // Start with the interface value (not pointer) + current := common.SelectQuery(wrapper) + + // Apply each function in sequence + for _, fn := range apply { + if fn != nil { + // Pass ¤t (pointer to interface variable), fn modifies and returns new interface value + modified := fn(current) + current = modified + } + } + + // Extract the final *bun.SelectQuery + if finalBun, ok := current.(*BunSelectQuery); ok { + return finalBun.query + } + + return sq // fallback + }) + return b +} + func (b *BunSelectQuery) Order(order string) common.SelectQuery { b.query = b.query.Order(order) return b diff --git a/pkg/common/adapters/database/gorm.go b/pkg/common/adapters/database/gorm.go index 6c31fb2..5241ead 100644 --- a/pkg/common/adapters/database/gorm.go +++ b/pkg/common/adapters/database/gorm.go @@ -197,6 +197,36 @@ func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) co return g } +func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { + g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB { + if len(apply) == 0 { + return db + } + + wrapper := &GormSelectQuery{ + db: g.db, + } + + current := common.SelectQuery(wrapper) + + for _, fn := range apply { + if fn != nil { + + modified := fn(current) + current = modified + } + } + + if finalBun, ok := current.(*GormSelectQuery); ok { + return finalBun.db + } + + return db // fallback + }) + + return g +} + func (g *GormSelectQuery) Order(order string) common.SelectQuery { g.db = g.db.Order(order) return g diff --git a/pkg/common/interfaces.go b/pkg/common/interfaces.go index 8efc027..5d85dee 100644 --- a/pkg/common/interfaces.go +++ b/pkg/common/interfaces.go @@ -32,6 +32,7 @@ type SelectQuery interface { Join(query string, args ...interface{}) SelectQuery LeftJoin(query string, args ...interface{}) SelectQuery Preload(relation string, conditions ...interface{}) SelectQuery + PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery Order(order string) SelectQuery Limit(n int) SelectQuery Offset(n int) SelectQuery diff --git a/pkg/common/recursive_crud.go b/pkg/common/recursive_crud.go index d085cf1..62abacf 100644 --- a/pkg/common/recursive_crud.go +++ b/pkg/common/recursive_crud.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/reflection" ) // CRUDRequestProvider interface for models that provide CRUD request strings @@ -248,7 +249,7 @@ func (p *NestedCUDProcessor) processUpdate( logger.Debug("Updating %s with ID %v, data: %+v", tableName, id, data) - query := p.db.NewUpdate().Table(tableName).SetMap(data).Where("id = ?", id) + query := p.db.NewUpdate().Table(tableName).SetMap(data).Where(fmt.Sprintf("%s = ?", QuoteIdent(reflection.GetPrimaryKeyName(tableName))), id) result, err := query.Exec(ctx) if err != nil { @@ -268,7 +269,7 @@ func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string logger.Debug("Deleting from %s with ID %v", tableName, id) - query := p.db.NewDelete().Table(tableName).Where("id = ?", id) + query := p.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", QuoteIdent(reflection.GetPrimaryKeyName(tableName))), id) result, err := query.Exec(ctx) if err != nil { diff --git a/pkg/common/types.go b/pkg/common/types.go index 3a9eaaa..cc62b0b 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -35,7 +35,9 @@ type PreloadOption struct { Relation string `json:"relation"` Columns []string `json:"columns"` OmitColumns []string `json:"omit_columns"` + Sort []SortOption `json:"sort"` Filters []FilterOption `json:"filters"` + Where string `json:"where"` Limit *int `json:"limit"` Offset *int `json:"offset"` Updatable *bool `json:"updateable"` // if true, the relation can be updated diff --git a/pkg/common/validation.go b/pkg/common/validation.go index 13bbe10..dbf777b 100644 --- a/pkg/common/validation.go +++ b/pkg/common/validation.go @@ -270,3 +270,11 @@ func (v *ColumnValidator) GetValidColumns() []string { } return columns } + +func QuoteIdent(qualifier string) string { + return `"` + strings.ReplaceAll(qualifier, `"`, `""`) + `"` +} + +func QuoteLiteral(value string) string { + return `'` + strings.ReplaceAll(value, `'`, `''`) + `'` +} diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 4a28e60..6fd8d1c 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -4,15 +4,31 @@ import ( "reflect" "strings" - "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/modelregistry" ) +type PrimaryKeyNameProvider interface { + GetIDName() string +} + // GetPrimaryKeyName extracts the primary key column name from a model // It first checks if the model implements PrimaryKeyNameProvider (GetIDName method) // Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag func GetPrimaryKeyName(model any) string { + if reflect.TypeOf(model) == nil { + return "" + } + //If we are given a string model name, look up the model + if reflect.TypeOf(model).Kind() == reflect.String { + name := model.(string) + m, err := modelregistry.GetModelByName(name) + if err == nil { + model = m + } + } + // Check if model implements PrimaryKeyNameProvider - if provider, ok := model.(common.PrimaryKeyNameProvider); ok { + if provider, ok := model.(PrimaryKeyNameProvider); ok { return provider.GetIDName() } @@ -22,7 +38,11 @@ func GetPrimaryKeyName(model any) string { } // Fall back to GORM tag - return getPrimaryKeyFromReflection(model, "gorm") + if pkName := getPrimaryKeyFromReflection(model, "gorm"); pkName != "" { + return pkName + } + + return "" } // GetModelColumns extracts all column names from a model using reflection diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 00a06b8..2a7547c 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -11,6 +11,7 @@ import ( "github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/reflection" ) // Handler handles API requests using database and model abstractions @@ -249,7 +250,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st logger.Debug("Querying single record with ID: %s", id) // For single record, create a new pointer to the struct type singleResult := reflect.New(modelType).Interface() - query = query.Where("id = ?", id) + + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(singleResult))), id) if err := query.Scan(ctx, singleResult); err != nil { logger.Error("Error querying record: %v", err) h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err) @@ -521,15 +523,15 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url // Apply conditions if urlID != "" { logger.Debug("Updating by URL ID: %s", urlID) - query = query.Where("id = ?", urlID) + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), urlID) } else if reqID != nil { switch id := reqID.(type) { case string: logger.Debug("Updating by request ID: %s", id) - query = query.Where("id = ?", id) + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id) case []string: logger.Debug("Updating by multiple IDs: %v", id) - query = query.Where("id IN (?)", id) + query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), id) } } @@ -593,7 +595,8 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range updates { if itemID, ok := item["id"]; ok { - txQuery := tx.NewUpdate().Table(tableName).SetMap(item).Where("id = ?", itemID) + + txQuery := tx.NewUpdate().Table(tableName).SetMap(item).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID) if _, err := txQuery.Exec(ctx); err != nil { return err } @@ -659,7 +662,8 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url for _, item := range updates { if itemMap, ok := item.(map[string]interface{}); ok { if itemID, ok := itemMap["id"]; ok { - txQuery := tx.NewUpdate().Table(tableName).SetMap(itemMap).Where("id = ?", itemID) + + txQuery := tx.NewUpdate().Table(tableName).SetMap(itemMap).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID) if _, err := txQuery.Exec(ctx); err != nil { return err } @@ -706,7 +710,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id logger.Info("Batch delete with %d IDs ([]string)", len(v)) err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, itemID := range v { - query := tx.NewDelete().Table(tableName).Where("id = ?", itemID) + + query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(tableName))), itemID) if _, err := query.Exec(ctx); err != nil { return fmt.Errorf("failed to delete record %s: %w", itemID, err) } @@ -745,7 +750,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id continue // Skip items without ID } - query := tx.NewDelete().Table(tableName).Where("id = ?", itemID) + query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(tableName))), itemID) result, err := query.Exec(ctx) if err != nil { return fmt.Errorf("failed to delete record %v: %w", itemID, err) @@ -770,7 +775,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id err := h.db.RunInTransaction(ctx, func(tx common.Database) error { for _, item := range v { if itemID, ok := item["id"]; ok && itemID != nil { - query := tx.NewDelete().Table(tableName).Where("id = ?", itemID) + query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(tableName))), itemID) result, err := query.Exec(ctx) if err != nil { return fmt.Errorf("failed to delete record %v: %w", itemID, err) @@ -804,7 +809,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id return } - query := h.db.NewDelete().Table(tableName).Where("id = ?", id) + query := h.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(tableName))), id) result, err := query.Exec(ctx) if err != nil { @@ -1128,7 +1133,54 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre // For now, we'll preload without conditions // TODO: Implement column selection and filtering for preloads // This requires a more sophisticated approach with callbacks or query builders - query = query.Preload(relationFieldName) + // Apply preloading + + logger.Debug("Applying preload: %s", preload.Relation) + query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery { + if len(preload.OmitColumns) > 0 { + allCols := reflection.GetModelColumns(model) + // Remove omitted columns + preload.Columns = []string{} + for _, col := range allCols { + addCols := true + for _, omitCol := range preload.OmitColumns { + if col == omitCol { + addCols = false + break + } + } + if addCols { + preload.Columns = append(preload.Columns, col) + } + } + } + + if len(preload.Columns) > 0 { + sq = sq.Column(preload.Columns...) + } + + if len(preload.Filters) > 0 { + for _, filter := range preload.Filters { + sq = h.applyFilter(sq, filter) + } + } + if len(preload.Sort) > 0 { + for _, sort := range preload.Sort { + sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction)) + } + } + + if len(preload.Where) > 0 { + sq = sq.Where(preload.Where) + } + + if preload.Limit != nil && *preload.Limit > 0 { + sq = sq.Limit(*preload.Limit) + } + + return sq + }) + logger.Debug("Applied Preload for relation: %s (field: %s)", preload.Relation, relationFieldName) } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 3e0ec1e..3bca219 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -258,6 +258,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st for colName, colExpr := range options.ComputedQL { logger.Debug("Applying computed column: %s", colName) query = query.ColumnExpr("(?) AS "+colName, colExpr) + for colIndex := range options.Columns { + if options.Columns[colIndex] == colName { + //Remove the computed column from the selected columns to avoid duplication + options.Columns = append(options.Columns[:colIndex], options.Columns[colIndex+1:]...) + break + } + } } } @@ -265,6 +272,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st for _, cu := range options.ComputedColumns { logger.Debug("Applying computed column: %s", cu.Name) query = query.ColumnExpr("(?) AS "+cu.Name, cu.Expression) + for colIndex := range options.Columns { + if options.Columns[colIndex] == cu.Name { + //Remove the computed column from the selected columns to avoid duplication + options.Columns = append(options.Columns[:colIndex], options.Columns[colIndex+1:]...) + break + } + } } } @@ -274,18 +288,91 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st query = query.Column(options.Columns...) } + // Apply expand (Just expand to Preload for now) + for _, expand := range options.Expand { + logger.Debug("Applying expand: %s", expand.Relation) + sorts := make([]common.SortOption, 0) + for _, s := range strings.Split(expand.Sort, ",") { + dir := "ASC" + if strings.HasPrefix(s, "-") || strings.HasSuffix(strings.ToUpper(s), " DESC") { + dir = "DESC" + s = strings.TrimPrefix(s, "-") + s = strings.TrimSuffix(strings.ToLower(s), " desc") + } + sorts = append(sorts, common.SortOption{ + Column: s, Direction: dir, + }) + } + // Note: Expand would require JOIN implementation + // For now, we'll use Preload as a fallback + //query = query.Preload(expand.Relation) + if options.Preload == nil { + options.Preload = make([]common.PreloadOption, 0) + } + skip := false + for _, existing := range options.Preload { + if existing.Relation == expand.Relation { + skip = true + continue + } + } + if !skip { + options.Preload = append(options.Preload, common.PreloadOption{ + Relation: expand.Relation, + Columns: expand.Columns, + Sort: sorts, + Where: expand.Where, + }) + } + } + // Apply preloading for _, preload := range options.Preload { logger.Debug("Applying preload: %s", preload.Relation) - query = query.Preload(preload.Relation) - } + query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery { + if len(preload.OmitColumns) > 0 { + allCols := reflection.GetModelColumns(model) + // Remove omitted columns + preload.Columns = []string{} + for _, col := range allCols { + addCols := true + for _, omitCol := range preload.OmitColumns { + if col == omitCol { + addCols = false + break + } + } + if addCols { + preload.Columns = append(preload.Columns, col) + } + } + } - // Apply expand (LEFT JOIN) - for _, expand := range options.Expand { - logger.Debug("Applying expand: %s", expand.Relation) - // Note: Expand would require JOIN implementation - // For now, we'll use Preload as a fallback - query = query.Preload(expand.Relation) + if len(preload.Columns) > 0 { + sq = sq.Column(preload.Columns...) + } + + if len(preload.Filters) > 0 { + for _, filter := range preload.Filters { + sq = h.applyFilter(sq, filter, "", false, "AND") + } + } + if len(preload.Sort) > 0 { + for _, sort := range preload.Sort { + sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction)) + } + } + + if len(preload.Where) > 0 { + sq = sq.Where(preload.Where) + } + + if preload.Limit != nil && *preload.Limit > 0 { + sq = sq.Limit(*preload.Limit) + } + + return sq + }) } // Apply DISTINCT if requested @@ -326,8 +413,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // If ID is provided, filter by ID if id != "" { - logger.Debug("Filtering by ID: %s", id) - query = query.Where("id = ?", id) + pkName := reflection.GetPrimaryKeyName(model) + logger.Debug("Filtering by ID=%s: %s", pkName, id) + + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) } // Apply sorting @@ -794,13 +883,13 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id } query := h.db.NewUpdate().Table(tableName).SetMap(dataMap) - + pkName := reflection.GetPrimaryKeyName(model) // Apply ID filter switch { case id != "": - query = query.Where("id = ?", id) + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id) case idPtr != nil: - query = query.Where("id = ?", *idPtr) + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), *idPtr) default: h.sendError(w, http.StatusBadRequest, "missing_id", "ID is required for update", nil) return @@ -883,7 +972,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id continue } - query := tx.NewDelete().Table(tableName).Where("id = ?", itemID) + query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(tableName))), itemID) result, err := query.Exec(ctx) if err != nil { @@ -950,7 +1039,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id continue } - query := tx.NewDelete().Table(tableName).Where("id = ?", itemID) + query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(tableName))), itemID) result, err := query.Exec(ctx) if err != nil { return fmt.Errorf("failed to delete record %v: %w", itemID, err) @@ -1001,7 +1090,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id continue } - query := tx.NewDelete().Table(tableName).Where("id = ?", itemID) + query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(tableName))), itemID) result, err := query.Exec(ctx) if err != nil { return fmt.Errorf("failed to delete record %v: %w", itemID, err) @@ -1061,7 +1150,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id return } - query = query.Where("id = ?", id) + query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(tableName))), id) // Execute BeforeScan hooks - pass query chain so hooks can modify it hookCtx.Query = query