Lots of refactoring, Fixes to preloads

This commit is contained in:
Hein 2025-11-21 10:17:20 +02:00
parent 7853a3f56a
commit 7e76977dcc
12 changed files with 563 additions and 386 deletions

View File

@ -86,7 +86,6 @@
"emptyFallthrough",
"equalFold",
"flagName",
"ifElseChain",
"indexAlloc",
"initClause",
"methodExprCall",
@ -106,6 +105,9 @@
"unnecessaryBlock",
"weakCond",
"yodaStyleExpr"
],
"disabled-checks": [
"ifElseChain"
]
},
"revive": {

View File

@ -237,7 +237,10 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() {
if r := recover(); r != nil {
logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
err := logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
if err != nil {
return
}
}
}()
if len(apply) == 0 {
@ -401,7 +404,7 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
err = logger.HandlePanic("BunInsertQuery.Exec", r)
}
}()
if b.values != nil && len(b.values) > 0 {
if len(b.values) > 0 {
if !b.hasModel {
// If no model was set, use the values map as the model
// Bun can insert map[string]interface{} directly

View File

@ -96,6 +96,117 @@ func IsSQLExpression(cond string) bool {
return false
}
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
// These conditions should be removed from WHERE clauses as they have no filtering effect
func IsTrivialCondition(cond string) bool {
cond = strings.TrimSpace(cond)
lowerCond := strings.ToLower(cond)
// Conditions that always evaluate to true
trivialConditions := []string{
"1=1", "1 = 1", "1= 1", "1 =1",
"true", "true = true", "true=true", "true= true", "true =true",
"0=0", "0 = 0", "0= 0", "0 =0",
}
for _, trivial := range trivialConditions {
if lowerCond == trivial {
return true
}
}
return false
}
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
//
// Parameters:
// - where: The WHERE clause string to sanitize
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
//
// Returns:
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
// - An empty string if all conditions were trivial or the input was empty
func SanitizeWhereClause(where string, tableName string) string {
if where == "" {
return ""
}
where = strings.TrimSpace(where)
// Split by AND to handle multiple conditions
conditions := splitByAND(where)
validConditions := make([]string, 0, len(conditions))
for _, cond := range conditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Skip trivial conditions that always evaluate to true
if IsTrivialCondition(cond) {
logger.Debug("Removing trivial condition: '%s'", cond)
continue
}
// If tableName is provided and the condition doesn't already have a table prefix,
// attempt to add it
if tableName != "" && !hasTablePrefix(cond) {
// Check if this is a SQL expression/literal that shouldn't be prefixed
if !IsSQLExpression(strings.ToLower(cond)) {
// Extract the column name and prefix it
columnName := ExtractColumnName(cond)
if columnName != "" {
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
logger.Debug("Prefixed column in condition: '%s'", cond)
}
}
}
validConditions = append(validConditions, cond)
}
if len(validConditions) == 0 {
return ""
}
result := strings.Join(validConditions, " AND ")
if result != where {
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
}
return result
}
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
// This is a simple split that doesn't handle nested parentheses or complex expressions
func splitByAND(where string) []string {
// First try uppercase AND
conditions := strings.Split(where, " AND ")
// If we didn't split on uppercase, try lowercase
if len(conditions) == 1 {
conditions = strings.Split(where, " and ")
}
// If we still didn't split, try mixed case
if len(conditions) == 1 {
conditions = strings.Split(where, " And ")
}
return conditions
}
// hasTablePrefix checks if a condition already has a table/relation prefix (contains a dot)
func hasTablePrefix(cond string) bool {
// Look for patterns like "table.column" or "`table`.`column`" or "\"table\".\"column\""
return strings.Contains(cond, ".")
}
// ExtractColumnName extracts the column name from a WHERE condition
// For example: "status = 'active'" returns "status"
func ExtractColumnName(cond string) string {

View File

@ -238,13 +238,13 @@ func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
var err error
if b == nil {
t = &SqlTimeStamp{}
return nil
}
s := strings.Trim(strings.Trim(string(b), " "), "\"")
if s == "null" || s == "" || s == "0" ||
s == "0001-01-01T00:00:00" || s == "0001-01-01" {
t = &SqlTimeStamp{}
return nil
}
@ -293,7 +293,7 @@ func (t *SqlTimeStamp) Scan(value interface{}) error {
// String - Override String format of time
func (t SqlTimeStamp) String() string {
return fmt.Sprintf("%s", time.Time(t).Format("2006-01-02T15:04:05"))
return time.Time(t).Format("2006-01-02T15:04:05")
}
// GetTime - Returns Time
@ -308,7 +308,7 @@ func (t *SqlTimeStamp) SetTime(pTime time.Time) {
// Format - Formats the time
func (t SqlTimeStamp) Format(layout string) string {
return fmt.Sprintf("%s", time.Time(t).Format(layout))
return time.Time(t).Format(layout)
}
func SqlTimeStampNow() SqlTimeStamp {
@ -420,7 +420,6 @@ func (t *SqlDate) UnmarshalJSON(b []byte) error {
if s == "null" || s == "" || s == "0" ||
strings.HasPrefix(s, "0001-01-01T00:00:00") ||
s == "0001-01-01" {
t = &SqlDate{}
return nil
}
@ -517,8 +516,8 @@ func (t *SqlTime) UnmarshalJSON(b []byte) error {
*t = SqlTime{}
return nil
}
tx := time.Time{}
tx, err = tryParseDT(s)
tx, err := tryParseDT(s)
*t = SqlTime(tx)
return err
@ -642,9 +641,8 @@ func (n SqlJSONB) AsSlice() ([]any, error) {
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
s := strings.Trim(strings.Trim(string(b), " "), "\"")
invalid := (s == "null" || s == "" || len(s) < 2) || !(strings.Contains(s, "{") || strings.Contains(s, "["))
invalid := (s == "null" || s == "" || len(s) < 2) || (!strings.Contains(s, "{") && !strings.Contains(s, "["))
if invalid {
s = ""
return nil
}
@ -725,7 +723,6 @@ func (n *SqlUUID) UnmarshalJSON(b []byte) error {
s := strings.Trim(strings.Trim(string(b), " "), "\"")
invalid := (s == "null" || s == "" || len(s) < 30)
if invalid {
s = ""
return nil
}
*n = SqlUUID(sql.NullString{String: s, Valid: !invalid})

View File

@ -6,6 +6,7 @@ import (
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ColumnValidator validates column names against a model's fields
@ -92,23 +93,6 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
return strings.ToLower(field.Name)
}
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func extractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// ValidateColumn validates a single column name
// Returns nil if valid, error if invalid
// Columns prefixed with "cql" (case insensitive) are always valid
@ -125,7 +109,7 @@ func (v *ColumnValidator) ValidateColumn(column string) error {
}
// Extract source column name (remove JSON operators like ->> or ->)
sourceColumn := extractSourceColumn(column)
sourceColumn := reflection.ExtractSourceColumn(column)
// Check if column exists in model
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {

View File

@ -2,6 +2,8 @@ package common
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
func TestExtractSourceColumn(t *testing.T) {
@ -49,9 +51,9 @@ func TestExtractSourceColumn(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := extractSourceColumn(tc.input)
result := reflection.ExtractSourceColumn(tc.input)
if result != tc.expected {
t.Errorf("extractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
t.Errorf("reflection.ExtractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
}
})
}

View File

@ -26,8 +26,7 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
}
}()
var lst []ModelFieldDetail
lst = make([]ModelFieldDetail, 0)
lst := make([]ModelFieldDetail, 0)
if !record.IsValid() {
return lst

View File

@ -1,7 +1,9 @@
package reflection
import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
@ -132,7 +134,7 @@ func findFieldByName(val reflect.Value, name string) any {
}
// Check if field name matches
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() {
if strings.EqualFold(field.Name, name) && fieldValue.CanInterface() {
return fieldValue.Interface()
}
}
@ -472,7 +474,7 @@ func IsColumnWritable(model any, columnName string) bool {
// isColumnWritableInType recursively searches for a column and checks if it's writable
// Returns (found, writable) where found indicates if the column was found
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) {
func isColumnWritableInType(typ reflect.Type, columnName string) (found bool, writable bool) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
@ -561,3 +563,290 @@ func isGormFieldReadOnly(tag string) bool {
}
return false
}
// ExtractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func ExtractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// ToSnakeCase converts a string from CamelCase to snake_case
func ToSnakeCase(s string) string {
var result strings.Builder
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// GetColumnTypeFromModel uses reflection to determine the Go type of a column in a model
func GetColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
if model == nil {
return reflect.Invalid
}
// Extract the source column name (remove JSON operators like ->> or ->)
sourceColName := ExtractSourceColumn(colName)
modelType := reflect.TypeOf(model)
// Dereference pointer if needed
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
// Ensure it's a struct
if modelType.Kind() != reflect.Struct {
return reflect.Invalid
}
// Find the field by JSON tag or field name
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Check JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" {
// Parse JSON tag (format: "name,omitempty")
parts := strings.Split(jsonTag, ",")
if parts[0] == sourceColName {
return field.Type.Kind()
}
}
// Check field name (case-insensitive)
if strings.EqualFold(field.Name, sourceColName) {
return field.Type.Kind()
}
// Check snake_case conversion
snakeCaseName := ToSnakeCase(field.Name)
if snakeCaseName == sourceColName {
return field.Type.Kind()
}
}
return reflect.Invalid
}
// IsNumericType checks if a reflect.Kind is a numeric type
func IsNumericType(kind reflect.Kind) bool {
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
}
// IsStringType checks if a reflect.Kind is a string type
func IsStringType(kind reflect.Kind) bool {
return kind == reflect.String
}
// IsNumericValue checks if a string value can be parsed as a number
func IsNumericValue(value string) bool {
value = strings.TrimSpace(value)
_, err := strconv.ParseFloat(value, 64)
return err == nil
}
// ConvertToNumericType converts a string value to the appropriate numeric type
func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
value = strings.TrimSpace(value)
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// Parse as integer
bitSize := 64
switch kind {
case reflect.Int8:
bitSize = 8
case reflect.Int16:
bitSize = 16
case reflect.Int32:
bitSize = 32
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Int:
return int(intVal), nil
case reflect.Int8:
return int8(intVal), nil
case reflect.Int16:
return int16(intVal), nil
case reflect.Int32:
return int32(intVal), nil
case reflect.Int64:
return intVal, nil
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
// Parse as unsigned integer
bitSize := 64
switch kind {
case reflect.Uint8:
bitSize = 8
case reflect.Uint16:
bitSize = 16
case reflect.Uint32:
bitSize = 32
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Uint:
return uint(uintVal), nil
case reflect.Uint8:
return uint8(uintVal), nil
case reflect.Uint16:
return uint16(uintVal), nil
case reflect.Uint32:
return uint32(uintVal), nil
case reflect.Uint64:
return uintVal, nil
}
case reflect.Float32, reflect.Float64:
// Parse as float
bitSize := 64
if kind == reflect.Float32 {
bitSize = 32
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid float value: %w", err)
}
if kind == reflect.Float32 {
return float32(floatVal), nil
}
return floatVal, nil
}
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
}
// GetRelationModel gets the model type for a relation field
// It searches for the field by name in the following order (case-insensitive):
// 1. Actual field name
// 2. Bun tag name (if exists)
// 3. Gorm tag name (if exists)
// 4. JSON tag name (if exists)
func GetRelationModel(model interface{}, fieldName string) interface{} {
if model == nil || fieldName == "" {
return nil
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nil
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil
}
// Find the field by checking in priority order (case-insensitive)
var field *reflect.StructField
normalizedFieldName := strings.ToLower(fieldName)
for i := 0; i < modelType.NumField(); i++ {
f := modelType.Field(i)
// 1. Check actual field name (case-insensitive)
if strings.EqualFold(f.Name, fieldName) {
field = &f
break
}
// 2. Check bun tag name
bunTag := f.Tag.Get("bun")
if bunTag != "" {
bunColName := ExtractColumnFromBunTag(bunTag)
if bunColName != "" && strings.EqualFold(bunColName, normalizedFieldName) {
field = &f
break
}
}
// 3. Check gorm tag name
gormTag := f.Tag.Get("gorm")
if gormTag != "" {
gormColName := ExtractColumnFromGormTag(gormTag)
if gormColName != "" && strings.EqualFold(gormColName, normalizedFieldName) {
field = &f
break
}
}
// 4. Check JSON tag name
jsonTag := f.Tag.Get("json")
if jsonTag != "" {
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
if strings.EqualFold(parts[0], normalizedFieldName) {
field = &f
break
}
}
}
}
if field == nil {
return nil
}
// Get the target type
targetType := field.Type
if targetType == nil {
return nil
}
if targetType.Kind() == reflect.Slice {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() != reflect.Struct {
return nil
}
// Create a zero value of the target type
return reflect.New(targetType).Elem().Interface()
}

View File

@ -1149,6 +1149,11 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
logger.Debug("Applying preload: %s", relationFieldName)
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
preload.Columns = reflection.GetSQLModelColumns(model)
}
// Handle column selection and omission
if len(preload.OmitColumns) > 0 {
allCols := reflection.GetSQLModelColumns(model)
// Remove omitted columns
@ -1204,7 +1209,10 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
}
if len(preload.Where) > 0 {
sq = sq.Where(preload.Where)
sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation)
if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere)
}
}
if preload.Limit != nil && *preload.Limit > 0 {

View File

@ -391,13 +391,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply custom SQL WHERE clause (AND condition)
if options.CustomSQLWhere != "" {
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
query = query.Where(options.CustomSQLWhere)
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, "")
if sanitizedWhere != "" {
query = query.Where(sanitizedWhere)
}
}
// Apply custom SQL WHERE clause (OR condition)
if options.CustomSQLOr != "" {
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
query = query.WhereOr(options.CustomSQLOr)
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, "")
if sanitizedOr != "" {
query = query.WhereOr(sanitizedOr)
}
}
// If ID is provided, filter by ID
@ -473,7 +481,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply cursor filter to query
if cursorFilter != "" {
logger.Debug("Applying cursor filter: %s", cursorFilter)
query = query.Where(cursorFilter)
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, "")
if sanitizedCursor != "" {
query = query.Where(sanitizedCursor)
}
}
}
@ -552,15 +563,16 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
// Apply the preload
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
// Get the related model for column operations
relatedModel := h.getRelationModel(model, preload.Relation)
relationParts := strings.Split(preload.Relation, ",")
relatedModel := reflection.GetRelationModel(model, relationParts[0])
if relatedModel == nil {
logger.Warn("Could not get related model for preload: %s", preload.Relation)
relatedModel = model // fallback to parent model
}
// relatedModel = model // fallback to parent model
} else {
// If we have computed columns but no explicit columns, populate with all model columns first
// since computed columns are additions
if len(preload.Columns) == 0 && len(preload.ComputedQL) > 0 && relatedModel != nil {
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
logger.Debug("Populating preload columns with all model columns since computed columns are additions")
preload.Columns = reflection.GetSQLModelColumns(relatedModel)
}
@ -581,8 +593,8 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
}
// Handle OmitColumns
if len(preload.OmitColumns) > 0 && relatedModel != nil {
allCols := reflection.GetModelColumns(relatedModel)
if len(preload.OmitColumns) > 0 {
allCols := preload.Columns
// Remove omitted columns
preload.Columns = []string{}
for _, col := range allCols {
@ -603,6 +615,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
if len(preload.Columns) > 0 {
sq = sq.Column(preload.Columns...)
}
}
// Apply filters
if len(preload.Filters) > 0 {
@ -620,7 +633,10 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
// Apply WHERE clause
if len(preload.Where) > 0 {
sq = sq.Where(preload.Where)
sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation)
if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere)
}
}
// Apply limit
@ -628,6 +644,10 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
sq = sq.Limit(*preload.Limit)
}
if preload.Offset != nil && *preload.Offset > 0 {
sq = sq.Offset(*preload.Offset)
}
return sq
})
@ -1312,7 +1332,7 @@ func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
func (h *Handler) extractNestedRelations(
data map[string]interface{},
model interface{},
) (map[string]interface{}, map[string]interface{}, error) {
) (_cleanedData map[string]interface{}, _relations map[string]interface{}, _err error) {
// Get model type for reflection
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
@ -1741,7 +1761,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
if data == nil {
return data
return nil
}
// Use reflection to check if data is a slice or array

View File

@ -10,6 +10,7 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ExtendedRequestOptions extends common.RequestOptions with additional features
@ -122,78 +123,77 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
// Merge headers and query parameters - query parameters take precedence
// This allows the same parameters to be specified in either headers or query string
// Normalize keys to lowercase to ensure query params properly override headers
combinedParams := make(map[string]string)
for key, value := range headers {
combinedParams[key] = value
combinedParams[strings.ToLower(key)] = value
}
for key, value := range queryParams {
combinedParams[key] = value
combinedParams[strings.ToLower(key)] = value
}
// Process each parameter (from both headers and query params)
// Note: keys are already normalized to lowercase in combinedParams
for key, value := range combinedParams {
// Normalize parameter key to lowercase for consistent matching
normalizedKey := strings.ToLower(key)
// Decode value if it's base64 encoded
decodedValue := decodeHeaderValue(value)
// Parse based on parameter prefix/name
switch {
// Field Selection
case strings.HasPrefix(normalizedKey, "x-select-fields"):
case strings.HasPrefix(key, "x-select-fields"):
h.parseSelectFields(&options, decodedValue)
case strings.HasPrefix(normalizedKey, "x-not-select-fields"):
case strings.HasPrefix(key, "x-not-select-fields"):
h.parseNotSelectFields(&options, decodedValue)
case strings.HasPrefix(normalizedKey, "x-clean-json"):
case strings.HasPrefix(key, "x-clean-json"):
options.CleanJSON = strings.EqualFold(decodedValue, "true")
// Filtering & Search
case strings.HasPrefix(normalizedKey, "x-fieldfilter-"):
h.parseFieldFilter(&options, normalizedKey, decodedValue)
case strings.HasPrefix(normalizedKey, "x-searchfilter-"):
h.parseSearchFilter(&options, normalizedKey, decodedValue)
case strings.HasPrefix(normalizedKey, "x-searchop-"):
h.parseSearchOp(&options, normalizedKey, decodedValue, "AND")
case strings.HasPrefix(normalizedKey, "x-searchor-"):
h.parseSearchOp(&options, normalizedKey, decodedValue, "OR")
case strings.HasPrefix(normalizedKey, "x-searchand-"):
h.parseSearchOp(&options, normalizedKey, decodedValue, "AND")
case strings.HasPrefix(normalizedKey, "x-searchcols"):
case strings.HasPrefix(key, "x-fieldfilter-"):
h.parseFieldFilter(&options, key, decodedValue)
case strings.HasPrefix(key, "x-searchfilter-"):
h.parseSearchFilter(&options, key, decodedValue)
case strings.HasPrefix(key, "x-searchop-"):
h.parseSearchOp(&options, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-searchor-"):
h.parseSearchOp(&options, key, decodedValue, "OR")
case strings.HasPrefix(key, "x-searchand-"):
h.parseSearchOp(&options, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-searchcols"):
options.SearchColumns = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(normalizedKey, "x-custom-sql-w"):
case strings.HasPrefix(key, "x-custom-sql-w"):
options.CustomSQLWhere = decodedValue
case strings.HasPrefix(normalizedKey, "x-custom-sql-or"):
case strings.HasPrefix(key, "x-custom-sql-or"):
options.CustomSQLOr = decodedValue
// Joins & Relations
case strings.HasPrefix(normalizedKey, "x-preload"):
if strings.HasSuffix(normalizedKey, "-where") {
case strings.HasPrefix(key, "x-preload"):
if strings.HasSuffix(key, "-where") {
continue
}
whereClaude := combinedParams[fmt.Sprintf("%s-where", key)]
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
case strings.HasPrefix(normalizedKey, "x-expand"):
case strings.HasPrefix(key, "x-expand"):
h.parseExpand(&options, decodedValue)
case strings.HasPrefix(normalizedKey, "x-custom-sql-join"):
case strings.HasPrefix(key, "x-custom-sql-join"):
// TODO: Implement custom SQL join
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
// Sorting & Pagination
case strings.HasPrefix(normalizedKey, "x-sort"):
case strings.HasPrefix(key, "x-sort"):
h.parseSorting(&options, decodedValue)
// Special cases for older clients using sort(a,b,-c) syntax
case strings.HasPrefix(normalizedKey, "sort(") && strings.Contains(normalizedKey, ")"):
sortValue := normalizedKey[strings.Index(normalizedKey, "(")+1 : strings.Index(normalizedKey, ")")]
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
h.parseSorting(&options, sortValue)
case strings.HasPrefix(normalizedKey, "x-limit"):
case strings.HasPrefix(key, "x-limit"):
if limit, err := strconv.Atoi(decodedValue); err == nil {
options.Limit = &limit
}
// Special cases for older clients using limit(n) syntax
case strings.HasPrefix(normalizedKey, "limit(") && strings.Contains(normalizedKey, ")"):
limitValue := normalizedKey[strings.Index(normalizedKey, "(")+1 : strings.Index(normalizedKey, ")")]
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
limitValueParts := strings.Split(limitValue, ",")
if len(limitValueParts) > 1 {
@ -209,42 +209,42 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
}
}
case strings.HasPrefix(normalizedKey, "x-offset"):
case strings.HasPrefix(key, "x-offset"):
if offset, err := strconv.Atoi(decodedValue); err == nil {
options.Offset = &offset
}
case strings.HasPrefix(normalizedKey, "x-cursor-forward"):
case strings.HasPrefix(key, "x-cursor-forward"):
options.CursorForward = decodedValue
case strings.HasPrefix(normalizedKey, "x-cursor-backward"):
case strings.HasPrefix(key, "x-cursor-backward"):
options.CursorBackward = decodedValue
// Advanced Features
case strings.HasPrefix(normalizedKey, "x-advsql-"):
colName := strings.TrimPrefix(normalizedKey, "x-advsql-")
case strings.HasPrefix(key, "x-advsql-"):
colName := strings.TrimPrefix(key, "x-advsql-")
options.AdvancedSQL[colName] = decodedValue
case strings.HasPrefix(normalizedKey, "x-cql-sel-"):
colName := strings.TrimPrefix(normalizedKey, "x-cql-sel-")
case strings.HasPrefix(key, "x-cql-sel-"):
colName := strings.TrimPrefix(key, "x-cql-sel-")
options.ComputedQL[colName] = decodedValue
case strings.HasPrefix(normalizedKey, "x-distinct"):
case strings.HasPrefix(key, "x-distinct"):
options.Distinct = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(normalizedKey, "x-skipcount"):
case strings.HasPrefix(key, "x-skipcount"):
options.SkipCount = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(normalizedKey, "x-skipcache"):
case strings.HasPrefix(key, "x-skipcache"):
options.SkipCache = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(normalizedKey, "x-fetch-rownumber"):
case strings.HasPrefix(key, "x-fetch-rownumber"):
options.FetchRowNumber = &decodedValue
case strings.HasPrefix(normalizedKey, "x-pkrow"):
case strings.HasPrefix(key, "x-pkrow"):
options.PKRow = &decodedValue
// Response Format
case strings.HasPrefix(normalizedKey, "x-simpleapi"):
case strings.HasPrefix(key, "x-simpleapi"):
options.ResponseFormat = "simple"
case strings.HasPrefix(normalizedKey, "x-detailapi"):
case strings.HasPrefix(key, "x-detailapi"):
options.ResponseFormat = "detail"
case strings.HasPrefix(normalizedKey, "x-syncfusion"):
case strings.HasPrefix(key, "x-syncfusion"):
options.ResponseFormat = "syncfusion"
case strings.HasPrefix(normalizedKey, "x-single-record-as-object"):
case strings.HasPrefix(key, "x-single-record-as-object"):
// Parse as boolean - "false" disables, "true" enables (default is true)
if strings.EqualFold(decodedValue, "false") {
options.SingleRecordAsObject = false
@ -253,11 +253,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
}
// Transaction Control
case strings.HasPrefix(normalizedKey, "x-transaction-atomic"):
case strings.HasPrefix(key, "x-transaction-atomic"):
options.AtomicTransaction = strings.EqualFold(decodedValue, "true")
// X-Files - comprehensive JSON configuration
case strings.HasPrefix(normalizedKey, "x-files"):
case strings.HasPrefix(key, "x-files"):
h.parseXFiles(&options, decodedValue)
}
}
@ -720,7 +720,7 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
// Try to get the model type for the next level
// This allows nested resolution
if nextModel := h.getRelationModel(currentModel, resolvedPart); nextModel != nil {
if nextModel := reflection.GetRelationModel(currentModel, resolvedPart); nextModel != nil {
currentModel = nextModel
}
}
@ -744,58 +744,6 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
}
}
// getRelationModel gets the model type for a relation field
func (h *Handler) getRelationModel(model interface{}, fieldName string) interface{} {
if model == nil || fieldName == "" {
return nil
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nil
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil
}
// Find the field
field, found := modelType.FieldByName(fieldName)
if !found {
return nil
}
// Get the target type
targetType := field.Type
if targetType == nil {
return nil
}
if targetType.Kind() == reflect.Slice {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() != reflect.Struct {
return nil
}
// Create a zero value of the target type
return reflect.New(targetType).Elem().Interface()
}
// resolveRelationName resolves a relation name or table name to the actual field name in the model
// If the input is already a field name, it returns it as-is
// If the input is a table name, it looks up the corresponding relation field
@ -983,192 +931,6 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
}
}
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func extractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
if model == nil {
return reflect.Invalid
}
// Extract the source column name (remove JSON operators like ->> or ->)
sourceColName := extractSourceColumn(colName)
modelType := reflect.TypeOf(model)
// Dereference pointer if needed
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
// Ensure it's a struct
if modelType.Kind() != reflect.Struct {
return reflect.Invalid
}
// Find the field by JSON tag or field name
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Check JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" {
// Parse JSON tag (format: "name,omitempty")
parts := strings.Split(jsonTag, ",")
if parts[0] == sourceColName {
return field.Type.Kind()
}
}
// Check field name (case-insensitive)
if strings.EqualFold(field.Name, sourceColName) {
return field.Type.Kind()
}
// Check snake_case conversion
snakeCaseName := toSnakeCase(field.Name)
if snakeCaseName == sourceColName {
return field.Type.Kind()
}
}
return reflect.Invalid
}
// toSnakeCase converts a string from CamelCase to snake_case
func toSnakeCase(s string) string {
var result strings.Builder
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// isNumericType checks if a reflect.Kind is a numeric type
func isNumericType(kind reflect.Kind) bool {
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
}
// isStringType checks if a reflect.Kind is a string type
func isStringType(kind reflect.Kind) bool {
return kind == reflect.String
}
// convertToNumericType converts a string value to the appropriate numeric type
func convertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
value = strings.TrimSpace(value)
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// Parse as integer
bitSize := 64
switch kind {
case reflect.Int8:
bitSize = 8
case reflect.Int16:
bitSize = 16
case reflect.Int32:
bitSize = 32
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Int:
return int(intVal), nil
case reflect.Int8:
return int8(intVal), nil
case reflect.Int16:
return int16(intVal), nil
case reflect.Int32:
return int32(intVal), nil
case reflect.Int64:
return intVal, nil
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
// Parse as unsigned integer
bitSize := 64
switch kind {
case reflect.Uint8:
bitSize = 8
case reflect.Uint16:
bitSize = 16
case reflect.Uint32:
bitSize = 32
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Uint:
return uint(uintVal), nil
case reflect.Uint8:
return uint8(uintVal), nil
case reflect.Uint16:
return uint16(uintVal), nil
case reflect.Uint32:
return uint32(uintVal), nil
case reflect.Uint64:
return uintVal, nil
}
case reflect.Float32, reflect.Float64:
// Parse as float
bitSize := 64
if kind == reflect.Float32 {
bitSize = 32
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid float value: %w", err)
}
if kind == reflect.Float32 {
return float32(floatVal), nil
}
return floatVal, nil
}
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
}
// isNumericValue checks if a string value can be parsed as a number
func isNumericValue(value string) bool {
value = strings.TrimSpace(value)
_, err := strconv.ParseFloat(value, 64)
return err == nil
}
// ColumnCastInfo holds information about whether a column needs casting
type ColumnCastInfo struct {
NeedsCast bool
@ -1182,7 +944,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
}
colType := h.getColumnTypeFromModel(model, filter.Column)
colType := reflection.GetColumnTypeFromModel(model, filter.Column)
if colType == reflect.Invalid {
// Column not found in model, no casting needed
logger.Debug("Column %s not found in model, skipping type validation", filter.Column)
@ -1193,18 +955,18 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
valueIsNumeric := false
if strVal, ok := filter.Value.(string); ok {
strVal = strings.Trim(strVal, "%")
valueIsNumeric = isNumericValue(strVal)
valueIsNumeric = reflection.IsNumericValue(strVal)
}
// Adjust based on column type
switch {
case isNumericType(colType):
case reflection.IsNumericType(colType):
// Column is numeric
if valueIsNumeric {
// Value is numeric - try to convert it
if strVal, ok := filter.Value.(string); ok {
strVal = strings.Trim(strVal, "%")
numericVal, err := convertToNumericType(strVal, colType)
numericVal, err := reflection.ConvertToNumericType(strVal, colType)
if err != nil {
logger.Debug("Failed to convert value '%s' to numeric type for column %s, will use text cast", strVal, filter.Column)
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
@ -1219,7 +981,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
}
case isStringType(colType):
case reflection.IsStringType(colType):
// String columns don't need casting
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}