From 7e76977dcc581c59b8be9cd8d84de8b67c89fe2e Mon Sep 17 00:00:00 2001 From: Hein Date: Fri, 21 Nov 2025 10:17:20 +0200 Subject: [PATCH] Lots of refactoring, Fixes to preloads --- .golangci.json | 4 +- pkg/common/adapters/database/bun.go | 7 +- pkg/common/sql_helpers.go | 111 +++++++++ pkg/common/sql_types.go | 23 +- pkg/common/types.go | 2 +- pkg/common/validation.go | 20 +- pkg/common/validation_json_test.go | 6 +- pkg/reflection/generic_model.go | 3 +- pkg/reflection/model_utils.go | 303 +++++++++++++++++++++++- pkg/resolvespec/handler.go | 10 +- pkg/restheadspec/handler.go | 108 +++++---- pkg/restheadspec/headers.go | 352 +++++----------------------- 12 files changed, 563 insertions(+), 386 deletions(-) diff --git a/.golangci.json b/.golangci.json index 28b34d9..5b41664 100644 --- a/.golangci.json +++ b/.golangci.json @@ -86,7 +86,6 @@ "emptyFallthrough", "equalFold", "flagName", - "ifElseChain", "indexAlloc", "initClause", "methodExprCall", @@ -106,6 +105,9 @@ "unnecessaryBlock", "weakCond", "yodaStyleExpr" + ], + "disabled-checks": [ + "ifElseChain" ] }, "revive": { diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index c0db6ab..52b7526 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -237,7 +237,10 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery { defer func() { if r := recover(); r != nil { - logger.HandlePanic("BunSelectQuery.PreloadRelation", r) + err := logger.HandlePanic("BunSelectQuery.PreloadRelation", r) + if err != nil { + return + } } }() if len(apply) == 0 { @@ -401,7 +404,7 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error err = logger.HandlePanic("BunInsertQuery.Exec", r) } }() - if b.values != nil && len(b.values) > 0 { + if len(b.values) > 0 { if !b.hasModel { // If no model was set, use the values map as the model // Bun can insert map[string]interface{} directly diff --git a/pkg/common/sql_helpers.go b/pkg/common/sql_helpers.go index 04aa3f2..51e6a5c 100644 --- a/pkg/common/sql_helpers.go +++ b/pkg/common/sql_helpers.go @@ -96,6 +96,117 @@ func IsSQLExpression(cond string) bool { return false } +// IsTrivialCondition checks if a condition is trivial and always evaluates to true +// These conditions should be removed from WHERE clauses as they have no filtering effect +func IsTrivialCondition(cond string) bool { + cond = strings.TrimSpace(cond) + lowerCond := strings.ToLower(cond) + + // Conditions that always evaluate to true + trivialConditions := []string{ + "1=1", "1 = 1", "1= 1", "1 =1", + "true", "true = true", "true=true", "true= true", "true =true", + "0=0", "0 = 0", "0= 0", "0 =0", + } + + for _, trivial := range trivialConditions { + if lowerCond == trivial { + return true + } + } + + return false +} + +// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns +// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL +// +// Parameters: +// - where: The WHERE clause string to sanitize +// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing) +// +// Returns: +// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed +// - An empty string if all conditions were trivial or the input was empty +func SanitizeWhereClause(where string, tableName string) string { + if where == "" { + return "" + } + + where = strings.TrimSpace(where) + + // Split by AND to handle multiple conditions + conditions := splitByAND(where) + + validConditions := make([]string, 0, len(conditions)) + + for _, cond := range conditions { + cond = strings.TrimSpace(cond) + if cond == "" { + continue + } + + // Skip trivial conditions that always evaluate to true + if IsTrivialCondition(cond) { + logger.Debug("Removing trivial condition: '%s'", cond) + continue + } + + // If tableName is provided and the condition doesn't already have a table prefix, + // attempt to add it + if tableName != "" && !hasTablePrefix(cond) { + // Check if this is a SQL expression/literal that shouldn't be prefixed + if !IsSQLExpression(strings.ToLower(cond)) { + // Extract the column name and prefix it + columnName := ExtractColumnName(cond) + if columnName != "" { + cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1) + logger.Debug("Prefixed column in condition: '%s'", cond) + } + } + } + + validConditions = append(validConditions, cond) + } + + if len(validConditions) == 0 { + return "" + } + + result := strings.Join(validConditions, " AND ") + + if result != where { + logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result) + } + + return result +} + +// splitByAND splits a WHERE clause by AND operators (case-insensitive) +// This is a simple split that doesn't handle nested parentheses or complex expressions +func splitByAND(where string) []string { + // First try uppercase AND + conditions := strings.Split(where, " AND ") + + // If we didn't split on uppercase, try lowercase + if len(conditions) == 1 { + conditions = strings.Split(where, " and ") + } + + // If we still didn't split, try mixed case + if len(conditions) == 1 { + conditions = strings.Split(where, " And ") + } + + return conditions +} + +// hasTablePrefix checks if a condition already has a table/relation prefix (contains a dot) +func hasTablePrefix(cond string) bool { + // Look for patterns like "table.column" or "`table`.`column`" or "\"table\".\"column\"" + return strings.Contains(cond, ".") +} + // ExtractColumnName extracts the column name from a WHERE condition // For example: "status = 'active'" returns "status" func ExtractColumnName(cond string) string { diff --git a/pkg/common/sql_types.go b/pkg/common/sql_types.go index 97adf42..d07e966 100644 --- a/pkg/common/sql_types.go +++ b/pkg/common/sql_types.go @@ -238,13 +238,13 @@ func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error { var err error if b == nil { - t = &SqlTimeStamp{} + return nil } s := strings.Trim(strings.Trim(string(b), " "), "\"") if s == "null" || s == "" || s == "0" || s == "0001-01-01T00:00:00" || s == "0001-01-01" { - t = &SqlTimeStamp{} + return nil } @@ -293,7 +293,7 @@ func (t *SqlTimeStamp) Scan(value interface{}) error { // String - Override String format of time func (t SqlTimeStamp) String() string { - return fmt.Sprintf("%s", time.Time(t).Format("2006-01-02T15:04:05")) + return time.Time(t).Format("2006-01-02T15:04:05") } // GetTime - Returns Time @@ -308,7 +308,7 @@ func (t *SqlTimeStamp) SetTime(pTime time.Time) { // Format - Formats the time func (t SqlTimeStamp) Format(layout string) string { - return fmt.Sprintf("%s", time.Time(t).Format(layout)) + return time.Time(t).Format(layout) } func SqlTimeStampNow() SqlTimeStamp { @@ -420,7 +420,6 @@ func (t *SqlDate) UnmarshalJSON(b []byte) error { if s == "null" || s == "" || s == "0" || strings.HasPrefix(s, "0001-01-01T00:00:00") || s == "0001-01-01" { - t = &SqlDate{} return nil } @@ -434,7 +433,7 @@ func (t *SqlDate) UnmarshalJSON(b []byte) error { // MarshalJSON - Override JSON format of time func (t SqlDate) MarshalJSON() ([]byte, error) { - tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339 + tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339 if strings.HasPrefix(tmstr, "0001-01-01") { return []byte("null"), nil } @@ -482,7 +481,7 @@ func (t SqlDate) Int64() int64 { // String - Override String format of time func (t SqlDate) String() string { - tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339 + tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339 if strings.HasPrefix(tmstr, "0001-01-01") || strings.HasPrefix(tmstr, "1800-12-31") { return "0" } @@ -517,8 +516,8 @@ func (t *SqlTime) UnmarshalJSON(b []byte) error { *t = SqlTime{} return nil } - tx := time.Time{} - tx, err = tryParseDT(s) + + tx, err := tryParseDT(s) *t = SqlTime(tx) return err @@ -642,9 +641,8 @@ func (n SqlJSONB) AsSlice() ([]any, error) { func (n *SqlJSONB) UnmarshalJSON(b []byte) error { s := strings.Trim(strings.Trim(string(b), " "), "\"") - invalid := (s == "null" || s == "" || len(s) < 2) || !(strings.Contains(s, "{") || strings.Contains(s, "[")) + invalid := (s == "null" || s == "" || len(s) < 2) || (!strings.Contains(s, "{") && !strings.Contains(s, "[")) if invalid { - s = "" return nil } @@ -661,7 +659,7 @@ func (n SqlJSONB) MarshalJSON() ([]byte, error) { var obj interface{} err := json.Unmarshal(n, &obj) if err != nil { - //fmt.Printf("Invalid JSON %v", err) + // fmt.Printf("Invalid JSON %v", err) return []byte("null"), nil } @@ -725,7 +723,6 @@ func (n *SqlUUID) UnmarshalJSON(b []byte) error { s := strings.Trim(strings.Trim(string(b), " "), "\"") invalid := (s == "null" || s == "" || len(s) < 30) if invalid { - s = "" return nil } *n = SqlUUID(sql.NullString{String: s, Valid: !invalid}) diff --git a/pkg/common/types.go b/pkg/common/types.go index 330bcfa..10c0d39 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -40,7 +40,7 @@ type PreloadOption struct { Where string `json:"where"` Limit *int `json:"limit"` Offset *int `json:"offset"` - Updatable *bool `json:"updateable"` // if true, the relation can be updated + Updatable *bool `json:"updateable"` // if true, the relation can be updated ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels } diff --git a/pkg/common/validation.go b/pkg/common/validation.go index ba97bb5..c177471 100644 --- a/pkg/common/validation.go +++ b/pkg/common/validation.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/reflection" ) // ColumnValidator validates column names against a model's fields @@ -92,23 +93,6 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string { return strings.ToLower(field.Name) } -// extractSourceColumn extracts the base column name from PostgreSQL JSON operators -// Examples: -// - "columna->>'val'" returns "columna" -// - "columna->'key'" returns "columna" -// - "columna" returns "columna" -// - "table.columna->>'val'" returns "table.columna" -func extractSourceColumn(colName string) string { - // Check for PostgreSQL JSON operators: -> and ->> - if idx := strings.Index(colName, "->>"); idx != -1 { - return strings.TrimSpace(colName[:idx]) - } - if idx := strings.Index(colName, "->"); idx != -1 { - return strings.TrimSpace(colName[:idx]) - } - return colName -} - // ValidateColumn validates a single column name // Returns nil if valid, error if invalid // Columns prefixed with "cql" (case insensitive) are always valid @@ -125,7 +109,7 @@ func (v *ColumnValidator) ValidateColumn(column string) error { } // Extract source column name (remove JSON operators like ->> or ->) - sourceColumn := extractSourceColumn(column) + sourceColumn := reflection.ExtractSourceColumn(column) // Check if column exists in model if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists { diff --git a/pkg/common/validation_json_test.go b/pkg/common/validation_json_test.go index 1c6273a..79a0f36 100644 --- a/pkg/common/validation_json_test.go +++ b/pkg/common/validation_json_test.go @@ -2,6 +2,8 @@ package common import ( "testing" + + "github.com/bitechdev/ResolveSpec/pkg/reflection" ) func TestExtractSourceColumn(t *testing.T) { @@ -49,9 +51,9 @@ func TestExtractSourceColumn(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - result := extractSourceColumn(tc.input) + result := reflection.ExtractSourceColumn(tc.input) if result != tc.expected { - t.Errorf("extractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected) + t.Errorf("reflection.ExtractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected) } }) } diff --git a/pkg/reflection/generic_model.go b/pkg/reflection/generic_model.go index cb33107..333092d 100644 --- a/pkg/reflection/generic_model.go +++ b/pkg/reflection/generic_model.go @@ -26,8 +26,7 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail { } }() - var lst []ModelFieldDetail - lst = make([]ModelFieldDetail, 0) + lst := make([]ModelFieldDetail, 0) if !record.IsValid() { return lst diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 9409e2f..7179687 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -1,7 +1,9 @@ package reflection import ( + "fmt" "reflect" + "strconv" "strings" "github.com/bitechdev/ResolveSpec/pkg/modelregistry" @@ -132,7 +134,7 @@ func findFieldByName(val reflect.Value, name string) any { } // Check if field name matches - if strings.ToLower(field.Name) == name && fieldValue.CanInterface() { + if strings.EqualFold(field.Name, name) && fieldValue.CanInterface() { return fieldValue.Interface() } } @@ -409,8 +411,8 @@ func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbe if bunTag != "" { // Skip if it's a bun relation (rel:, join:, or m2m:) if strings.Contains(bunTag, "rel:") || - strings.Contains(bunTag, "join:") || - strings.Contains(bunTag, "m2m:") { + strings.Contains(bunTag, "join:") || + strings.Contains(bunTag, "m2m:") { continue } } @@ -419,9 +421,9 @@ func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbe if gormTag != "" { // Skip if it has gorm relationship tags if strings.Contains(gormTag, "foreignKey:") || - strings.Contains(gormTag, "references:") || - strings.Contains(gormTag, "many2many:") || - strings.Contains(gormTag, "constraint:") { + strings.Contains(gormTag, "references:") || + strings.Contains(gormTag, "many2many:") || + strings.Contains(gormTag, "constraint:") { continue } } @@ -472,7 +474,7 @@ func IsColumnWritable(model any, columnName string) bool { // isColumnWritableInType recursively searches for a column and checks if it's writable // Returns (found, writable) where found indicates if the column was found -func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) { +func isColumnWritableInType(typ reflect.Type, columnName string) (found bool, writable bool) { for i := 0; i < typ.NumField(); i++ { field := typ.Field(i) @@ -561,3 +563,290 @@ func isGormFieldReadOnly(tag string) bool { } return false } + +// ExtractSourceColumn extracts the base column name from PostgreSQL JSON operators +// Examples: +// - "columna->>'val'" returns "columna" +// - "columna->'key'" returns "columna" +// - "columna" returns "columna" +// - "table.columna->>'val'" returns "table.columna" +func ExtractSourceColumn(colName string) string { + // Check for PostgreSQL JSON operators: -> and ->> + if idx := strings.Index(colName, "->>"); idx != -1 { + return strings.TrimSpace(colName[:idx]) + } + if idx := strings.Index(colName, "->"); idx != -1 { + return strings.TrimSpace(colName[:idx]) + } + return colName +} + +// ToSnakeCase converts a string from CamelCase to snake_case +func ToSnakeCase(s string) string { + var result strings.Builder + for i, r := range s { + if i > 0 && r >= 'A' && r <= 'Z' { + result.WriteRune('_') + } + result.WriteRune(r) + } + return strings.ToLower(result.String()) +} + +// GetColumnTypeFromModel uses reflection to determine the Go type of a column in a model +func GetColumnTypeFromModel(model interface{}, colName string) reflect.Kind { + if model == nil { + return reflect.Invalid + } + + // Extract the source column name (remove JSON operators like ->> or ->) + sourceColName := ExtractSourceColumn(colName) + + modelType := reflect.TypeOf(model) + // Dereference pointer if needed + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + // Ensure it's a struct + if modelType.Kind() != reflect.Struct { + return reflect.Invalid + } + + // Find the field by JSON tag or field name + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + + // Check JSON tag + jsonTag := field.Tag.Get("json") + if jsonTag != "" { + // Parse JSON tag (format: "name,omitempty") + parts := strings.Split(jsonTag, ",") + if parts[0] == sourceColName { + return field.Type.Kind() + } + } + + // Check field name (case-insensitive) + if strings.EqualFold(field.Name, sourceColName) { + return field.Type.Kind() + } + + // Check snake_case conversion + snakeCaseName := ToSnakeCase(field.Name) + if snakeCaseName == sourceColName { + return field.Type.Kind() + } + } + + return reflect.Invalid +} + +// IsNumericType checks if a reflect.Kind is a numeric type +func IsNumericType(kind reflect.Kind) bool { + return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 || + kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint || + kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 || + kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64 +} + +// IsStringType checks if a reflect.Kind is a string type +func IsStringType(kind reflect.Kind) bool { + return kind == reflect.String +} + +// IsNumericValue checks if a string value can be parsed as a number +func IsNumericValue(value string) bool { + value = strings.TrimSpace(value) + _, err := strconv.ParseFloat(value, 64) + return err == nil +} + +// ConvertToNumericType converts a string value to the appropriate numeric type +func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error) { + value = strings.TrimSpace(value) + + switch kind { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + // Parse as integer + bitSize := 64 + switch kind { + case reflect.Int8: + bitSize = 8 + case reflect.Int16: + bitSize = 16 + case reflect.Int32: + bitSize = 32 + } + + intVal, err := strconv.ParseInt(value, 10, bitSize) + if err != nil { + return nil, fmt.Errorf("invalid integer value: %w", err) + } + + // Return the appropriate type + switch kind { + case reflect.Int: + return int(intVal), nil + case reflect.Int8: + return int8(intVal), nil + case reflect.Int16: + return int16(intVal), nil + case reflect.Int32: + return int32(intVal), nil + case reflect.Int64: + return intVal, nil + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // Parse as unsigned integer + bitSize := 64 + switch kind { + case reflect.Uint8: + bitSize = 8 + case reflect.Uint16: + bitSize = 16 + case reflect.Uint32: + bitSize = 32 + } + + uintVal, err := strconv.ParseUint(value, 10, bitSize) + if err != nil { + return nil, fmt.Errorf("invalid unsigned integer value: %w", err) + } + + // Return the appropriate type + switch kind { + case reflect.Uint: + return uint(uintVal), nil + case reflect.Uint8: + return uint8(uintVal), nil + case reflect.Uint16: + return uint16(uintVal), nil + case reflect.Uint32: + return uint32(uintVal), nil + case reflect.Uint64: + return uintVal, nil + } + + case reflect.Float32, reflect.Float64: + // Parse as float + bitSize := 64 + if kind == reflect.Float32 { + bitSize = 32 + } + + floatVal, err := strconv.ParseFloat(value, bitSize) + if err != nil { + return nil, fmt.Errorf("invalid float value: %w", err) + } + + if kind == reflect.Float32 { + return float32(floatVal), nil + } + return floatVal, nil + } + + return nil, fmt.Errorf("unsupported numeric type: %v", kind) +} + +// GetRelationModel gets the model type for a relation field +// It searches for the field by name in the following order (case-insensitive): +// 1. Actual field name +// 2. Bun tag name (if exists) +// 3. Gorm tag name (if exists) +// 4. JSON tag name (if exists) +func GetRelationModel(model interface{}, fieldName string) interface{} { + if model == nil || fieldName == "" { + return nil + } + + modelType := reflect.TypeOf(model) + if modelType == nil { + return nil + } + + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType == nil || modelType.Kind() != reflect.Struct { + return nil + } + + // Find the field by checking in priority order (case-insensitive) + var field *reflect.StructField + normalizedFieldName := strings.ToLower(fieldName) + + for i := 0; i < modelType.NumField(); i++ { + f := modelType.Field(i) + + // 1. Check actual field name (case-insensitive) + if strings.EqualFold(f.Name, fieldName) { + field = &f + break + } + + // 2. Check bun tag name + bunTag := f.Tag.Get("bun") + if bunTag != "" { + bunColName := ExtractColumnFromBunTag(bunTag) + if bunColName != "" && strings.EqualFold(bunColName, normalizedFieldName) { + field = &f + break + } + } + + // 3. Check gorm tag name + gormTag := f.Tag.Get("gorm") + if gormTag != "" { + gormColName := ExtractColumnFromGormTag(gormTag) + if gormColName != "" && strings.EqualFold(gormColName, normalizedFieldName) { + field = &f + break + } + } + + // 4. Check JSON tag name + jsonTag := f.Tag.Get("json") + if jsonTag != "" { + parts := strings.Split(jsonTag, ",") + if len(parts) > 0 && parts[0] != "" && parts[0] != "-" { + if strings.EqualFold(parts[0], normalizedFieldName) { + field = &f + break + } + } + } + } + + if field == nil { + return nil + } + + // Get the target type + targetType := field.Type + if targetType == nil { + return nil + } + + if targetType.Kind() == reflect.Slice { + targetType = targetType.Elem() + if targetType == nil { + return nil + } + } + if targetType.Kind() == reflect.Ptr { + targetType = targetType.Elem() + if targetType == nil { + return nil + } + } + + if targetType.Kind() != reflect.Struct { + return nil + } + + // Create a zero value of the target type + return reflect.New(targetType).Elem().Interface() +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index ccc1cb5..65d20c8 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -1149,6 +1149,11 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre logger.Debug("Applying preload: %s", relationFieldName) query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery { + if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) { + preload.Columns = reflection.GetSQLModelColumns(model) + } + + // Handle column selection and omission if len(preload.OmitColumns) > 0 { allCols := reflection.GetSQLModelColumns(model) // Remove omitted columns @@ -1204,7 +1209,10 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre } if len(preload.Where) > 0 { - sq = sq.Where(preload.Where) + sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation) + if len(sanitizedWhere) > 0 { + sq = sq.Where(sanitizedWhere) + } } if preload.Limit != nil && *preload.Limit > 0 { diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index f69acc2..6b25834 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -391,13 +391,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // Apply custom SQL WHERE clause (AND condition) if options.CustomSQLWhere != "" { logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere) - query = query.Where(options.CustomSQLWhere) + // Sanitize without auto-prefixing since custom SQL may reference multiple tables + sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, "") + if sanitizedWhere != "" { + query = query.Where(sanitizedWhere) + } } // Apply custom SQL WHERE clause (OR condition) if options.CustomSQLOr != "" { logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr) - query = query.WhereOr(options.CustomSQLOr) + // Sanitize without auto-prefixing since custom SQL may reference multiple tables + sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, "") + if sanitizedOr != "" { + query = query.WhereOr(sanitizedOr) + } } // If ID is provided, filter by ID @@ -473,7 +481,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // Apply cursor filter to query if cursorFilter != "" { logger.Debug("Applying cursor filter: %s", cursorFilter) - query = query.Where(cursorFilter) + sanitizedCursor := common.SanitizeWhereClause(cursorFilter, "") + if sanitizedCursor != "" { + query = query.Where(sanitizedCursor) + } } } @@ -552,56 +563,58 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co // Apply the preload query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery { // Get the related model for column operations - relatedModel := h.getRelationModel(model, preload.Relation) + relationParts := strings.Split(preload.Relation, ",") + relatedModel := reflection.GetRelationModel(model, relationParts[0]) if relatedModel == nil { logger.Warn("Could not get related model for preload: %s", preload.Relation) - relatedModel = model // fallback to parent model - } + // relatedModel = model // fallback to parent model + } else { - // If we have computed columns but no explicit columns, populate with all model columns first - // since computed columns are additions - if len(preload.Columns) == 0 && len(preload.ComputedQL) > 0 && relatedModel != nil { - logger.Debug("Populating preload columns with all model columns since computed columns are additions") - preload.Columns = reflection.GetSQLModelColumns(relatedModel) - } + // If we have computed columns but no explicit columns, populate with all model columns first + // since computed columns are additions + if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) { + logger.Debug("Populating preload columns with all model columns since computed columns are additions") + preload.Columns = reflection.GetSQLModelColumns(relatedModel) + } - // Apply ComputedQL fields if any - if len(preload.ComputedQL) > 0 { - for colName, colExpr := range preload.ComputedQL { - logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName) - sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName)) - // Remove the computed column from selected columns to avoid duplication - for colIndex := range preload.Columns { - if preload.Columns[colIndex] == colName { - preload.Columns = append(preload.Columns[:colIndex], preload.Columns[colIndex+1:]...) - break + // Apply ComputedQL fields if any + if len(preload.ComputedQL) > 0 { + for colName, colExpr := range preload.ComputedQL { + logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName) + sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName)) + // Remove the computed column from selected columns to avoid duplication + for colIndex := range preload.Columns { + if preload.Columns[colIndex] == colName { + preload.Columns = append(preload.Columns[:colIndex], preload.Columns[colIndex+1:]...) + break + } } } } - } - // Handle OmitColumns - if len(preload.OmitColumns) > 0 && relatedModel != nil { - allCols := reflection.GetModelColumns(relatedModel) - // Remove omitted columns - preload.Columns = []string{} - for _, col := range allCols { - addCols := true - for _, omitCol := range preload.OmitColumns { - if col == omitCol { - addCols = false - break + // Handle OmitColumns + if len(preload.OmitColumns) > 0 { + allCols := preload.Columns + // Remove omitted columns + preload.Columns = []string{} + for _, col := range allCols { + addCols := true + for _, omitCol := range preload.OmitColumns { + if col == omitCol { + addCols = false + break + } + } + if addCols { + preload.Columns = append(preload.Columns, col) } } - if addCols { - preload.Columns = append(preload.Columns, col) - } } - } - // Apply column selection - if len(preload.Columns) > 0 { - sq = sq.Column(preload.Columns...) + // Apply column selection + if len(preload.Columns) > 0 { + sq = sq.Column(preload.Columns...) + } } // Apply filters @@ -620,7 +633,10 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co // Apply WHERE clause if len(preload.Where) > 0 { - sq = sq.Where(preload.Where) + sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation) + if len(sanitizedWhere) > 0 { + sq = sq.Where(sanitizedWhere) + } } // Apply limit @@ -628,6 +644,10 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co sq = sq.Limit(*preload.Limit) } + if preload.Offset != nil && *preload.Offset > 0 { + sq = sq.Offset(*preload.Offset) + } + return sq }) @@ -1312,7 +1332,7 @@ func (h *Handler) normalizeToSlice(data interface{}) []interface{} { func (h *Handler) extractNestedRelations( data map[string]interface{}, model interface{}, -) (map[string]interface{}, map[string]interface{}, error) { +) (_cleanedData map[string]interface{}, _relations map[string]interface{}, _err error) { // Get model type for reflection modelType := reflect.TypeOf(model) for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { @@ -1741,7 +1761,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac // Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged func (h *Handler) normalizeResultArray(data interface{}) interface{} { if data == nil { - return data + return nil } // Use reflection to check if data is a slice or array diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index 2930189..c08bc32 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -10,6 +10,7 @@ import ( "github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/reflection" ) // ExtendedRequestOptions extends common.RequestOptions with additional features @@ -122,78 +123,77 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E // Merge headers and query parameters - query parameters take precedence // This allows the same parameters to be specified in either headers or query string + // Normalize keys to lowercase to ensure query params properly override headers combinedParams := make(map[string]string) for key, value := range headers { - combinedParams[key] = value + combinedParams[strings.ToLower(key)] = value } for key, value := range queryParams { - combinedParams[key] = value + combinedParams[strings.ToLower(key)] = value } // Process each parameter (from both headers and query params) + // Note: keys are already normalized to lowercase in combinedParams for key, value := range combinedParams { - // Normalize parameter key to lowercase for consistent matching - normalizedKey := strings.ToLower(key) - // Decode value if it's base64 encoded decodedValue := decodeHeaderValue(value) // Parse based on parameter prefix/name switch { // Field Selection - case strings.HasPrefix(normalizedKey, "x-select-fields"): + case strings.HasPrefix(key, "x-select-fields"): h.parseSelectFields(&options, decodedValue) - case strings.HasPrefix(normalizedKey, "x-not-select-fields"): + case strings.HasPrefix(key, "x-not-select-fields"): h.parseNotSelectFields(&options, decodedValue) - case strings.HasPrefix(normalizedKey, "x-clean-json"): + case strings.HasPrefix(key, "x-clean-json"): options.CleanJSON = strings.EqualFold(decodedValue, "true") // Filtering & Search - case strings.HasPrefix(normalizedKey, "x-fieldfilter-"): - h.parseFieldFilter(&options, normalizedKey, decodedValue) - case strings.HasPrefix(normalizedKey, "x-searchfilter-"): - h.parseSearchFilter(&options, normalizedKey, decodedValue) - case strings.HasPrefix(normalizedKey, "x-searchop-"): - h.parseSearchOp(&options, normalizedKey, decodedValue, "AND") - case strings.HasPrefix(normalizedKey, "x-searchor-"): - h.parseSearchOp(&options, normalizedKey, decodedValue, "OR") - case strings.HasPrefix(normalizedKey, "x-searchand-"): - h.parseSearchOp(&options, normalizedKey, decodedValue, "AND") - case strings.HasPrefix(normalizedKey, "x-searchcols"): + case strings.HasPrefix(key, "x-fieldfilter-"): + h.parseFieldFilter(&options, key, decodedValue) + case strings.HasPrefix(key, "x-searchfilter-"): + h.parseSearchFilter(&options, key, decodedValue) + case strings.HasPrefix(key, "x-searchop-"): + h.parseSearchOp(&options, key, decodedValue, "AND") + case strings.HasPrefix(key, "x-searchor-"): + h.parseSearchOp(&options, key, decodedValue, "OR") + case strings.HasPrefix(key, "x-searchand-"): + h.parseSearchOp(&options, key, decodedValue, "AND") + case strings.HasPrefix(key, "x-searchcols"): options.SearchColumns = h.parseCommaSeparated(decodedValue) - case strings.HasPrefix(normalizedKey, "x-custom-sql-w"): + case strings.HasPrefix(key, "x-custom-sql-w"): options.CustomSQLWhere = decodedValue - case strings.HasPrefix(normalizedKey, "x-custom-sql-or"): + case strings.HasPrefix(key, "x-custom-sql-or"): options.CustomSQLOr = decodedValue // Joins & Relations - case strings.HasPrefix(normalizedKey, "x-preload"): - if strings.HasSuffix(normalizedKey, "-where") { + case strings.HasPrefix(key, "x-preload"): + if strings.HasSuffix(key, "-where") { continue } whereClaude := combinedParams[fmt.Sprintf("%s-where", key)] h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude)) - case strings.HasPrefix(normalizedKey, "x-expand"): + case strings.HasPrefix(key, "x-expand"): h.parseExpand(&options, decodedValue) - case strings.HasPrefix(normalizedKey, "x-custom-sql-join"): + case strings.HasPrefix(key, "x-custom-sql-join"): // TODO: Implement custom SQL join logger.Debug("Custom SQL join not yet implemented: %s", decodedValue) // Sorting & Pagination - case strings.HasPrefix(normalizedKey, "x-sort"): + case strings.HasPrefix(key, "x-sort"): h.parseSorting(&options, decodedValue) - //Special cases for older clients using sort(a,b,-c) syntax - case strings.HasPrefix(normalizedKey, "sort(") && strings.Contains(normalizedKey, ")"): - sortValue := normalizedKey[strings.Index(normalizedKey, "(")+1 : strings.Index(normalizedKey, ")")] + // Special cases for older clients using sort(a,b,-c) syntax + case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"): + sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")] h.parseSorting(&options, sortValue) - case strings.HasPrefix(normalizedKey, "x-limit"): + case strings.HasPrefix(key, "x-limit"): if limit, err := strconv.Atoi(decodedValue); err == nil { options.Limit = &limit } - //Special cases for older clients using limit(n) syntax - case strings.HasPrefix(normalizedKey, "limit(") && strings.Contains(normalizedKey, ")"): - limitValue := normalizedKey[strings.Index(normalizedKey, "(")+1 : strings.Index(normalizedKey, ")")] + // Special cases for older clients using limit(n) syntax + case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"): + limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")] limitValueParts := strings.Split(limitValue, ",") if len(limitValueParts) > 1 { @@ -209,42 +209,42 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E } } - case strings.HasPrefix(normalizedKey, "x-offset"): + case strings.HasPrefix(key, "x-offset"): if offset, err := strconv.Atoi(decodedValue); err == nil { options.Offset = &offset } - case strings.HasPrefix(normalizedKey, "x-cursor-forward"): + case strings.HasPrefix(key, "x-cursor-forward"): options.CursorForward = decodedValue - case strings.HasPrefix(normalizedKey, "x-cursor-backward"): + case strings.HasPrefix(key, "x-cursor-backward"): options.CursorBackward = decodedValue // Advanced Features - case strings.HasPrefix(normalizedKey, "x-advsql-"): - colName := strings.TrimPrefix(normalizedKey, "x-advsql-") + case strings.HasPrefix(key, "x-advsql-"): + colName := strings.TrimPrefix(key, "x-advsql-") options.AdvancedSQL[colName] = decodedValue - case strings.HasPrefix(normalizedKey, "x-cql-sel-"): - colName := strings.TrimPrefix(normalizedKey, "x-cql-sel-") + case strings.HasPrefix(key, "x-cql-sel-"): + colName := strings.TrimPrefix(key, "x-cql-sel-") options.ComputedQL[colName] = decodedValue - case strings.HasPrefix(normalizedKey, "x-distinct"): + case strings.HasPrefix(key, "x-distinct"): options.Distinct = strings.EqualFold(decodedValue, "true") - case strings.HasPrefix(normalizedKey, "x-skipcount"): + case strings.HasPrefix(key, "x-skipcount"): options.SkipCount = strings.EqualFold(decodedValue, "true") - case strings.HasPrefix(normalizedKey, "x-skipcache"): + case strings.HasPrefix(key, "x-skipcache"): options.SkipCache = strings.EqualFold(decodedValue, "true") - case strings.HasPrefix(normalizedKey, "x-fetch-rownumber"): + case strings.HasPrefix(key, "x-fetch-rownumber"): options.FetchRowNumber = &decodedValue - case strings.HasPrefix(normalizedKey, "x-pkrow"): + case strings.HasPrefix(key, "x-pkrow"): options.PKRow = &decodedValue // Response Format - case strings.HasPrefix(normalizedKey, "x-simpleapi"): + case strings.HasPrefix(key, "x-simpleapi"): options.ResponseFormat = "simple" - case strings.HasPrefix(normalizedKey, "x-detailapi"): + case strings.HasPrefix(key, "x-detailapi"): options.ResponseFormat = "detail" - case strings.HasPrefix(normalizedKey, "x-syncfusion"): + case strings.HasPrefix(key, "x-syncfusion"): options.ResponseFormat = "syncfusion" - case strings.HasPrefix(normalizedKey, "x-single-record-as-object"): + case strings.HasPrefix(key, "x-single-record-as-object"): // Parse as boolean - "false" disables, "true" enables (default is true) if strings.EqualFold(decodedValue, "false") { options.SingleRecordAsObject = false @@ -253,11 +253,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E } // Transaction Control - case strings.HasPrefix(normalizedKey, "x-transaction-atomic"): + case strings.HasPrefix(key, "x-transaction-atomic"): options.AtomicTransaction = strings.EqualFold(decodedValue, "true") // X-Files - comprehensive JSON configuration - case strings.HasPrefix(normalizedKey, "x-files"): + case strings.HasPrefix(key, "x-files"): h.parseXFiles(&options, decodedValue) } } @@ -720,7 +720,7 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions, // Try to get the model type for the next level // This allows nested resolution - if nextModel := h.getRelationModel(currentModel, resolvedPart); nextModel != nil { + if nextModel := reflection.GetRelationModel(currentModel, resolvedPart); nextModel != nil { currentModel = nextModel } } @@ -744,58 +744,6 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions, } } -// getRelationModel gets the model type for a relation field -func (h *Handler) getRelationModel(model interface{}, fieldName string) interface{} { - if model == nil || fieldName == "" { - return nil - } - - modelType := reflect.TypeOf(model) - if modelType == nil { - return nil - } - - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } - - if modelType == nil || modelType.Kind() != reflect.Struct { - return nil - } - - // Find the field - field, found := modelType.FieldByName(fieldName) - if !found { - return nil - } - - // Get the target type - targetType := field.Type - if targetType == nil { - return nil - } - - if targetType.Kind() == reflect.Slice { - targetType = targetType.Elem() - if targetType == nil { - return nil - } - } - if targetType.Kind() == reflect.Ptr { - targetType = targetType.Elem() - if targetType == nil { - return nil - } - } - - if targetType.Kind() != reflect.Struct { - return nil - } - - // Create a zero value of the target type - return reflect.New(targetType).Elem().Interface() -} - // resolveRelationName resolves a relation name or table name to the actual field name in the model // If the input is already a field name, it returns it as-is // If the input is a table name, it looks up the corresponding relation field @@ -983,192 +931,6 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption } } -// extractSourceColumn extracts the base column name from PostgreSQL JSON operators -// Examples: -// - "columna->>'val'" returns "columna" -// - "columna->'key'" returns "columna" -// - "columna" returns "columna" -// - "table.columna->>'val'" returns "table.columna" -func extractSourceColumn(colName string) string { - // Check for PostgreSQL JSON operators: -> and ->> - if idx := strings.Index(colName, "->>"); idx != -1 { - return strings.TrimSpace(colName[:idx]) - } - if idx := strings.Index(colName, "->"); idx != -1 { - return strings.TrimSpace(colName[:idx]) - } - return colName -} - -// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model -func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind { - if model == nil { - return reflect.Invalid - } - - // Extract the source column name (remove JSON operators like ->> or ->) - sourceColName := extractSourceColumn(colName) - - modelType := reflect.TypeOf(model) - // Dereference pointer if needed - if modelType.Kind() == reflect.Ptr { - modelType = modelType.Elem() - } - - // Ensure it's a struct - if modelType.Kind() != reflect.Struct { - return reflect.Invalid - } - - // Find the field by JSON tag or field name - for i := 0; i < modelType.NumField(); i++ { - field := modelType.Field(i) - - // Check JSON tag - jsonTag := field.Tag.Get("json") - if jsonTag != "" { - // Parse JSON tag (format: "name,omitempty") - parts := strings.Split(jsonTag, ",") - if parts[0] == sourceColName { - return field.Type.Kind() - } - } - - // Check field name (case-insensitive) - if strings.EqualFold(field.Name, sourceColName) { - return field.Type.Kind() - } - - // Check snake_case conversion - snakeCaseName := toSnakeCase(field.Name) - if snakeCaseName == sourceColName { - return field.Type.Kind() - } - } - - return reflect.Invalid -} - -// toSnakeCase converts a string from CamelCase to snake_case -func toSnakeCase(s string) string { - var result strings.Builder - for i, r := range s { - if i > 0 && r >= 'A' && r <= 'Z' { - result.WriteRune('_') - } - result.WriteRune(r) - } - return strings.ToLower(result.String()) -} - -// isNumericType checks if a reflect.Kind is a numeric type -func isNumericType(kind reflect.Kind) bool { - return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 || - kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint || - kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 || - kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64 -} - -// isStringType checks if a reflect.Kind is a string type -func isStringType(kind reflect.Kind) bool { - return kind == reflect.String -} - -// convertToNumericType converts a string value to the appropriate numeric type -func convertToNumericType(value string, kind reflect.Kind) (interface{}, error) { - value = strings.TrimSpace(value) - - switch kind { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - // Parse as integer - bitSize := 64 - switch kind { - case reflect.Int8: - bitSize = 8 - case reflect.Int16: - bitSize = 16 - case reflect.Int32: - bitSize = 32 - } - - intVal, err := strconv.ParseInt(value, 10, bitSize) - if err != nil { - return nil, fmt.Errorf("invalid integer value: %w", err) - } - - // Return the appropriate type - switch kind { - case reflect.Int: - return int(intVal), nil - case reflect.Int8: - return int8(intVal), nil - case reflect.Int16: - return int16(intVal), nil - case reflect.Int32: - return int32(intVal), nil - case reflect.Int64: - return intVal, nil - } - - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - // Parse as unsigned integer - bitSize := 64 - switch kind { - case reflect.Uint8: - bitSize = 8 - case reflect.Uint16: - bitSize = 16 - case reflect.Uint32: - bitSize = 32 - } - - uintVal, err := strconv.ParseUint(value, 10, bitSize) - if err != nil { - return nil, fmt.Errorf("invalid unsigned integer value: %w", err) - } - - // Return the appropriate type - switch kind { - case reflect.Uint: - return uint(uintVal), nil - case reflect.Uint8: - return uint8(uintVal), nil - case reflect.Uint16: - return uint16(uintVal), nil - case reflect.Uint32: - return uint32(uintVal), nil - case reflect.Uint64: - return uintVal, nil - } - - case reflect.Float32, reflect.Float64: - // Parse as float - bitSize := 64 - if kind == reflect.Float32 { - bitSize = 32 - } - - floatVal, err := strconv.ParseFloat(value, bitSize) - if err != nil { - return nil, fmt.Errorf("invalid float value: %w", err) - } - - if kind == reflect.Float32 { - return float32(floatVal), nil - } - return floatVal, nil - } - - return nil, fmt.Errorf("unsupported numeric type: %v", kind) -} - -// isNumericValue checks if a string value can be parsed as a number -func isNumericValue(value string) bool { - value = strings.TrimSpace(value) - _, err := strconv.ParseFloat(value, 64) - return err == nil -} - // ColumnCastInfo holds information about whether a column needs casting type ColumnCastInfo struct { NeedsCast bool @@ -1182,7 +944,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti return ColumnCastInfo{NeedsCast: false, IsNumericType: false} } - colType := h.getColumnTypeFromModel(model, filter.Column) + colType := reflection.GetColumnTypeFromModel(model, filter.Column) if colType == reflect.Invalid { // Column not found in model, no casting needed logger.Debug("Column %s not found in model, skipping type validation", filter.Column) @@ -1193,18 +955,18 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti valueIsNumeric := false if strVal, ok := filter.Value.(string); ok { strVal = strings.Trim(strVal, "%") - valueIsNumeric = isNumericValue(strVal) + valueIsNumeric = reflection.IsNumericValue(strVal) } // Adjust based on column type switch { - case isNumericType(colType): + case reflection.IsNumericType(colType): // Column is numeric if valueIsNumeric { // Value is numeric - try to convert it if strVal, ok := filter.Value.(string); ok { strVal = strings.Trim(strVal, "%") - numericVal, err := convertToNumericType(strVal, colType) + numericVal, err := reflection.ConvertToNumericType(strVal, colType) if err != nil { logger.Debug("Failed to convert value '%s' to numeric type for column %s, will use text cast", strVal, filter.Column) return ColumnCastInfo{NeedsCast: true, IsNumericType: true} @@ -1219,7 +981,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti return ColumnCastInfo{NeedsCast: true, IsNumericType: true} } - case isStringType(colType): + case reflection.IsStringType(colType): // String columns don't need casting return ColumnCastInfo{NeedsCast: false, IsNumericType: false}