Compare commits

...

4 Commits

Author SHA1 Message Date
Hein
76bbf33db2 Fixed SingleRecordAsObject true when handleRead with no id 2025-11-21 11:49:08 +02:00
Hein
02c9b96b0c Better SanitizeWhereClause 2025-11-21 11:42:01 +02:00
Hein
9a3564f05f SanitizeWhereClause with tablename on handlers. 2025-11-21 11:00:44 +02:00
Hein
a931b8cdd2 Better preloads 2025-11-21 10:41:58 +02:00
8 changed files with 449 additions and 15 deletions

View File

@@ -5,6 +5,8 @@ import (
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
@@ -135,6 +137,15 @@ func SanitizeWhereClause(where string, tableName string) string {
where = strings.TrimSpace(where)
// Strip outer parentheses and re-trim
where = stripOuterParentheses(where)
// Get valid columns from the model if tableName is provided
var validColumns map[string]bool
if tableName != "" {
validColumns = getValidColumnsForTable(tableName)
}
// Split by AND to handle multiple conditions
conditions := splitByAND(where)
@@ -146,22 +157,32 @@ func SanitizeWhereClause(where string, tableName string) string {
continue
}
// Strip parentheses from the condition before checking
condToCheck := stripOuterParentheses(cond)
// Skip trivial conditions that always evaluate to true
if IsTrivialCondition(cond) {
if IsTrivialCondition(condToCheck) {
logger.Debug("Removing trivial condition: '%s'", cond)
continue
}
// If tableName is provided and the condition doesn't already have a table prefix,
// attempt to add it
if tableName != "" && !hasTablePrefix(cond) {
if tableName != "" && !hasTablePrefix(condToCheck) {
// Check if this is a SQL expression/literal that shouldn't be prefixed
if !IsSQLExpression(strings.ToLower(cond)) {
if !IsSQLExpression(strings.ToLower(condToCheck)) {
// Extract the column name and prefix it
columnName := ExtractColumnName(cond)
columnName := ExtractColumnName(condToCheck)
if columnName != "" {
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
logger.Debug("Prefixed column in condition: '%s'", cond)
// Only prefix if this is a valid column in the model
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace in the original condition (without stripped parens)
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
logger.Debug("Prefixed column in condition: '%s'", cond)
} else {
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
}
}
}
}
@@ -182,6 +203,43 @@ func SanitizeWhereClause(where string, tableName string) string {
return result
}
// stripOuterParentheses removes matching outer parentheses from a string
// It handles nested parentheses correctly
func stripOuterParentheses(s string) string {
s = strings.TrimSpace(s)
for {
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
return s
}
// Check if these parentheses match (i.e., they're the outermost pair)
depth := 0
matched := false
for i := 0; i < len(s); i++ {
switch s[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 && i == len(s)-1 {
matched = true
} else if depth == 0 {
// Found a closing paren before the end, so outer parens don't match
return s
}
}
}
if !matched {
return s
}
// Strip the outer parentheses and continue
s = strings.TrimSpace(s[1 : len(s)-1])
}
}
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
// This is a simple split that doesn't handle nested parentheses or complex expressions
func splitByAND(where string) []string {
@@ -245,3 +303,38 @@ func IsSQLKeyword(word string) bool {
}
return false
}
// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry
// Returns a map of column names for fast lookup, or nil if the model is not found
func getValidColumnsForTable(tableName string) map[string]bool {
// Try to get the model from the registry
model, err := modelregistry.GetModelByName(tableName)
if err != nil {
// Model not found, return nil to indicate we should use fallback behavior
return nil
}
// Get SQL columns from the model
columns := reflection.GetSQLModelColumns(model)
if len(columns) == 0 {
// No columns found, return nil
return nil
}
// Build a map for fast lookup
columnMap := make(map[string]bool, len(columns))
for _, col := range columns {
columnMap[strings.ToLower(col)] = true
}
return columnMap
}
// isValidColumn checks if a column name exists in the valid columns map
// Handles case-insensitive comparison
func isValidColumn(columnName string, validColumns map[string]bool) bool {
if validColumns == nil {
return true // No model info, assume valid
}
return validColumns[strings.ToLower(columnName)]
}

View File

@@ -0,0 +1,224 @@
package common
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
func TestSanitizeWhereClause(t *testing.T) {
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "trivial conditions in parentheses",
where: "(true AND true AND true)",
tableName: "mastertask",
expected: "",
},
{
name: "trivial conditions without parentheses",
where: "true AND true AND true",
tableName: "mastertask",
expected: "",
},
{
name: "single trivial condition",
where: "true",
tableName: "mastertask",
expected: "",
},
{
name: "valid condition with parentheses",
where: "(status = 'active')",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "mixed trivial and valid conditions",
where: "true AND status = 'active' AND 1=1",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "condition already with table prefix",
where: "users.status = 'active'",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "multiple valid conditions",
where: "status = 'active' AND age > 18",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "no table name provided",
where: "status = 'active'",
tableName: "",
expected: "status = 'active'",
},
{
name: "empty where clause",
where: "",
tableName: "users",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}
func TestStripOuterParentheses(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "single level parentheses",
input: "(true)",
expected: "true",
},
{
name: "multiple levels",
input: "((true))",
expected: "true",
},
{
name: "no parentheses",
input: "true",
expected: "true",
},
{
name: "mismatched parentheses",
input: "(true",
expected: "(true",
},
{
name: "complex expression",
input: "(a AND b)",
expected: "a AND b",
},
{
name: "nested but not outer",
input: "(a AND (b OR c)) AND d",
expected: "(a AND (b OR c)) AND d",
},
{
name: "with spaces",
input: " ( true ) ",
expected: "true",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := stripOuterParentheses(tt.input)
if result != tt.expected {
t.Errorf("stripOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestIsTrivialCondition(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{"true", "true", true},
{"true with spaces", " true ", true},
{"TRUE uppercase", "TRUE", true},
{"1=1", "1=1", true},
{"1 = 1", "1 = 1", true},
{"true = true", "true = true", true},
{"valid condition", "status = 'active'", false},
{"false", "false", false},
{"column name", "is_active", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsTrivialCondition(tt.input)
if result != tt.expected {
t.Errorf("IsTrivialCondition(%q) = %v; want %v", tt.input, result, tt.expected)
}
})
}
}
// Test model for model-aware sanitization tests
type MasterTask struct {
ID int `bun:"id,pk"`
Name string `bun:"name"`
Status string `bun:"status"`
UserID int `bun:"user_id"`
}
func TestSanitizeWhereClauseWithModel(t *testing.T) {
// Register the test model
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
if err != nil {
// Model might already be registered, ignore error
t.Logf("Model registration returned: %v", err)
}
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "valid column gets prefixed",
where: "status = 'active'",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "multiple valid columns get prefixed",
where: "status = 'active' AND user_id = 123",
tableName: "mastertask",
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
},
{
name: "invalid column does not get prefixed",
where: "invalid_column = 'value'",
tableName: "mastertask",
expected: "invalid_column = 'value'",
},
{
name: "mix of valid and trivial conditions",
where: "true AND status = 'active' AND 1=1",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "parentheses with valid column",
where: "(status = 'active')",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}

View File

@@ -43,6 +43,11 @@ type PreloadOption struct {
Updatable *bool `json:"updateable"` // if true, the relation can be updated
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
// Relationship keys from XFiles - used to build proper foreign key filters
PrimaryKey string `json:"primary_key"` // Primary key of the related table
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
}
type FilterOption struct {

View File

@@ -17,3 +17,33 @@ func Len(v any) int {
return 0
}
}
// ExtractTableNameOnly extracts the table name from a fully qualified table reference.
// It removes any schema prefix (e.g., "schema.table" -> "table") and truncates at
// the first delimiter (comma, space, tab, or newline). If the input contains multiple
// dots, it returns everything after the last dot up to the first delimiter.
func ExtractTableNameOnly(fullName string) string {
// First, split by dot to remove schema prefix if present
lastDotIndex := -1
for i, char := range fullName {
if char == '.' {
lastDotIndex = i
}
}
// Start from after the last dot (or from beginning if no dot)
startIndex := 0
if lastDotIndex != -1 {
startIndex = lastDotIndex + 1
}
// Now find the end (first delimiter after the table name)
for i := startIndex; i < len(fullName); i++ {
char := rune(fullName[i])
if char == ',' || char == ' ' || char == '\t' || char == '\n' {
return fullName[startIndex:i]
}
}
return fullName[startIndex:]
}

View File

@@ -756,11 +756,42 @@ func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error)
// 2. Bun tag name (if exists)
// 3. Gorm tag name (if exists)
// 4. JSON tag name (if exists)
//
// Supports recursive field paths using dot notation (e.g., "MAL.MAL.DEF")
// For nested fields, it traverses through each level of the struct hierarchy
func GetRelationModel(model interface{}, fieldName string) interface{} {
if model == nil || fieldName == "" {
return nil
}
// Split the field name by "." to handle nested/recursive relations
fieldParts := strings.Split(fieldName, ".")
// Start with the current model
currentModel := model
// Traverse through each level of the field path
for _, part := range fieldParts {
if part == "" {
continue
}
currentModel = getRelationModelSingleLevel(currentModel, part)
if currentModel == nil {
return nil
}
}
return currentModel
}
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
// This is a helper function used by GetRelationModel to handle one level at a time
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
if model == nil || fieldName == "" {
return nil
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nil

View File

@@ -199,7 +199,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply column selection
if len(options.Columns) > 0 {
logger.Debug("Selecting columns: %v", options.Columns)
query = query.Column(options.Columns...)
for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
}
if len(options.ComputedColumns) > 0 {
@@ -1209,7 +1211,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
}
if len(preload.Where) > 0 {
sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation)
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere)
}

View File

@@ -213,6 +213,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
tableName := GetTableName(ctx)
model := GetModel(ctx)
if id == "" {
options.SingleRecordAsObject = false
}
// Execute BeforeRead hooks
hookCtx := &HookContext{
Context: ctx,
@@ -299,7 +303,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply column selection
if len(options.Columns) > 0 {
logger.Debug("Selecting columns: %v", options.Columns)
query = query.Column(options.Columns...)
for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
}
// Apply expand (Just expand to Preload for now)
@@ -392,7 +399,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
if options.CustomSQLWhere != "" {
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, "")
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
if sanitizedWhere != "" {
query = query.Where(sanitizedWhere)
}
@@ -402,7 +409,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
if options.CustomSQLOr != "" {
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, "")
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
if sanitizedOr != "" {
query = query.WhereOr(sanitizedOr)
}
@@ -481,7 +488,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply cursor filter to query
if cursorFilter != "" {
logger.Debug("Applying cursor filter: %s", cursorFilter)
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, "")
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName))
if sanitizedCursor != "" {
query = query.Where(sanitizedCursor)
}
@@ -560,11 +567,33 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery {
// Log relationship keys if they're specified (from XFiles)
if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" {
logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s",
preload.Relation, preload.PrimaryKey, preload.RelatedKey, preload.ForeignKey)
// Build a WHERE clause using the relationship keys if needed
// Note: Bun's PreloadRelation typically handles the relationship join automatically via struct tags
// However, if the relationship keys are explicitly provided from XFiles, we can use them
// to add additional filtering or validation
if preload.RelatedKey != "" && preload.Where == "" {
// For child tables: ensure the child's relatedkey column will be matched
// The actual parent value is dynamic and handled by Bun's preload mechanism
// We just log this for visibility
logger.Debug("Child table %s will be filtered by %s matching parent's primary key",
preload.Relation, preload.RelatedKey)
}
if preload.ForeignKey != "" && preload.Where == "" {
// For parent tables: ensure the parent's primary key matches the current table's foreign key
logger.Debug("Parent table %s will be filtered by primary key matching current table's %s",
preload.Relation, preload.ForeignKey)
}
}
// Apply the preload
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
// Get the related model for column operations
relationParts := strings.Split(preload.Relation, ",")
relatedModel := reflection.GetRelationModel(model, relationParts[0])
relatedModel := reflection.GetRelationModel(model, preload.Relation)
if relatedModel == nil {
logger.Warn("Could not get related model for preload: %s", preload.Relation)
// relatedModel = model // fallback to parent model
@@ -633,7 +662,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
// Apply WHERE clause
if len(preload.Where) > 0 {
sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation)
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere)
}

View File

@@ -267,6 +267,12 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
h.resolveRelationNamesInOptions(&options, model)
}
// Always sort according to the primary key if no sorting is specified
if len(options.Sort) == 0 {
pkName := reflection.GetPrimaryKeyName(model)
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
return options
}
@@ -919,6 +925,20 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
// Set recursive flag
preloadOpt.Recursive = xfile.Recursive
// Extract relationship keys for proper foreign key filtering
if xfile.PrimaryKey != "" {
preloadOpt.PrimaryKey = xfile.PrimaryKey
logger.Debug("X-Files: Set primary key for %s: %s", relationPath, xfile.PrimaryKey)
}
if xfile.RelatedKey != "" {
preloadOpt.RelatedKey = xfile.RelatedKey
logger.Debug("X-Files: Set related key for %s: %s", relationPath, xfile.RelatedKey)
}
if xfile.ForeignKey != "" {
preloadOpt.ForeignKey = xfile.ForeignKey
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
}
// Add the preload option
options.Preload = append(options.Preload, preloadOpt)