diff --git a/pkg/common/types.go b/pkg/common/types.go index 56e6f60..a5c6d70 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -37,9 +37,10 @@ type PreloadOption struct { } type FilterOption struct { - Column string `json:"column"` - Operator string `json:"operator"` - Value interface{} `json:"value"` + Column string `json:"column"` + Operator string `json:"operator"` + Value interface{} `json:"value"` + LogicOperator string `json:"logic_operator"` // "AND" or "OR" - how this filter combines with previous filters } type SortOption struct { diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 9630186..5b39d2a 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -239,10 +239,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // This may need to be handled differently per database adapter } - // Apply filters - for _, filter := range options.Filters { - logger.Debug("Applying filter: %s %s %v", filter.Column, filter.Operator, filter.Value) - query = h.applyFilter(query, filter) + // Apply filters - validate and adjust for column types first + for i := range options.Filters { + filter := &options.Filters[i] + + // Validate and adjust filter based on column type + castInfo := h.ValidateAndAdjustFilterForColumnType(filter, model) + + // Default to AND if LogicOperator is not set + logicOp := filter.LogicOperator + if logicOp == "" { + logicOp = "AND" + } + + logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp) + query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp) } // Apply custom SQL WHERE clause (AND condition) @@ -491,55 +502,96 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id }, nil) } -func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery { +// qualifyColumnName ensures column name is fully qualified with table name if not already +func (h *Handler) qualifyColumnName(columnName, fullTableName string) string { + // Check if column already has a table/schema prefix (contains a dot) + if strings.Contains(columnName, ".") { + return columnName + } + + // If no table name provided, return column as-is + if fullTableName == "" { + return columnName + } + + // Extract just the table name from "schema.table" format + // Only use the table name part, not the schema + tableOnly := fullTableName + if idx := strings.LastIndex(fullTableName, "."); idx != -1 { + tableOnly = fullTableName[idx+1:] + } + + // Return column qualified with just the table name + return fmt.Sprintf("%s.%s", tableOnly, columnName) +} + +func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption, tableName string, needsCast bool, logicOp string) common.SelectQuery { + // Qualify the column name with table name if not already qualified + qualifiedColumn := h.qualifyColumnName(filter.Column, tableName) + + // Apply casting to text if needed for non-numeric columns or non-numeric values + if needsCast { + qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn) + } + + // Helper function to apply the correct Where method based on logic operator + applyWhere := func(condition string, args ...interface{}) common.SelectQuery { + if logicOp == "OR" { + return query.WhereOr(condition, args...) + } + return query.Where(condition, args...) + } + switch strings.ToLower(filter.Operator) { case "eq", "equals": - return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value) + return applyWhere(fmt.Sprintf("%s = ?", qualifiedColumn), filter.Value) case "neq", "not_equals", "ne": - return query.Where(fmt.Sprintf("%s != ?", filter.Column), filter.Value) + return applyWhere(fmt.Sprintf("%s != ?", qualifiedColumn), filter.Value) case "gt", "greater_than": - return query.Where(fmt.Sprintf("%s > ?", filter.Column), filter.Value) + return applyWhere(fmt.Sprintf("%s > ?", qualifiedColumn), filter.Value) case "gte", "greater_than_equals", "ge": - return query.Where(fmt.Sprintf("%s >= ?", filter.Column), filter.Value) + return applyWhere(fmt.Sprintf("%s >= ?", qualifiedColumn), filter.Value) case "lt", "less_than": - return query.Where(fmt.Sprintf("%s < ?", filter.Column), filter.Value) + return applyWhere(fmt.Sprintf("%s < ?", qualifiedColumn), filter.Value) case "lte", "less_than_equals", "le": - return query.Where(fmt.Sprintf("%s <= ?", filter.Column), filter.Value) + return applyWhere(fmt.Sprintf("%s <= ?", qualifiedColumn), filter.Value) case "like": - return query.Where(fmt.Sprintf("%s LIKE ?", filter.Column), filter.Value) + return applyWhere(fmt.Sprintf("%s LIKE ?", qualifiedColumn), filter.Value) case "ilike": // Use ILIKE for case-insensitive search (PostgreSQL) - // For other databases, cast to citext or use LOWER() - return query.Where(fmt.Sprintf("CAST(%s AS TEXT) ILIKE ?", filter.Column), filter.Value) + // Column is already cast to TEXT if needed + return applyWhere(fmt.Sprintf("%s ILIKE ?", qualifiedColumn), filter.Value) case "in": - return query.Where(fmt.Sprintf("%s IN (?)", filter.Column), filter.Value) + return applyWhere(fmt.Sprintf("%s IN (?)", qualifiedColumn), filter.Value) case "between": // Handle between operator - exclusive (> val1 AND < val2) if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 { - return query.Where(fmt.Sprintf("%s > ? AND %s < ?", filter.Column, filter.Column), values[0], values[1]) + return applyWhere(fmt.Sprintf("%s > ? AND %s < ?", qualifiedColumn, qualifiedColumn), values[0], values[1]) } else if values, ok := filter.Value.([]string); ok && len(values) == 2 { - return query.Where(fmt.Sprintf("%s > ? AND %s < ?", filter.Column, filter.Column), values[0], values[1]) + return applyWhere(fmt.Sprintf("%s > ? AND %s < ?", qualifiedColumn, qualifiedColumn), values[0], values[1]) } logger.Warn("Invalid BETWEEN filter value format") return query case "between_inclusive": // Handle between inclusive operator - inclusive (>= val1 AND <= val2) if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 { - return query.Where(fmt.Sprintf("%s >= ? AND %s <= ?", filter.Column, filter.Column), values[0], values[1]) + return applyWhere(fmt.Sprintf("%s >= ? AND %s <= ?", qualifiedColumn, qualifiedColumn), values[0], values[1]) } else if values, ok := filter.Value.([]string); ok && len(values) == 2 { - return query.Where(fmt.Sprintf("%s >= ? AND %s <= ?", filter.Column, filter.Column), values[0], values[1]) + return applyWhere(fmt.Sprintf("%s >= ? AND %s <= ?", qualifiedColumn, qualifiedColumn), values[0], values[1]) } logger.Warn("Invalid BETWEEN INCLUSIVE filter value format") return query case "is_null", "isnull": - // Check for NULL values - return query.Where(fmt.Sprintf("(%s IS NULL OR %s = '')", filter.Column, filter.Column)) + // Check for NULL values - don't use cast for NULL checks + colName := h.qualifyColumnName(filter.Column, tableName) + return applyWhere(fmt.Sprintf("(%s IS NULL OR %s = '')", colName, colName)) case "is_not_null", "isnotnull": - // Check for NOT NULL values - return query.Where(fmt.Sprintf("(%s IS NOT NULL AND %s != '')", filter.Column, filter.Column)) + // Check for NOT NULL values - don't use cast for NULL checks + colName := h.qualifyColumnName(filter.Column, tableName) + return applyWhere(fmt.Sprintf("(%s IS NOT NULL AND %s != '')", colName, colName)) default: logger.Warn("Unknown filter operator: %s, defaulting to equals", filter.Operator) - return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value) + return applyWhere(fmt.Sprintf("%s = ?", qualifiedColumn), filter.Value) } } diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index 068058b..2f16588 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "reflect" "strconv" "strings" @@ -235,9 +236,10 @@ func (h *Handler) parseNotSelectFields(options *ExtendedRequestOptions, value st func (h *Handler) parseFieldFilter(options *ExtendedRequestOptions, headerKey, value string) { colName := strings.TrimPrefix(headerKey, "x-fieldfilter-") options.Filters = append(options.Filters, common.FilterOption{ - Column: colName, - Operator: "eq", - Value: value, + Column: colName, + Operator: "eq", + Value: value, + LogicOperator: "AND", // Default to AND }) } @@ -246,9 +248,10 @@ func (h *Handler) parseSearchFilter(options *ExtendedRequestOptions, headerKey, colName := strings.TrimPrefix(headerKey, "x-searchfilter-") // Use ILIKE for fuzzy search options.Filters = append(options.Filters, common.FilterOption{ - Column: colName, - Operator: "ilike", - Value: "%" + value + "%", + Column: colName, + Operator: "ilike", + Value: "%" + value + "%", + LogicOperator: "AND", // Default to AND }) } @@ -277,70 +280,68 @@ func (h *Handler) parseSearchOp(options *ExtendedRequestOptions, headerKey, valu colName := parts[1] // Map operator names to filter operators - filterOp := h.mapSearchOperator(operator, value) + filterOp := h.mapSearchOperator(colName, operator, value) + + // Set the logic operator (AND or OR) + filterOp.LogicOperator = logicOp options.Filters = append(options.Filters, filterOp) - // Note: OR logic would need special handling in query builder - // For now, we'll add a comment to indicate OR logic - if logicOp == "OR" { - // TODO: Implement OR logic in query builder - logger.Debug("OR logic filter: %s %s %v", colName, filterOp.Operator, filterOp.Value) - } + logger.Debug("%s logic filter: %s %s %v", logicOp, colName, filterOp.Operator, filterOp.Value) } // mapSearchOperator maps search operator names to filter operators -func (h *Handler) mapSearchOperator(operator, value string) common.FilterOption { +func (h *Handler) mapSearchOperator(colName, operator, value string) common.FilterOption { operator = strings.ToLower(operator) switch operator { - case "contains": - return common.FilterOption{Operator: "ilike", Value: "%" + value + "%"} + case "contains", "contain", "like": + return common.FilterOption{Column: colName, Operator: "ilike", Value: "%" + value + "%"} case "beginswith", "startswith": - return common.FilterOption{Operator: "ilike", Value: value + "%"} + return common.FilterOption{Column: colName, Operator: "ilike", Value: value + "%"} case "endswith": - return common.FilterOption{Operator: "ilike", Value: "%" + value} - case "equals", "eq": - return common.FilterOption{Operator: "eq", Value: value} - case "notequals", "neq", "ne": - return common.FilterOption{Operator: "neq", Value: value} - case "greaterthan", "gt": - return common.FilterOption{Operator: "gt", Value: value} - case "lessthan", "lt": - return common.FilterOption{Operator: "lt", Value: value} - case "greaterthanorequal", "gte", "ge": - return common.FilterOption{Operator: "gte", Value: value} - case "lessthanorequal", "lte", "le": - return common.FilterOption{Operator: "lte", Value: value} + return common.FilterOption{Column: colName, Operator: "ilike", Value: "%" + value} + case "equals", "eq", "=": + return common.FilterOption{Column: colName, Operator: "eq", Value: value} + case "notequals", "neq", "ne", "!=", "<>": + return common.FilterOption{Column: colName, Operator: "neq", Value: value} + case "greaterthan", "gt", ">": + return common.FilterOption{Column: colName, Operator: "gt", Value: value} + case "lessthan", "lt", "<": + return common.FilterOption{Column: colName, Operator: "lt", Value: value} + case "greaterthanorequal", "gte", "ge", ">=": + return common.FilterOption{Column: colName, Operator: "gte", Value: value} + case "lessthanorequal", "lte", "le", "<=": + return common.FilterOption{Column: colName, Operator: "lte", Value: value} case "between": // Parse between values (format: "value1,value2") // Between is exclusive (> value1 AND < value2) parts := strings.Split(value, ",") if len(parts) == 2 { - return common.FilterOption{Operator: "between", Value: parts} + return common.FilterOption{Column: colName, Operator: "between", Value: parts} } - return common.FilterOption{Operator: "eq", Value: value} + return common.FilterOption{Column: colName, Operator: "eq", Value: value} case "betweeninclusive": // Parse between values (format: "value1,value2") // Between inclusive is >= value1 AND <= value2 parts := strings.Split(value, ",") if len(parts) == 2 { - return common.FilterOption{Operator: "between_inclusive", Value: parts} + return common.FilterOption{Column: colName, Operator: "between_inclusive", Value: parts} } - return common.FilterOption{Operator: "eq", Value: value} + return common.FilterOption{Column: colName, Operator: "eq", Value: value} case "in": // Parse IN values (format: "value1,value2,value3") values := strings.Split(value, ",") - return common.FilterOption{Operator: "in", Value: values} + return common.FilterOption{Column: colName, Operator: "in", Value: values} case "empty", "isnull", "null": // Check for NULL or empty string - return common.FilterOption{Operator: "is_null", Value: nil} + return common.FilterOption{Column: colName, Operator: "is_null", Value: nil} case "notempty", "isnotnull", "notnull": // Check for NOT NULL - return common.FilterOption{Operator: "is_not_null", Value: nil} + return common.FilterOption{Column: colName, Operator: "is_not_null", Value: nil} default: logger.Warn("Unknown search operator: %s, defaulting to equals", operator) - return common.FilterOption{Operator: "eq", Value: value} + return common.FilterOption{Column: colName, Operator: "eq", Value: value} } } @@ -427,10 +428,16 @@ func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) { } else if strings.HasPrefix(field, "+") { direction = "ASC" colName = strings.TrimPrefix(field, "+") + } else if strings.HasSuffix(field, " desc") { + direction = "DESC" + colName = strings.TrimSuffix(field, "desc") + } else if strings.HasSuffix(field, " asc") { + direction = "ASC" + colName = strings.TrimSuffix(field, "asc") } options.Sort = append(options.Sort, common.SortOption{ - Column: colName, + Column: strings.Trim(colName, " "), Direction: direction, }) } @@ -462,3 +469,235 @@ func (h *Handler) parseJSONHeader(value string) (map[string]interface{}, error) } return result, nil } + +// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model +func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind { + if model == nil { + return reflect.Invalid + } + + modelType := reflect.TypeOf(model) + // Dereference pointer if needed + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + // Ensure it's a struct + if modelType.Kind() != reflect.Struct { + return reflect.Invalid + } + + // Find the field by JSON tag or field name + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + + // Check JSON tag + jsonTag := field.Tag.Get("json") + if jsonTag != "" { + // Parse JSON tag (format: "name,omitempty") + parts := strings.Split(jsonTag, ",") + if parts[0] == colName { + return field.Type.Kind() + } + } + + // Check field name (case-insensitive) + if strings.EqualFold(field.Name, colName) { + return field.Type.Kind() + } + + // Check snake_case conversion + snakeCaseName := toSnakeCase(field.Name) + if snakeCaseName == colName { + return field.Type.Kind() + } + } + + return reflect.Invalid +} + +// toSnakeCase converts a string from CamelCase to snake_case +func toSnakeCase(s string) string { + var result strings.Builder + for i, r := range s { + if i > 0 && r >= 'A' && r <= 'Z' { + result.WriteRune('_') + } + result.WriteRune(r) + } + return strings.ToLower(result.String()) +} + +// isNumericType checks if a reflect.Kind is a numeric type +func isNumericType(kind reflect.Kind) bool { + return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 || + kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint || + kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 || + kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64 +} + +// isStringType checks if a reflect.Kind is a string type +func isStringType(kind reflect.Kind) bool { + return kind == reflect.String +} + +// isBoolType checks if a reflect.Kind is a boolean type +func isBoolType(kind reflect.Kind) bool { + return kind == reflect.Bool +} + +// convertToNumericType converts a string value to the appropriate numeric type +func convertToNumericType(value string, kind reflect.Kind) (interface{}, error) { + value = strings.TrimSpace(value) + + switch kind { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + // Parse as integer + bitSize := 64 + switch kind { + case reflect.Int8: + bitSize = 8 + case reflect.Int16: + bitSize = 16 + case reflect.Int32: + bitSize = 32 + } + + intVal, err := strconv.ParseInt(value, 10, bitSize) + if err != nil { + return nil, fmt.Errorf("invalid integer value: %w", err) + } + + // Return the appropriate type + switch kind { + case reflect.Int: + return int(intVal), nil + case reflect.Int8: + return int8(intVal), nil + case reflect.Int16: + return int16(intVal), nil + case reflect.Int32: + return int32(intVal), nil + case reflect.Int64: + return intVal, nil + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + // Parse as unsigned integer + bitSize := 64 + switch kind { + case reflect.Uint8: + bitSize = 8 + case reflect.Uint16: + bitSize = 16 + case reflect.Uint32: + bitSize = 32 + } + + uintVal, err := strconv.ParseUint(value, 10, bitSize) + if err != nil { + return nil, fmt.Errorf("invalid unsigned integer value: %w", err) + } + + // Return the appropriate type + switch kind { + case reflect.Uint: + return uint(uintVal), nil + case reflect.Uint8: + return uint8(uintVal), nil + case reflect.Uint16: + return uint16(uintVal), nil + case reflect.Uint32: + return uint32(uintVal), nil + case reflect.Uint64: + return uintVal, nil + } + + case reflect.Float32, reflect.Float64: + // Parse as float + bitSize := 64 + if kind == reflect.Float32 { + bitSize = 32 + } + + floatVal, err := strconv.ParseFloat(value, bitSize) + if err != nil { + return nil, fmt.Errorf("invalid float value: %w", err) + } + + if kind == reflect.Float32 { + return float32(floatVal), nil + } + return floatVal, nil + } + + return nil, fmt.Errorf("unsupported numeric type: %v", kind) +} + +// isNumericValue checks if a string value can be parsed as a number +func isNumericValue(value string) bool { + value = strings.TrimSpace(value) + _, err := strconv.ParseFloat(value, 64) + return err == nil +} + +// ColumnCastInfo holds information about whether a column needs casting +type ColumnCastInfo struct { + NeedsCast bool + IsNumericType bool +} + +// ValidateAndAdjustFilterForColumnType validates and adjusts a filter based on column type +// Returns ColumnCastInfo indicating whether the column should be cast to text in SQL +func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOption, model interface{}) ColumnCastInfo { + if filter == nil || model == nil { + return ColumnCastInfo{NeedsCast: false, IsNumericType: false} + } + + colType := h.getColumnTypeFromModel(model, filter.Column) + if colType == reflect.Invalid { + // Column not found in model, no casting needed + logger.Debug("Column %s not found in model, skipping type validation", filter.Column) + return ColumnCastInfo{NeedsCast: false, IsNumericType: false} + } + + // Check if the input value is numeric + valueIsNumeric := false + if strVal, ok := filter.Value.(string); ok { + strVal = strings.Trim(strVal, "%") + valueIsNumeric = isNumericValue(strVal) + } + + // Adjust based on column type + switch { + case isNumericType(colType): + // Column is numeric + if valueIsNumeric { + // Value is numeric - try to convert it + if strVal, ok := filter.Value.(string); ok { + strVal = strings.Trim(strVal, "%") + numericVal, err := convertToNumericType(strVal, colType) + if err != nil { + logger.Debug("Failed to convert value '%s' to numeric type for column %s, will use text cast", strVal, filter.Column) + return ColumnCastInfo{NeedsCast: true, IsNumericType: true} + } + filter.Value = numericVal + } + // No cast needed - numeric column with numeric value + return ColumnCastInfo{NeedsCast: false, IsNumericType: true} + } else { + // Value is not numeric - cast column to text for comparison + logger.Debug("Non-numeric value for numeric column %s, will cast to text", filter.Column) + return ColumnCastInfo{NeedsCast: true, IsNumericType: true} + } + + case isStringType(colType): + // String columns don't need casting + return ColumnCastInfo{NeedsCast: false, IsNumericType: false} + + default: + // For bool, time.Time, and other complex types - cast to text + logger.Debug("Complex type column %s, will cast to text", filter.Column) + return ColumnCastInfo{NeedsCast: true, IsNumericType: false} + } +}