Lots of refactoring, Fixes to preloads

This commit is contained in:
Hein
2025-11-21 10:17:20 +02:00
parent 7853a3f56a
commit 7e76977dcc
12 changed files with 563 additions and 386 deletions

View File

@@ -10,6 +10,7 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ExtendedRequestOptions extends common.RequestOptions with additional features
@@ -122,78 +123,77 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
// Merge headers and query parameters - query parameters take precedence
// This allows the same parameters to be specified in either headers or query string
// Normalize keys to lowercase to ensure query params properly override headers
combinedParams := make(map[string]string)
for key, value := range headers {
combinedParams[key] = value
combinedParams[strings.ToLower(key)] = value
}
for key, value := range queryParams {
combinedParams[key] = value
combinedParams[strings.ToLower(key)] = value
}
// Process each parameter (from both headers and query params)
// Note: keys are already normalized to lowercase in combinedParams
for key, value := range combinedParams {
// Normalize parameter key to lowercase for consistent matching
normalizedKey := strings.ToLower(key)
// Decode value if it's base64 encoded
decodedValue := decodeHeaderValue(value)
// Parse based on parameter prefix/name
switch {
// Field Selection
case strings.HasPrefix(normalizedKey, "x-select-fields"):
case strings.HasPrefix(key, "x-select-fields"):
h.parseSelectFields(&options, decodedValue)
case strings.HasPrefix(normalizedKey, "x-not-select-fields"):
case strings.HasPrefix(key, "x-not-select-fields"):
h.parseNotSelectFields(&options, decodedValue)
case strings.HasPrefix(normalizedKey, "x-clean-json"):
case strings.HasPrefix(key, "x-clean-json"):
options.CleanJSON = strings.EqualFold(decodedValue, "true")
// Filtering & Search
case strings.HasPrefix(normalizedKey, "x-fieldfilter-"):
h.parseFieldFilter(&options, normalizedKey, decodedValue)
case strings.HasPrefix(normalizedKey, "x-searchfilter-"):
h.parseSearchFilter(&options, normalizedKey, decodedValue)
case strings.HasPrefix(normalizedKey, "x-searchop-"):
h.parseSearchOp(&options, normalizedKey, decodedValue, "AND")
case strings.HasPrefix(normalizedKey, "x-searchor-"):
h.parseSearchOp(&options, normalizedKey, decodedValue, "OR")
case strings.HasPrefix(normalizedKey, "x-searchand-"):
h.parseSearchOp(&options, normalizedKey, decodedValue, "AND")
case strings.HasPrefix(normalizedKey, "x-searchcols"):
case strings.HasPrefix(key, "x-fieldfilter-"):
h.parseFieldFilter(&options, key, decodedValue)
case strings.HasPrefix(key, "x-searchfilter-"):
h.parseSearchFilter(&options, key, decodedValue)
case strings.HasPrefix(key, "x-searchop-"):
h.parseSearchOp(&options, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-searchor-"):
h.parseSearchOp(&options, key, decodedValue, "OR")
case strings.HasPrefix(key, "x-searchand-"):
h.parseSearchOp(&options, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-searchcols"):
options.SearchColumns = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(normalizedKey, "x-custom-sql-w"):
case strings.HasPrefix(key, "x-custom-sql-w"):
options.CustomSQLWhere = decodedValue
case strings.HasPrefix(normalizedKey, "x-custom-sql-or"):
case strings.HasPrefix(key, "x-custom-sql-or"):
options.CustomSQLOr = decodedValue
// Joins & Relations
case strings.HasPrefix(normalizedKey, "x-preload"):
if strings.HasSuffix(normalizedKey, "-where") {
case strings.HasPrefix(key, "x-preload"):
if strings.HasSuffix(key, "-where") {
continue
}
whereClaude := combinedParams[fmt.Sprintf("%s-where", key)]
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
case strings.HasPrefix(normalizedKey, "x-expand"):
case strings.HasPrefix(key, "x-expand"):
h.parseExpand(&options, decodedValue)
case strings.HasPrefix(normalizedKey, "x-custom-sql-join"):
case strings.HasPrefix(key, "x-custom-sql-join"):
// TODO: Implement custom SQL join
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
// Sorting & Pagination
case strings.HasPrefix(normalizedKey, "x-sort"):
case strings.HasPrefix(key, "x-sort"):
h.parseSorting(&options, decodedValue)
//Special cases for older clients using sort(a,b,-c) syntax
case strings.HasPrefix(normalizedKey, "sort(") && strings.Contains(normalizedKey, ")"):
sortValue := normalizedKey[strings.Index(normalizedKey, "(")+1 : strings.Index(normalizedKey, ")")]
// Special cases for older clients using sort(a,b,-c) syntax
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
h.parseSorting(&options, sortValue)
case strings.HasPrefix(normalizedKey, "x-limit"):
case strings.HasPrefix(key, "x-limit"):
if limit, err := strconv.Atoi(decodedValue); err == nil {
options.Limit = &limit
}
//Special cases for older clients using limit(n) syntax
case strings.HasPrefix(normalizedKey, "limit(") && strings.Contains(normalizedKey, ")"):
limitValue := normalizedKey[strings.Index(normalizedKey, "(")+1 : strings.Index(normalizedKey, ")")]
// Special cases for older clients using limit(n) syntax
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
limitValueParts := strings.Split(limitValue, ",")
if len(limitValueParts) > 1 {
@@ -209,42 +209,42 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
}
}
case strings.HasPrefix(normalizedKey, "x-offset"):
case strings.HasPrefix(key, "x-offset"):
if offset, err := strconv.Atoi(decodedValue); err == nil {
options.Offset = &offset
}
case strings.HasPrefix(normalizedKey, "x-cursor-forward"):
case strings.HasPrefix(key, "x-cursor-forward"):
options.CursorForward = decodedValue
case strings.HasPrefix(normalizedKey, "x-cursor-backward"):
case strings.HasPrefix(key, "x-cursor-backward"):
options.CursorBackward = decodedValue
// Advanced Features
case strings.HasPrefix(normalizedKey, "x-advsql-"):
colName := strings.TrimPrefix(normalizedKey, "x-advsql-")
case strings.HasPrefix(key, "x-advsql-"):
colName := strings.TrimPrefix(key, "x-advsql-")
options.AdvancedSQL[colName] = decodedValue
case strings.HasPrefix(normalizedKey, "x-cql-sel-"):
colName := strings.TrimPrefix(normalizedKey, "x-cql-sel-")
case strings.HasPrefix(key, "x-cql-sel-"):
colName := strings.TrimPrefix(key, "x-cql-sel-")
options.ComputedQL[colName] = decodedValue
case strings.HasPrefix(normalizedKey, "x-distinct"):
case strings.HasPrefix(key, "x-distinct"):
options.Distinct = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(normalizedKey, "x-skipcount"):
case strings.HasPrefix(key, "x-skipcount"):
options.SkipCount = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(normalizedKey, "x-skipcache"):
case strings.HasPrefix(key, "x-skipcache"):
options.SkipCache = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(normalizedKey, "x-fetch-rownumber"):
case strings.HasPrefix(key, "x-fetch-rownumber"):
options.FetchRowNumber = &decodedValue
case strings.HasPrefix(normalizedKey, "x-pkrow"):
case strings.HasPrefix(key, "x-pkrow"):
options.PKRow = &decodedValue
// Response Format
case strings.HasPrefix(normalizedKey, "x-simpleapi"):
case strings.HasPrefix(key, "x-simpleapi"):
options.ResponseFormat = "simple"
case strings.HasPrefix(normalizedKey, "x-detailapi"):
case strings.HasPrefix(key, "x-detailapi"):
options.ResponseFormat = "detail"
case strings.HasPrefix(normalizedKey, "x-syncfusion"):
case strings.HasPrefix(key, "x-syncfusion"):
options.ResponseFormat = "syncfusion"
case strings.HasPrefix(normalizedKey, "x-single-record-as-object"):
case strings.HasPrefix(key, "x-single-record-as-object"):
// Parse as boolean - "false" disables, "true" enables (default is true)
if strings.EqualFold(decodedValue, "false") {
options.SingleRecordAsObject = false
@@ -253,11 +253,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
}
// Transaction Control
case strings.HasPrefix(normalizedKey, "x-transaction-atomic"):
case strings.HasPrefix(key, "x-transaction-atomic"):
options.AtomicTransaction = strings.EqualFold(decodedValue, "true")
// X-Files - comprehensive JSON configuration
case strings.HasPrefix(normalizedKey, "x-files"):
case strings.HasPrefix(key, "x-files"):
h.parseXFiles(&options, decodedValue)
}
}
@@ -720,7 +720,7 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
// Try to get the model type for the next level
// This allows nested resolution
if nextModel := h.getRelationModel(currentModel, resolvedPart); nextModel != nil {
if nextModel := reflection.GetRelationModel(currentModel, resolvedPart); nextModel != nil {
currentModel = nextModel
}
}
@@ -744,58 +744,6 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
}
}
// getRelationModel gets the model type for a relation field
func (h *Handler) getRelationModel(model interface{}, fieldName string) interface{} {
if model == nil || fieldName == "" {
return nil
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nil
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil
}
// Find the field
field, found := modelType.FieldByName(fieldName)
if !found {
return nil
}
// Get the target type
targetType := field.Type
if targetType == nil {
return nil
}
if targetType.Kind() == reflect.Slice {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() != reflect.Struct {
return nil
}
// Create a zero value of the target type
return reflect.New(targetType).Elem().Interface()
}
// resolveRelationName resolves a relation name or table name to the actual field name in the model
// If the input is already a field name, it returns it as-is
// If the input is a table name, it looks up the corresponding relation field
@@ -983,192 +931,6 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
}
}
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func extractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
if model == nil {
return reflect.Invalid
}
// Extract the source column name (remove JSON operators like ->> or ->)
sourceColName := extractSourceColumn(colName)
modelType := reflect.TypeOf(model)
// Dereference pointer if needed
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
// Ensure it's a struct
if modelType.Kind() != reflect.Struct {
return reflect.Invalid
}
// Find the field by JSON tag or field name
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Check JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" {
// Parse JSON tag (format: "name,omitempty")
parts := strings.Split(jsonTag, ",")
if parts[0] == sourceColName {
return field.Type.Kind()
}
}
// Check field name (case-insensitive)
if strings.EqualFold(field.Name, sourceColName) {
return field.Type.Kind()
}
// Check snake_case conversion
snakeCaseName := toSnakeCase(field.Name)
if snakeCaseName == sourceColName {
return field.Type.Kind()
}
}
return reflect.Invalid
}
// toSnakeCase converts a string from CamelCase to snake_case
func toSnakeCase(s string) string {
var result strings.Builder
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// isNumericType checks if a reflect.Kind is a numeric type
func isNumericType(kind reflect.Kind) bool {
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
}
// isStringType checks if a reflect.Kind is a string type
func isStringType(kind reflect.Kind) bool {
return kind == reflect.String
}
// convertToNumericType converts a string value to the appropriate numeric type
func convertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
value = strings.TrimSpace(value)
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// Parse as integer
bitSize := 64
switch kind {
case reflect.Int8:
bitSize = 8
case reflect.Int16:
bitSize = 16
case reflect.Int32:
bitSize = 32
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Int:
return int(intVal), nil
case reflect.Int8:
return int8(intVal), nil
case reflect.Int16:
return int16(intVal), nil
case reflect.Int32:
return int32(intVal), nil
case reflect.Int64:
return intVal, nil
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
// Parse as unsigned integer
bitSize := 64
switch kind {
case reflect.Uint8:
bitSize = 8
case reflect.Uint16:
bitSize = 16
case reflect.Uint32:
bitSize = 32
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Uint:
return uint(uintVal), nil
case reflect.Uint8:
return uint8(uintVal), nil
case reflect.Uint16:
return uint16(uintVal), nil
case reflect.Uint32:
return uint32(uintVal), nil
case reflect.Uint64:
return uintVal, nil
}
case reflect.Float32, reflect.Float64:
// Parse as float
bitSize := 64
if kind == reflect.Float32 {
bitSize = 32
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid float value: %w", err)
}
if kind == reflect.Float32 {
return float32(floatVal), nil
}
return floatVal, nil
}
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
}
// isNumericValue checks if a string value can be parsed as a number
func isNumericValue(value string) bool {
value = strings.TrimSpace(value)
_, err := strconv.ParseFloat(value, 64)
return err == nil
}
// ColumnCastInfo holds information about whether a column needs casting
type ColumnCastInfo struct {
NeedsCast bool
@@ -1182,7 +944,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
}
colType := h.getColumnTypeFromModel(model, filter.Column)
colType := reflection.GetColumnTypeFromModel(model, filter.Column)
if colType == reflect.Invalid {
// Column not found in model, no casting needed
logger.Debug("Column %s not found in model, skipping type validation", filter.Column)
@@ -1193,18 +955,18 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
valueIsNumeric := false
if strVal, ok := filter.Value.(string); ok {
strVal = strings.Trim(strVal, "%")
valueIsNumeric = isNumericValue(strVal)
valueIsNumeric = reflection.IsNumericValue(strVal)
}
// Adjust based on column type
switch {
case isNumericType(colType):
case reflection.IsNumericType(colType):
// Column is numeric
if valueIsNumeric {
// Value is numeric - try to convert it
if strVal, ok := filter.Value.(string); ok {
strVal = strings.Trim(strVal, "%")
numericVal, err := convertToNumericType(strVal, colType)
numericVal, err := reflection.ConvertToNumericType(strVal, colType)
if err != nil {
logger.Debug("Failed to convert value '%s' to numeric type for column %s, will use text cast", strVal, filter.Column)
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
@@ -1219,7 +981,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
}
case isStringType(colType):
case reflection.IsStringType(colType):
// String columns don't need casting
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}