mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-01 09:44:24 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2e0c36c79 | ||
|
|
59bd709460 | ||
|
|
05962035b6 | ||
|
|
1cd04b7083 |
@@ -121,6 +121,16 @@ func (b *BunRouterRequest) QueryParam(key string) string {
|
|||||||
return b.req.URL.Query().Get(key)
|
return b.req.URL.Query().Get(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BunRouterRequest) AllQueryParams() map[string]string {
|
||||||
|
params := make(map[string]string)
|
||||||
|
for key, values := range b.req.URL.Query() {
|
||||||
|
if len(values) > 0 {
|
||||||
|
params[key] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BunRouterRequest) AllHeaders() map[string]string {
|
func (b *BunRouterRequest) AllHeaders() map[string]string {
|
||||||
headers := make(map[string]string)
|
headers := make(map[string]string)
|
||||||
for key, values := range b.req.Header {
|
for key, values := range b.req.Header {
|
||||||
|
|||||||
@@ -117,6 +117,16 @@ func (h *HTTPRequest) QueryParam(key string) string {
|
|||||||
return h.req.URL.Query().Get(key)
|
return h.req.URL.Query().Get(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *HTTPRequest) AllQueryParams() map[string]string {
|
||||||
|
params := make(map[string]string)
|
||||||
|
for key, values := range h.req.URL.Query() {
|
||||||
|
if len(values) > 0 {
|
||||||
|
params[key] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
func (h *HTTPRequest) AllHeaders() map[string]string {
|
func (h *HTTPRequest) AllHeaders() map[string]string {
|
||||||
headers := make(map[string]string)
|
headers := make(map[string]string)
|
||||||
for key, values := range h.req.Header {
|
for key, values := range h.req.Header {
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ type Request interface {
|
|||||||
Body() ([]byte, error)
|
Body() ([]byte, error)
|
||||||
PathParam(key string) string
|
PathParam(key string) string
|
||||||
QueryParam(key string) string
|
QueryParam(key string) string
|
||||||
|
AllQueryParams() map[string]string // Get all query parameters as a map
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResponseWriter interface abstracts HTTP response
|
// ResponseWriter interface abstracts HTTP response
|
||||||
|
|||||||
136
pkg/common/sql_helpers.go
Normal file
136
pkg/common/sql_helpers.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||||
|
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
||||||
|
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
||||||
|
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
||||||
|
if where == "" {
|
||||||
|
return where, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the relation name is already present in the WHERE clause
|
||||||
|
lowerWhere := strings.ToLower(where)
|
||||||
|
lowerRelation := strings.ToLower(relationName)
|
||||||
|
|
||||||
|
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
||||||
|
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
||||||
|
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
||||||
|
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
||||||
|
// Relation prefix is already present
|
||||||
|
return where, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
||||||
|
// we can't safely auto-fix it - require explicit prefix
|
||||||
|
if strings.Contains(lowerWhere, " or ") ||
|
||||||
|
strings.Contains(where, "(") ||
|
||||||
|
strings.Contains(where, ")") {
|
||||||
|
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to add the relation prefix to simple column references
|
||||||
|
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
||||||
|
// Split by AND to handle multiple conditions (case-insensitive)
|
||||||
|
originalConditions := strings.Split(where, " AND ")
|
||||||
|
|
||||||
|
// If uppercase split didn't work, try lowercase
|
||||||
|
if len(originalConditions) == 1 {
|
||||||
|
originalConditions = strings.Split(where, " and ")
|
||||||
|
}
|
||||||
|
|
||||||
|
fixedConditions := make([]string, 0, len(originalConditions))
|
||||||
|
|
||||||
|
for _, cond := range originalConditions {
|
||||||
|
cond = strings.TrimSpace(cond)
|
||||||
|
if cond == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this condition already has a table prefix (contains a dot)
|
||||||
|
if strings.Contains(cond, ".") {
|
||||||
|
fixedConditions = append(fixedConditions, cond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||||
|
lowerCond := strings.ToLower(strings.TrimSpace(cond))
|
||||||
|
if IsSQLExpression(lowerCond) {
|
||||||
|
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
|
||||||
|
fixedConditions = append(fixedConditions, cond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the column name (first identifier before operator)
|
||||||
|
columnName := ExtractColumnName(cond)
|
||||||
|
if columnName == "" {
|
||||||
|
// Can't identify column name, require explicit prefix
|
||||||
|
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add relation prefix to the column name only
|
||||||
|
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
|
||||||
|
fixedConditions = append(fixedConditions, fixedCond)
|
||||||
|
}
|
||||||
|
|
||||||
|
fixedWhere := strings.Join(fixedConditions, " AND ")
|
||||||
|
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
||||||
|
return fixedWhere, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
|
||||||
|
func IsSQLExpression(cond string) bool {
|
||||||
|
// Common SQL literals and expressions
|
||||||
|
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
|
||||||
|
for _, literal := range sqlLiterals {
|
||||||
|
if cond == literal {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractColumnName extracts the column name from a WHERE condition
|
||||||
|
// For example: "status = 'active'" returns "status"
|
||||||
|
func ExtractColumnName(cond string) string {
|
||||||
|
// Common SQL operators
|
||||||
|
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
||||||
|
|
||||||
|
for _, op := range operators {
|
||||||
|
if idx := strings.Index(cond, op); idx > 0 {
|
||||||
|
columnName := strings.TrimSpace(cond[:idx])
|
||||||
|
// Remove quotes if present
|
||||||
|
columnName = strings.Trim(columnName, "`\"'")
|
||||||
|
return columnName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no operator found, check if it's a simple identifier (for boolean columns)
|
||||||
|
parts := strings.Fields(cond)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
columnName := strings.Trim(parts[0], "`\"'")
|
||||||
|
// Check if it's a valid identifier (not a SQL keyword)
|
||||||
|
if !IsSQLKeyword(strings.ToLower(columnName)) {
|
||||||
|
return columnName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
|
||||||
|
func IsSQLKeyword(word string) bool {
|
||||||
|
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
|
||||||
|
for _, kw := range keywords {
|
||||||
|
if word == kw {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -323,6 +323,127 @@ func ExtractColumnFromBunTag(tag string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSQLModelColumns extracts column names that have valid SQL field mappings
|
||||||
|
// This function only returns columns that:
|
||||||
|
// 1. Have bun or gorm tags (not just json tags)
|
||||||
|
// 2. Are not relations (no rel:, join:, foreignKey, references, many2many tags)
|
||||||
|
// 3. Are not scan-only embedded fields
|
||||||
|
func GetSQLModelColumns(model any) []string {
|
||||||
|
var columns []string
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
collectSQLColumnsFromType(modelType, &columns, false)
|
||||||
|
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectSQLColumnsFromType recursively collects SQL column names from a struct type
|
||||||
|
// scanOnlyEmbedded indicates if we're inside a scan-only embedded struct
|
||||||
|
func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbedded bool) {
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the embedded struct itself is scan-only
|
||||||
|
isScanOnly := scanOnlyEmbedded
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||||
|
isScanOnly = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively process embedded struct
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
collectSQLColumnsFromType(fieldType, columns, isScanOnly)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip fields in scan-only embedded structs
|
||||||
|
if scanOnlyEmbedded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get bun and gorm tags
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
|
||||||
|
// Skip if neither bun nor gorm tag exists
|
||||||
|
if bunTag == "" && gormTag == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if explicitly marked with "-"
|
||||||
|
if bunTag == "-" || gormTag == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if field itself is scan-only (bun)
|
||||||
|
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if field itself is read-only (gorm)
|
||||||
|
if gormTag != "" && isGormFieldReadOnly(gormTag) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip relation fields (bun)
|
||||||
|
if bunTag != "" {
|
||||||
|
// Skip if it's a bun relation (rel:, join:, or m2m:)
|
||||||
|
if strings.Contains(bunTag, "rel:") ||
|
||||||
|
strings.Contains(bunTag, "join:") ||
|
||||||
|
strings.Contains(bunTag, "m2m:") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip relation fields (gorm)
|
||||||
|
if gormTag != "" {
|
||||||
|
// Skip if it has gorm relationship tags
|
||||||
|
if strings.Contains(gormTag, "foreignKey:") ||
|
||||||
|
strings.Contains(gormTag, "references:") ||
|
||||||
|
strings.Contains(gormTag, "many2many:") ||
|
||||||
|
strings.Contains(gormTag, "constraint:") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get column name
|
||||||
|
columnName := ""
|
||||||
|
if bunTag != "" {
|
||||||
|
columnName = ExtractColumnFromBunTag(bunTag)
|
||||||
|
}
|
||||||
|
if columnName == "" && gormTag != "" {
|
||||||
|
columnName = ExtractColumnFromGormTag(gormTag)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if we couldn't extract a column name
|
||||||
|
if columnName == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
*columns = append(*columns, columnName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// IsColumnWritable checks if a column can be written to in the database
|
// IsColumnWritable checks if a column can be written to in the database
|
||||||
// For bun: returns false if the field has "scanonly" tag
|
// For bun: returns false if the field has "scanonly" tag
|
||||||
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
||||||
|
|||||||
@@ -474,3 +474,143 @@ func TestIsColumnWritableWithEmbedded(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test models with relations for GetSQLModelColumns
|
||||||
|
type User struct {
|
||||||
|
ID int `bun:"id,pk" json:"id"`
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
Email string `bun:"email" json:"email"`
|
||||||
|
ProfileData string `json:"profile_data"` // No bun/gorm tag
|
||||||
|
Posts []Post `bun:"rel:has-many,join:id=user_id" json:"posts"`
|
||||||
|
Profile *Profile `bun:"rel:has-one,join:id=user_id" json:"profile"`
|
||||||
|
RowNumber int64 `bun:",scanonly" json:"_rownumber"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Post struct {
|
||||||
|
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||||
|
Title string `gorm:"column:title" json:"title"`
|
||||||
|
UserID int `gorm:"column:user_id;foreignKey" json:"user_id"`
|
||||||
|
User *User `gorm:"foreignKey:UserID;references:ID" json:"user"`
|
||||||
|
Tags []Tag `gorm:"many2many:post_tags" json:"tags"`
|
||||||
|
Content string `json:"content"` // No bun/gorm tag
|
||||||
|
}
|
||||||
|
|
||||||
|
type Profile struct {
|
||||||
|
ID int `bun:"id,pk" json:"id"`
|
||||||
|
Bio string `bun:"bio" json:"bio"`
|
||||||
|
UserID int `bun:"user_id" json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tag struct {
|
||||||
|
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||||
|
Name string `gorm:"column:name" json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model with scan-only embedded struct
|
||||||
|
type EntityWithScanOnlyEmbedded struct {
|
||||||
|
ID int `bun:"id,pk" json:"id"`
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
AdhocBuffer `bun:",scanonly"` // Entire embedded struct is scan-only
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSQLModelColumns(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with relations - excludes relations and non-SQL fields",
|
||||||
|
model: User{},
|
||||||
|
// Should include: id, name, email (has bun tags)
|
||||||
|
// Should exclude: profile_data (no bun tag), Posts/Profile (relations), RowNumber (scan-only in embedded would be excluded)
|
||||||
|
expected: []string{"id", "name", "email"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with relations - excludes relations and non-SQL fields",
|
||||||
|
model: Post{},
|
||||||
|
// Should include: id, title, user_id (has gorm tags)
|
||||||
|
// Should exclude: content (no gorm tag), User/Tags (relations)
|
||||||
|
expected: []string{"id", "title", "user_id"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Model with embedded base and scan-only embedded",
|
||||||
|
model: EntityWithScanOnlyEmbedded{},
|
||||||
|
// Should include: id, name from main struct
|
||||||
|
// Should exclude: all fields from AdhocBuffer (scan-only embedded struct)
|
||||||
|
expected: []string{"id", "name"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Model with embedded - includes SQL fields, excludes scan-only",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
// Should include: rid_base, created_at (from BaseModel), name, description (from main)
|
||||||
|
// Should exclude: cql1, cql2, _rownumber (from AdhocBuffer - scan-only fields)
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded - includes SQL fields, excludes scan-only",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
// Should include: rid_base, created_at (from GormBaseModel), name, description (from main)
|
||||||
|
// Should exclude: cql1, cql2 (scan-only), _rownumber (no gorm column tag, marked as -)
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Simple Profile model",
|
||||||
|
model: Profile{},
|
||||||
|
// Should include all fields with bun tags
|
||||||
|
expected: []string{"id", "bio", "user_id"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetSQLModelColumns(tt.model)
|
||||||
|
if len(result) != len(tt.expected) {
|
||||||
|
t.Errorf("GetSQLModelColumns() returned %d columns, want %d.\nGot: %v\nWant: %v",
|
||||||
|
len(result), len(tt.expected), result, tt.expected)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i, col := range result {
|
||||||
|
if col != tt.expected[i] {
|
||||||
|
t.Errorf("GetSQLModelColumns()[%d] = %v, want %v.\nFull result: %v",
|
||||||
|
i, col, tt.expected[i], result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSQLModelColumnsVsGetModelColumns(t *testing.T) {
|
||||||
|
// Demonstrate the difference between GetModelColumns and GetSQLModelColumns
|
||||||
|
user := User{}
|
||||||
|
|
||||||
|
allColumns := GetModelColumns(user)
|
||||||
|
sqlColumns := GetSQLModelColumns(user)
|
||||||
|
|
||||||
|
t.Logf("GetModelColumns(User): %v", allColumns)
|
||||||
|
t.Logf("GetSQLModelColumns(User): %v", sqlColumns)
|
||||||
|
|
||||||
|
// GetModelColumns should return more columns (includes fields with only json tags)
|
||||||
|
if len(allColumns) <= len(sqlColumns) {
|
||||||
|
t.Errorf("Expected GetModelColumns to return more columns than GetSQLModelColumns")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSQLModelColumns should not include 'profile_data' (no bun tag)
|
||||||
|
for _, col := range sqlColumns {
|
||||||
|
if col == "profile_data" {
|
||||||
|
t.Errorf("GetSQLModelColumns should not include 'profile_data' (no bun/gorm tag)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelColumns should include 'profile_data' (has json tag)
|
||||||
|
hasProfileData := false
|
||||||
|
for _, col := range allColumns {
|
||||||
|
if col == "profile_data" {
|
||||||
|
hasProfileData = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasProfileData {
|
||||||
|
t.Errorf("GetModelColumns should include 'profile_data' (has json tag)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -191,6 +191,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
query = query.Table(tableName)
|
query = query.Table(tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(options.Columns) == 0 && (len(options.ComputedColumns) > 0) {
|
||||||
|
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
|
||||||
|
options.Columns = reflection.GetSQLModelColumns(model)
|
||||||
|
}
|
||||||
|
|
||||||
// Apply column selection
|
// Apply column selection
|
||||||
if len(options.Columns) > 0 {
|
if len(options.Columns) > 0 {
|
||||||
logger.Debug("Selecting columns: %v", options.Columns)
|
logger.Debug("Selecting columns: %v", options.Columns)
|
||||||
@@ -1105,69 +1110,6 @@ type relationshipInfo struct {
|
|||||||
relatedModel interface{}
|
relatedModel interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
|
||||||
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
|
||||||
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
|
||||||
func (h *Handler) validateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
|
||||||
if where == "" {
|
|
||||||
return where, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the relation name is already present in the WHERE clause
|
|
||||||
lowerWhere := strings.ToLower(where)
|
|
||||||
lowerRelation := strings.ToLower(relationName)
|
|
||||||
|
|
||||||
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
|
||||||
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
|
||||||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
|
||||||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
|
||||||
// Relation prefix is already present
|
|
||||||
return where, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
|
||||||
// we can't safely auto-fix it - require explicit prefix
|
|
||||||
if strings.Contains(lowerWhere, " or ") ||
|
|
||||||
strings.Contains(where, "(") ||
|
|
||||||
strings.Contains(where, ")") {
|
|
||||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to add the relation prefix to simple column references
|
|
||||||
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
|
||||||
// Split by AND to handle multiple conditions (case-insensitive)
|
|
||||||
originalConditions := strings.Split(where, " AND ")
|
|
||||||
|
|
||||||
// If uppercase split didn't work, try lowercase
|
|
||||||
if len(originalConditions) == 1 {
|
|
||||||
originalConditions = strings.Split(where, " and ")
|
|
||||||
}
|
|
||||||
|
|
||||||
fixedConditions := make([]string, 0, len(originalConditions))
|
|
||||||
|
|
||||||
for _, cond := range originalConditions {
|
|
||||||
cond = strings.TrimSpace(cond)
|
|
||||||
if cond == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this condition already has a table prefix (contains a dot)
|
|
||||||
if strings.Contains(cond, ".") {
|
|
||||||
fixedConditions = append(fixedConditions, cond)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add relation prefix to the column name
|
|
||||||
// This prefixes the entire condition with "relationName."
|
|
||||||
fixedCond := fmt.Sprintf("%s.%s", relationName, cond)
|
|
||||||
fixedConditions = append(fixedConditions, fixedCond)
|
|
||||||
}
|
|
||||||
|
|
||||||
fixedWhere := strings.Join(fixedConditions, " AND ")
|
|
||||||
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
|
||||||
return fixedWhere, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
|
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery {
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
@@ -1197,7 +1139,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
|
|
||||||
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
fixedWhere, err := h.validateAndFixPreloadWhere(preload.Where, relationFieldName)
|
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, relationFieldName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
|
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
|
||||||
panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err))
|
panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err))
|
||||||
@@ -1208,7 +1150,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
logger.Debug("Applying preload: %s", relationFieldName)
|
logger.Debug("Applying preload: %s", relationFieldName)
|
||||||
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
if len(preload.OmitColumns) > 0 {
|
if len(preload.OmitColumns) > 0 {
|
||||||
allCols := reflection.GetModelColumns(model)
|
allCols := reflection.GetSQLModelColumns(model)
|
||||||
// Remove omitted columns
|
// Remove omitted columns
|
||||||
preload.Columns = []string{}
|
preload.Columns = []string{}
|
||||||
for _, col := range allCols {
|
for _, col := range allCols {
|
||||||
|
|||||||
@@ -200,69 +200,6 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
|
|||||||
|
|
||||||
// parseOptionsFromHeaders is now implemented in headers.go
|
// parseOptionsFromHeaders is now implemented in headers.go
|
||||||
|
|
||||||
// validateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
|
||||||
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
|
||||||
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
|
||||||
func (h *Handler) validateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
|
||||||
if where == "" {
|
|
||||||
return where, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the relation name is already present in the WHERE clause
|
|
||||||
lowerWhere := strings.ToLower(where)
|
|
||||||
lowerRelation := strings.ToLower(relationName)
|
|
||||||
|
|
||||||
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
|
||||||
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
|
||||||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
|
||||||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
|
||||||
// Relation prefix is already present
|
|
||||||
return where, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
|
||||||
// we can't safely auto-fix it - require explicit prefix
|
|
||||||
if strings.Contains(lowerWhere, " or ") ||
|
|
||||||
strings.Contains(where, "(") ||
|
|
||||||
strings.Contains(where, ")") {
|
|
||||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to add the relation prefix to simple column references
|
|
||||||
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
|
||||||
// Split by AND to handle multiple conditions (case-insensitive)
|
|
||||||
originalConditions := strings.Split(where, " AND ")
|
|
||||||
|
|
||||||
// If uppercase split didn't work, try lowercase
|
|
||||||
if len(originalConditions) == 1 {
|
|
||||||
originalConditions = strings.Split(where, " and ")
|
|
||||||
}
|
|
||||||
|
|
||||||
fixedConditions := make([]string, 0, len(originalConditions))
|
|
||||||
|
|
||||||
for _, cond := range originalConditions {
|
|
||||||
cond = strings.TrimSpace(cond)
|
|
||||||
if cond == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this condition already has a table prefix (contains a dot)
|
|
||||||
if strings.Contains(cond, ".") {
|
|
||||||
fixedConditions = append(fixedConditions, cond)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add relation prefix to the column name
|
|
||||||
// This prefixes the entire condition with "relationName."
|
|
||||||
fixedCond := fmt.Sprintf("%s.%s", relationName, cond)
|
|
||||||
fixedConditions = append(fixedConditions, fixedCond)
|
|
||||||
}
|
|
||||||
|
|
||||||
fixedWhere := strings.Join(fixedConditions, " AND ")
|
|
||||||
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
|
||||||
return fixedWhere, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
|
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options ExtendedRequestOptions) {
|
||||||
// Capture panics and return error response
|
// Capture panics and return error response
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -323,9 +260,12 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
query = query.Table(tableName)
|
query = query.Table(tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: X-Files configuration is now applied via parseXFiles which populates
|
// If we have computed columns/expressions but options.Columns is empty,
|
||||||
// ExtendedRequestOptions fields (columns, filters, sort, preload, etc.)
|
// populate it with all model columns first since computed columns are additions
|
||||||
// These are applied below in the normal query building process
|
if len(options.Columns) == 0 && (len(options.ComputedQL) > 0 || len(options.ComputedColumns) > 0) {
|
||||||
|
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
|
||||||
|
options.Columns = reflection.GetSQLModelColumns(model)
|
||||||
|
}
|
||||||
|
|
||||||
// Apply ComputedQL fields if any
|
// Apply ComputedQL fields if any
|
||||||
if len(options.ComputedQL) > 0 {
|
if len(options.ComputedQL) > 0 {
|
||||||
@@ -410,7 +350,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
|
|
||||||
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
fixedWhere, err := h.validateAndFixPreloadWhere(preload.Where, preload.Relation)
|
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, preload.Relation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
|
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
|
||||||
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",
|
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",
|
||||||
|
|||||||
@@ -117,15 +117,28 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
// Get all headers
|
// Get all headers
|
||||||
headers := r.AllHeaders()
|
headers := r.AllHeaders()
|
||||||
|
|
||||||
// Process each header
|
// Get all query parameters
|
||||||
|
queryParams := r.AllQueryParams()
|
||||||
|
|
||||||
|
// Merge headers and query parameters - query parameters take precedence
|
||||||
|
// This allows the same parameters to be specified in either headers or query string
|
||||||
|
combinedParams := make(map[string]string)
|
||||||
for key, value := range headers {
|
for key, value := range headers {
|
||||||
// Normalize header key to lowercase for consistent matching
|
combinedParams[key] = value
|
||||||
|
}
|
||||||
|
for key, value := range queryParams {
|
||||||
|
combinedParams[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process each parameter (from both headers and query params)
|
||||||
|
for key, value := range combinedParams {
|
||||||
|
// Normalize parameter key to lowercase for consistent matching
|
||||||
normalizedKey := strings.ToLower(key)
|
normalizedKey := strings.ToLower(key)
|
||||||
|
|
||||||
// Decode value if it's base64 encoded
|
// Decode value if it's base64 encoded
|
||||||
decodedValue := decodeHeaderValue(value)
|
decodedValue := decodeHeaderValue(value)
|
||||||
|
|
||||||
// Parse based on header prefix/name
|
// Parse based on parameter prefix/name
|
||||||
switch {
|
switch {
|
||||||
// Field Selection
|
// Field Selection
|
||||||
case strings.HasPrefix(normalizedKey, "x-select-fields"):
|
case strings.HasPrefix(normalizedKey, "x-select-fields"):
|
||||||
@@ -158,7 +171,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
if strings.HasSuffix(normalizedKey, "-where") {
|
if strings.HasSuffix(normalizedKey, "-where") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
whereClaude := headers[fmt.Sprintf("%s-where", key)]
|
whereClaude := combinedParams[fmt.Sprintf("%s-where", key)]
|
||||||
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
|
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
|
||||||
|
|
||||||
case strings.HasPrefix(normalizedKey, "x-expand"):
|
case strings.HasPrefix(normalizedKey, "x-expand"):
|
||||||
|
|||||||
403
pkg/restheadspec/query_params_test.go
Normal file
403
pkg/restheadspec/query_params_test.go
Normal file
@@ -0,0 +1,403 @@
|
|||||||
|
package restheadspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockRequest implements common.Request interface for testing
|
||||||
|
type MockRequest struct {
|
||||||
|
headers map[string]string
|
||||||
|
queryParams map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) Method() string {
|
||||||
|
return "GET"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) URL() string {
|
||||||
|
return "http://example.com/test"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) Header(key string) string {
|
||||||
|
return m.headers[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) AllHeaders() map[string]string {
|
||||||
|
return m.headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) Body() ([]byte, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) PathParam(key string) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) QueryParam(key string) string {
|
||||||
|
return m.queryParams[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockRequest) AllQueryParams() map[string]string {
|
||||||
|
return m.queryParams
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOptionsFromQueryParams(t *testing.T) {
|
||||||
|
handler := NewHandler(nil, nil)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryParams map[string]string
|
||||||
|
headers map[string]string
|
||||||
|
validate func(t *testing.T, options ExtendedRequestOptions)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Parse custom SQL WHERE from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-custom-sql-w-1": `("v_webui_clients".clientstatus = 0 or "v_webui_clients".clientstatus is null)`,
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if options.CustomSQLWhere == "" {
|
||||||
|
t.Error("Expected CustomSQLWhere to be set from query param")
|
||||||
|
}
|
||||||
|
expected := `("v_webui_clients".clientstatus = 0 or "v_webui_clients".clientstatus is null)`
|
||||||
|
if options.CustomSQLWhere != expected {
|
||||||
|
t.Errorf("Expected CustomSQLWhere=%q, got %q", expected, options.CustomSQLWhere)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse sort from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-sort": "-applicationdate,name",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.Sort) != 2 {
|
||||||
|
t.Errorf("Expected 2 sort options, got %d", len(options.Sort))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if options.Sort[0].Column != "applicationdate" || options.Sort[0].Direction != "DESC" {
|
||||||
|
t.Errorf("Expected first sort: applicationdate DESC, got %s %s", options.Sort[0].Column, options.Sort[0].Direction)
|
||||||
|
}
|
||||||
|
if options.Sort[1].Column != "name" || options.Sort[1].Direction != "ASC" {
|
||||||
|
t.Errorf("Expected second sort: name ASC, got %s %s", options.Sort[1].Column, options.Sort[1].Direction)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse limit and offset from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-limit": "100",
|
||||||
|
"x-offset": "50",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if options.Limit == nil || *options.Limit != 100 {
|
||||||
|
t.Errorf("Expected limit=100, got %v", options.Limit)
|
||||||
|
}
|
||||||
|
if options.Offset == nil || *options.Offset != 50 {
|
||||||
|
t.Errorf("Expected offset=50, got %v", options.Offset)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse field filters from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-fieldfilter-status": "active",
|
||||||
|
"x-fieldfilter-type": "user",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.Filters) != 2 {
|
||||||
|
t.Errorf("Expected 2 filters, got %d", len(options.Filters))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check that filters were created
|
||||||
|
foundStatus := false
|
||||||
|
foundType := false
|
||||||
|
for _, filter := range options.Filters {
|
||||||
|
if filter.Column == "status" && filter.Value == "active" && filter.Operator == "eq" {
|
||||||
|
foundStatus = true
|
||||||
|
}
|
||||||
|
if filter.Column == "type" && filter.Value == "user" && filter.Operator == "eq" {
|
||||||
|
foundType = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundStatus {
|
||||||
|
t.Error("Expected status filter not found")
|
||||||
|
}
|
||||||
|
if !foundType {
|
||||||
|
t.Error("Expected type filter not found")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse select fields from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-select-fields": "id,name,email",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.Columns) != 3 {
|
||||||
|
t.Errorf("Expected 3 columns, got %d", len(options.Columns))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
expected := []string{"id", "name", "email"}
|
||||||
|
for i, col := range expected {
|
||||||
|
if i >= len(options.Columns) || options.Columns[i] != col {
|
||||||
|
t.Errorf("Expected column[%d]=%s, got %v", i, col, options.Columns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse preload from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-preload": "posts:title,content|comments",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.Preload) != 2 {
|
||||||
|
t.Errorf("Expected 2 preload options, got %d", len(options.Preload))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check first preload (posts with columns)
|
||||||
|
if options.Preload[0].Relation != "posts" {
|
||||||
|
t.Errorf("Expected first preload relation=posts, got %s", options.Preload[0].Relation)
|
||||||
|
}
|
||||||
|
if len(options.Preload[0].Columns) != 2 {
|
||||||
|
t.Errorf("Expected 2 columns for posts preload, got %d", len(options.Preload[0].Columns))
|
||||||
|
}
|
||||||
|
// Check second preload (comments without columns)
|
||||||
|
if options.Preload[1].Relation != "comments" {
|
||||||
|
t.Errorf("Expected second preload relation=comments, got %s", options.Preload[1].Relation)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Query params take precedence over headers",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-limit": "100",
|
||||||
|
},
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Limit": "50",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if options.Limit == nil || *options.Limit != 100 {
|
||||||
|
t.Errorf("Expected query param limit=100 to override header, got %v", options.Limit)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse search operators from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-searchop-contains-name": "john",
|
||||||
|
"x-searchop-gt-age": "18",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if len(options.Filters) != 2 {
|
||||||
|
t.Errorf("Expected 2 filters, got %d", len(options.Filters))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Check for ILIKE filter
|
||||||
|
foundContains := false
|
||||||
|
foundGt := false
|
||||||
|
for _, filter := range options.Filters {
|
||||||
|
if filter.Column == "name" && filter.Operator == "ilike" {
|
||||||
|
foundContains = true
|
||||||
|
}
|
||||||
|
if filter.Column == "age" && filter.Operator == "gt" && filter.Value == "18" {
|
||||||
|
foundGt = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundContains {
|
||||||
|
t.Error("Expected contains filter not found")
|
||||||
|
}
|
||||||
|
if !foundGt {
|
||||||
|
t.Error("Expected gt filter not found")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse complex example with multiple params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-custom-sql-w-1": `("v_webui_clients".clientstatus = 0)`,
|
||||||
|
"x-sort": "-applicationdate",
|
||||||
|
"x-limit": "100",
|
||||||
|
"x-select-fields": "id,name,status",
|
||||||
|
"x-fieldfilter-active": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
// Validate CustomSQLWhere
|
||||||
|
if options.CustomSQLWhere == "" {
|
||||||
|
t.Error("Expected CustomSQLWhere to be set")
|
||||||
|
}
|
||||||
|
// Validate Sort
|
||||||
|
if len(options.Sort) != 1 || options.Sort[0].Column != "applicationdate" || options.Sort[0].Direction != "DESC" {
|
||||||
|
t.Errorf("Expected sort by applicationdate DESC, got %v", options.Sort)
|
||||||
|
}
|
||||||
|
// Validate Limit
|
||||||
|
if options.Limit == nil || *options.Limit != 100 {
|
||||||
|
t.Errorf("Expected limit=100, got %v", options.Limit)
|
||||||
|
}
|
||||||
|
// Validate Columns
|
||||||
|
if len(options.Columns) != 3 {
|
||||||
|
t.Errorf("Expected 3 columns, got %d", len(options.Columns))
|
||||||
|
}
|
||||||
|
// Validate Filters
|
||||||
|
if len(options.Filters) < 1 {
|
||||||
|
t.Error("Expected at least 1 filter")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse distinct flag from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-distinct": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if !options.Distinct {
|
||||||
|
t.Error("Expected Distinct to be true")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse skip count flag from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-skipcount": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if !options.SkipCount {
|
||||||
|
t.Error("Expected SkipCount to be true")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse response format from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-syncfusion": "true",
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if options.ResponseFormat != "syncfusion" {
|
||||||
|
t.Errorf("Expected ResponseFormat=syncfusion, got %s", options.ResponseFormat)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parse custom SQL OR from query params",
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-custom-sql-or": `("field1" = 'value1' OR "field2" = 'value2')`,
|
||||||
|
},
|
||||||
|
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||||
|
if options.CustomSQLOr == "" {
|
||||||
|
t.Error("Expected CustomSQLOr to be set")
|
||||||
|
}
|
||||||
|
expected := `("field1" = 'value1' OR "field2" = 'value2')`
|
||||||
|
if options.CustomSQLOr != expected {
|
||||||
|
t.Errorf("Expected CustomSQLOr=%q, got %q", expected, options.CustomSQLOr)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Create mock request
|
||||||
|
req := &MockRequest{
|
||||||
|
headers: tt.headers,
|
||||||
|
queryParams: tt.queryParams,
|
||||||
|
}
|
||||||
|
if req.headers == nil {
|
||||||
|
req.headers = make(map[string]string)
|
||||||
|
}
|
||||||
|
if req.queryParams == nil {
|
||||||
|
req.queryParams = make(map[string]string)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse options
|
||||||
|
options := handler.parseOptionsFromHeaders(req, nil)
|
||||||
|
|
||||||
|
// Validate
|
||||||
|
tt.validate(t, options)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryParamsWithURLEncoding(t *testing.T) {
|
||||||
|
handler := NewHandler(nil, nil)
|
||||||
|
|
||||||
|
// Test with URL-encoded query parameter (like the user's example)
|
||||||
|
req := &MockRequest{
|
||||||
|
headers: make(map[string]string),
|
||||||
|
queryParams: map[string]string{
|
||||||
|
// URL-encoded version of the SQL WHERE clause
|
||||||
|
"x-custom-sql-w-1": `("v_webui_clients".clientstatus = 0 or "v_webui_clients".clientstatus is null) and ("v_webui_clients".inactive = 0 or "v_webui_clients".inactive is null)`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
options := handler.parseOptionsFromHeaders(req, nil)
|
||||||
|
|
||||||
|
if options.CustomSQLWhere == "" {
|
||||||
|
t.Error("Expected CustomSQLWhere to be set from URL-encoded query param")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The SQL should contain the expected conditions
|
||||||
|
if !contains(options.CustomSQLWhere, "clientstatus") {
|
||||||
|
t.Error("Expected CustomSQLWhere to contain 'clientstatus'")
|
||||||
|
}
|
||||||
|
if !contains(options.CustomSQLWhere, "inactive") {
|
||||||
|
t.Error("Expected CustomSQLWhere to contain 'inactive'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadersAndQueryParamsCombined(t *testing.T) {
|
||||||
|
handler := NewHandler(nil, nil)
|
||||||
|
|
||||||
|
// Test that headers and query params can work together
|
||||||
|
req := &MockRequest{
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Select-Fields": "id,name",
|
||||||
|
"X-Limit": "50",
|
||||||
|
},
|
||||||
|
queryParams: map[string]string{
|
||||||
|
"x-sort": "-created_at",
|
||||||
|
"x-offset": "10",
|
||||||
|
// This should override the header value
|
||||||
|
"x-limit": "100",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
options := handler.parseOptionsFromHeaders(req, nil)
|
||||||
|
|
||||||
|
// Verify columns from header
|
||||||
|
if len(options.Columns) != 2 {
|
||||||
|
t.Errorf("Expected 2 columns from header, got %d", len(options.Columns))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify sort from query param
|
||||||
|
if len(options.Sort) != 1 || options.Sort[0].Column != "created_at" {
|
||||||
|
t.Errorf("Expected sort from query param, got %v", options.Sort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify offset from query param
|
||||||
|
if options.Offset == nil || *options.Offset != 10 {
|
||||||
|
t.Errorf("Expected offset=10 from query param, got %v", options.Offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify limit from query param (should override header)
|
||||||
|
if options.Limit == nil {
|
||||||
|
t.Error("Expected limit to be set from query param")
|
||||||
|
} else if *options.Limit != 100 {
|
||||||
|
t.Errorf("Expected limit=100 from query param (overriding header), got %d", *options.Limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to check if a string contains a substring
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsHelper(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user