mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-02 01:44:25 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
68dee78a34 | ||
|
|
efb9e5d9d5 | ||
|
|
490ae37c6d | ||
|
|
99307e31e6 |
@@ -48,21 +48,42 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Log the type being scanned into
|
// Log the type being scanned into
|
||||||
logger.Debug("Debug scan into type: %s (kind: %s)", v.Type().Name(), v.Kind())
|
typeName := v.Type().String()
|
||||||
|
logger.Debug("Debug scan into type: %s (kind: %s)", typeName, v.Kind())
|
||||||
|
|
||||||
// If it's a struct, log all field types
|
// Handle slice types - inspect the element type
|
||||||
if v.Kind() == reflect.Struct {
|
var structType reflect.Type
|
||||||
for i := 0; i < v.NumField(); i++ {
|
if v.Kind() == reflect.Slice {
|
||||||
field := v.Type().Field(i)
|
elemType := v.Type().Elem()
|
||||||
fieldValue := v.Field(i)
|
logger.Debug(" Slice element type: %s", elemType)
|
||||||
|
|
||||||
|
// If slice of pointers, get the underlying type
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
structType = elemType.Elem()
|
||||||
|
} else {
|
||||||
|
structType = elemType
|
||||||
|
}
|
||||||
|
} else if v.Kind() == reflect.Struct {
|
||||||
|
structType = v.Type()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have a struct type, log all its fields
|
||||||
|
if structType != nil && structType.Kind() == reflect.Struct {
|
||||||
|
logger.Debug(" Struct %s has %d fields:", structType.Name(), structType.NumField())
|
||||||
|
for i := 0; i < structType.NumField(); i++ {
|
||||||
|
field := structType.Field(i)
|
||||||
|
|
||||||
// Log embedded fields specially
|
// Log embedded fields specially
|
||||||
if field.Anonymous {
|
if field.Anonymous {
|
||||||
logger.Debug(" Embedded field [%d]: %s (type: %s, kind: %s)",
|
logger.Debug(" [%d] EMBEDDED: %s (type: %s, kind: %s, bun:%q)",
|
||||||
i, field.Name, field.Type, fieldValue.Kind())
|
i, field.Name, field.Type, field.Type.Kind(), field.Tag.Get("bun"))
|
||||||
} else {
|
} else {
|
||||||
logger.Debug(" Field [%d]: %s (type: %s, kind: %s, tag: %s)",
|
bunTag := field.Tag.Get("bun")
|
||||||
i, field.Name, field.Type, fieldValue.Kind(), field.Tag.Get("bun"))
|
if bunTag == "" {
|
||||||
|
bunTag = "(no tag)"
|
||||||
|
}
|
||||||
|
logger.Debug(" [%d] %s (type: %s, kind: %s, bun:%q)",
|
||||||
|
i, field.Name, field.Type, field.Type.Kind(), bunTag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -430,7 +430,45 @@ func extractTableAndColumn(cond string) (table string, column string) {
|
|||||||
// Remove any quotes
|
// Remove any quotes
|
||||||
columnRef = strings.Trim(columnRef, "`\"'")
|
columnRef = strings.Trim(columnRef, "`\"'")
|
||||||
|
|
||||||
// Check if it contains a dot (qualified reference)
|
// Check if there's a function call (contains opening parenthesis)
|
||||||
|
openParenIdx := strings.Index(columnRef, "(")
|
||||||
|
|
||||||
|
if openParenIdx >= 0 {
|
||||||
|
// There's a function call - find the FIRST dot after the opening paren
|
||||||
|
// This handles cases like: ifblnk(users.status, orders.status) - extracts users.status
|
||||||
|
dotIdx := strings.Index(columnRef[openParenIdx:], ".")
|
||||||
|
if dotIdx > 0 {
|
||||||
|
dotIdx += openParenIdx // Adjust to absolute position
|
||||||
|
|
||||||
|
// Extract table name (between paren and dot)
|
||||||
|
// Find the last opening paren before this dot
|
||||||
|
lastOpenParen := strings.LastIndex(columnRef[:dotIdx], "(")
|
||||||
|
table = columnRef[lastOpenParen+1 : dotIdx]
|
||||||
|
|
||||||
|
// Find the column name - it ends at comma, closing paren, whitespace, or end of string
|
||||||
|
columnStart := dotIdx + 1
|
||||||
|
columnEnd := len(columnRef)
|
||||||
|
|
||||||
|
for i := columnStart; i < len(columnRef); i++ {
|
||||||
|
ch := columnRef[i]
|
||||||
|
if ch == ',' || ch == ')' || ch == ' ' || ch == '\t' {
|
||||||
|
columnEnd = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
column = columnRef[columnStart:columnEnd]
|
||||||
|
|
||||||
|
// Remove quotes from table and column if present
|
||||||
|
table = strings.Trim(table, "`\"'")
|
||||||
|
column = strings.Trim(column, "`\"'")
|
||||||
|
|
||||||
|
return table, column
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No function call - check if it contains a dot (qualified reference)
|
||||||
|
// Use LastIndex to handle schema.table.column properly
|
||||||
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
|
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
|
||||||
table = columnRef[:dotIdx]
|
table = columnRef[:dotIdx]
|
||||||
column = columnRef[dotIdx+1:]
|
column = columnRef[dotIdx+1:]
|
||||||
|
|||||||
@@ -286,6 +286,48 @@ func TestExtractTableAndColumn(t *testing.T) {
|
|||||||
expectedTable: "",
|
expectedTable: "",
|
||||||
expectedCol: "",
|
expectedCol: "",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "function call with table.column - ifblnk",
|
||||||
|
input: "ifblnk(users.status,0) in (1,2,3,4)",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "status",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function call with table.column - coalesce",
|
||||||
|
input: "coalesce(users.age, 0) = 25",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "age",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested function calls",
|
||||||
|
input: "upper(trim(users.name)) = 'JOHN'",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function with multiple args and table.column",
|
||||||
|
input: "substring(users.email, 1, 5) = 'admin'",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "email",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cast function with table.column",
|
||||||
|
input: "cast(orders.total as decimal) > 100",
|
||||||
|
expectedTable: "orders",
|
||||||
|
expectedCol: "total",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex nested functions",
|
||||||
|
input: "coalesce(nullif(users.status, ''), 'default') = 'active'",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "status",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function with multiple table.column refs (extracts first)",
|
||||||
|
input: "greatest(users.created_at, users.updated_at) > '2024-01-01'",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "created_at",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -352,6 +394,14 @@ func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expected: "users.status = 'active' AND Department.name = 'Engineering'",
|
expected: "users.status = 'active' AND Department.name = 'Engineering'",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "Function Call with correct table prefix - unchanged",
|
||||||
|
where: "ifblnk(users.status,0) in (1,2,3,4)",
|
||||||
|
tableName: "users",
|
||||||
|
options: nil,
|
||||||
|
expected: "ifblnk(users.status,0) in (1,2,3,4)",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "no options provided - works as before",
|
name: "no options provided - works as before",
|
||||||
where: "wrong_table.status = 'active'",
|
where: "wrong_table.status = 'active'",
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
|
|
||||||
// Validate and filter columns in options (log warnings for invalid columns)
|
// Validate and filter columns in options (log warnings for invalid columns)
|
||||||
validator := common.NewColumnValidator(model)
|
validator := common.NewColumnValidator(model)
|
||||||
options = filterExtendedOptions(validator, options)
|
options = h.filterExtendedOptions(validator, options, model)
|
||||||
|
|
||||||
// Add request-scoped data to context (including options)
|
// Add request-scoped data to context (including options)
|
||||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
||||||
@@ -2241,7 +2241,7 @@ func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
|
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
|
||||||
func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions {
|
func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions, model interface{}) ExtendedRequestOptions {
|
||||||
filtered := options
|
filtered := options
|
||||||
|
|
||||||
// Filter base RequestOptions
|
// Filter base RequestOptions
|
||||||
@@ -2265,12 +2265,30 @@ func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRe
|
|||||||
// No filtering needed for ComputedQL keys
|
// No filtering needed for ComputedQL keys
|
||||||
filtered.ComputedQL = options.ComputedQL
|
filtered.ComputedQL = options.ComputedQL
|
||||||
|
|
||||||
// Filter Expand columns
|
// Filter Expand columns using the expand relation's model
|
||||||
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
|
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
for _, expand := range options.Expand {
|
for _, expand := range options.Expand {
|
||||||
filteredExpand := expand
|
filteredExpand := expand
|
||||||
// Don't validate relation name, only columns
|
|
||||||
filteredExpand.Columns = validator.FilterValidColumns(expand.Columns)
|
// Get the relationship info for this expand relation
|
||||||
|
relInfo := h.getRelationshipInfo(modelType, expand.Relation)
|
||||||
|
if relInfo != nil && relInfo.relatedModel != nil {
|
||||||
|
// Create a validator for the related model
|
||||||
|
expandValidator := common.NewColumnValidator(relInfo.relatedModel)
|
||||||
|
// Filter columns using the related model's validator
|
||||||
|
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
|
||||||
|
} else {
|
||||||
|
// If we can't find the relationship, log a warning and skip column filtering
|
||||||
|
logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation)
|
||||||
|
// Keep the columns as-is if we can't validate them
|
||||||
|
filteredExpand.Columns = expand.Columns
|
||||||
|
}
|
||||||
|
|
||||||
filteredExpands = append(filteredExpands, filteredExpand)
|
filteredExpands = append(filteredExpands, filteredExpand)
|
||||||
}
|
}
|
||||||
filtered.Expand = filteredExpands
|
filtered.Expand = filteredExpands
|
||||||
|
|||||||
Reference in New Issue
Block a user