diff --git a/pkg/reflection/helpers.go b/pkg/reflection/helpers.go index fe9520c..cc6787f 100644 --- a/pkg/reflection/helpers.go +++ b/pkg/reflection/helpers.go @@ -17,3 +17,33 @@ func Len(v any) int { return 0 } } + +// ExtractTableNameOnly extracts the table name from a fully qualified table reference. +// It removes any schema prefix (e.g., "schema.table" -> "table") and truncates at +// the first delimiter (comma, space, tab, or newline). If the input contains multiple +// dots, it returns everything after the last dot up to the first delimiter. +func ExtractTableNameOnly(fullName string) string { + // First, split by dot to remove schema prefix if present + lastDotIndex := -1 + for i, char := range fullName { + if char == '.' { + lastDotIndex = i + } + } + + // Start from after the last dot (or from beginning if no dot) + startIndex := 0 + if lastDotIndex != -1 { + startIndex = lastDotIndex + 1 + } + + // Now find the end (first delimiter after the table name) + for i := startIndex; i < len(fullName); i++ { + char := rune(fullName[i]) + if char == ',' || char == ' ' || char == '\t' || char == '\n' { + return fullName[startIndex:i] + } + } + + return fullName[startIndex:] +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 65d20c8..5ded143 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -1209,7 +1209,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre } if len(preload.Where) > 0 { - sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation) + sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation)) if len(sanitizedWhere) > 0 { sq = sq.Where(sanitizedWhere) } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index be1bdeb..40a395b 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -392,7 +392,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st if options.CustomSQLWhere != "" { logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere) // Sanitize without auto-prefixing since custom SQL may reference multiple tables - sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, "") + sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName)) if sanitizedWhere != "" { query = query.Where(sanitizedWhere) } @@ -402,7 +402,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st if options.CustomSQLOr != "" { logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr) // Sanitize without auto-prefixing since custom SQL may reference multiple tables - sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, "") + sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName)) if sanitizedOr != "" { query = query.WhereOr(sanitizedOr) } @@ -481,7 +481,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // Apply cursor filter to query if cursorFilter != "" { logger.Debug("Applying cursor filter: %s", cursorFilter) - sanitizedCursor := common.SanitizeWhereClause(cursorFilter, "") + sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName)) if sanitizedCursor != "" { query = query.Where(sanitizedCursor) } @@ -655,7 +655,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co // Apply WHERE clause if len(preload.Where) > 0 { - sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation) + sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation)) if len(sanitizedWhere) > 0 { sq = sq.Where(sanitizedWhere) }