Compare commits

..

8 Commits

Author SHA1 Message Date
Hein
7853a3f56a cql_columns parsing and recursive preloading. Also added legacy header support for limt(s,e) ,sort(x,y,-z) 2025-11-21 09:15:40 +02:00
Hein
c2e0c36c79 Restheadspec now takes parameters from query parameters and headers. Allows for backward compatibility with our old dojo clients 2025-11-21 08:56:58 +02:00
Hein
59bd709460 More reflection function to handle sql columns and get default sqlcolumn lists. 2025-11-21 08:35:46 +02:00
Hein
05962035b6 when you specify computed columns without explicitly listing base columns, you'll get all base model column
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-20 17:34:46 +02:00
Hein
1cd04b7083 Better where clause handling for preloads 2025-11-20 17:02:27 +02:00
Hein
0d4909054c Better handling of preload where conditions and a few panic changes 2025-11-20 16:50:26 +02:00
Hein
745564f2e7 More Panic Recovery for reflection on orm 2025-11-20 15:20:21 +02:00
Hein
311e50bfdd Better relation lookup
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-20 14:30:59 +02:00
15 changed files with 1401 additions and 99 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/uptrace/bun"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
@@ -43,12 +44,22 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery {
return &BunDeleteQuery{query: b.db.NewDelete()}
}
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunAdapter.Exec", r)
}
}()
result, err := b.db.ExecContext(ctx, query, args...)
return &BunResult{result: result}, err
}
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunAdapter.Query", r)
}
}()
return b.db.NewRaw(query, args...).Scan(ctx, dest)
}
@@ -73,7 +84,12 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error {
return nil
}
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
}
}()
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
// Create adapter with transaction
adapter := &BunTxAdapter{tx: tx}
@@ -219,6 +235,11 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() {
if r := recover(); r != nil {
logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
}
}()
if len(apply) == 0 {
return sq
}
@@ -276,15 +297,38 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
return b
}
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) error {
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Scan", r)
}
}()
if dest == nil {
return fmt.Errorf("destination cannot be nil")
}
return b.query.Scan(ctx, dest)
}
func (b *BunSelectQuery) ScanModel(ctx context.Context) error {
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
}
}()
if b.query.GetModel() == nil {
return fmt.Errorf("model is nil")
}
return b.query.Scan(ctx)
}
func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Count", r)
count = 0
}
}()
// If Model() was set, use bun's native Count() which works properly
if b.hasModel {
count, err := b.query.Count(ctx)
@@ -293,15 +337,20 @@ func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
// Otherwise, wrap as subquery to avoid "Model(nil)" error
// This is needed when only Table() is set without a model
var count int
err := b.db.NewSelect().
err = b.db.NewSelect().
TableExpr("(?) AS subquery", b.query).
ColumnExpr("COUNT(*)").
Scan(ctx, &count)
return count, err
}
func (b *BunSelectQuery) Exists(ctx context.Context) (bool, error) {
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Exists", r)
exists = false
}
}()
return b.query.Exists(ctx)
}
@@ -320,7 +369,6 @@ func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
if b.hasModel {
// If model is set, do not override table name
return b
}
b.query = b.query.Table(table)
@@ -347,7 +395,12 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
return b
}
func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunInsertQuery.Exec", r)
}
}()
if b.values != nil && len(b.values) > 0 {
if !b.hasModel {
// If no model was set, use the values map as the model
@@ -428,7 +481,12 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
return b
}
func (b *BunUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
}
}()
result, err := b.query.Exec(ctx)
return &BunResult{result: result}, err
}
@@ -453,7 +511,12 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ
return b
}
func (b *BunDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
}
}()
result, err := b.query.Exec(ctx)
return &BunResult{result: result}, err
}

View File

@@ -8,6 +8,7 @@ import (
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
@@ -38,12 +39,22 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
return &GormDeleteQuery{db: g.db}
}
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.Exec", r)
}
}()
result := g.db.WithContext(ctx).Exec(query, args...)
return &GormResult{result: result}, result.Error
}
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.Query", r)
}
}()
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
}
@@ -63,7 +74,12 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
return g.db.WithContext(ctx).Rollback().Error
}
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
}
}()
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx}
return fn(adapter)
@@ -255,26 +271,48 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
return g
}
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) error {
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Scan", r)
}
}()
return g.db.WithContext(ctx).Find(dest).Error
}
func (g *GormSelectQuery) ScanModel(ctx context.Context) error {
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
}
}()
if g.db.Statement.Model == nil {
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
}
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
}
func (g *GormSelectQuery) Count(ctx context.Context) (int, error) {
var count int64
err := g.db.WithContext(ctx).Count(&count).Error
return int(count), err
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Count", r)
count = 0
}
}()
var count64 int64
err = g.db.WithContext(ctx).Count(&count64).Error
return int(count64), err
}
func (g *GormSelectQuery) Exists(ctx context.Context) (bool, error) {
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Exists", r)
exists = false
}
}()
var count int64
err := g.db.WithContext(ctx).Limit(1).Count(&count).Error
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
return count > 0, err
}
@@ -314,7 +352,12 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
return g
}
func (g *GormInsertQuery) Exec(ctx context.Context) (common.Result, error) {
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormInsertQuery.Exec", r)
}
}()
var result *gorm.DB
switch {
case g.model != nil:
@@ -401,7 +444,12 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
return g
}
func (g *GormUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Updates(g.updates)
return &GormResult{result: result}, result.Error
}
@@ -428,7 +476,12 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
return g
}
func (g *GormDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Delete(g.model)
return &GormResult{result: result}, result.Error
}

View File

@@ -121,6 +121,16 @@ func (b *BunRouterRequest) QueryParam(key string) string {
return b.req.URL.Query().Get(key)
}
func (b *BunRouterRequest) AllQueryParams() map[string]string {
params := make(map[string]string)
for key, values := range b.req.URL.Query() {
if len(values) > 0 {
params[key] = values[0]
}
}
return params
}
func (b *BunRouterRequest) AllHeaders() map[string]string {
headers := make(map[string]string)
for key, values := range b.req.Header {

View File

@@ -117,6 +117,16 @@ func (h *HTTPRequest) QueryParam(key string) string {
return h.req.URL.Query().Get(key)
}
func (h *HTTPRequest) AllQueryParams() map[string]string {
params := make(map[string]string)
for key, values := range h.req.URL.Query() {
if len(values) > 0 {
params[key] = values[0]
}
}
return params
}
func (h *HTTPRequest) AllHeaders() map[string]string {
headers := make(map[string]string)
for key, values := range h.req.Header {

View File

@@ -116,6 +116,7 @@ type Request interface {
Body() ([]byte, error)
PathParam(key string) string
QueryParam(key string) string
AllQueryParams() map[string]string // Get all query parameters as a map
}
// ResponseWriter interface abstracts HTTP response

136
pkg/common/sql_helpers.go Normal file
View File

@@ -0,0 +1,136 @@
package common
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
// the relation prefix (alias). If not present, it attempts to add it to column references.
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
if where == "" {
return where, nil
}
// Check if the relation name is already present in the WHERE clause
lowerWhere := strings.ToLower(where)
lowerRelation := strings.ToLower(relationName)
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
if strings.Contains(lowerWhere, lowerRelation+".") ||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
// Relation prefix is already present
return where, nil
}
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
// we can't safely auto-fix it - require explicit prefix
if strings.Contains(lowerWhere, " or ") ||
strings.Contains(where, "(") ||
strings.Contains(where, ")") {
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
}
// Try to add the relation prefix to simple column references
// This handles basic cases like "column = value" or "column = value AND other_column = value"
// Split by AND to handle multiple conditions (case-insensitive)
originalConditions := strings.Split(where, " AND ")
// If uppercase split didn't work, try lowercase
if len(originalConditions) == 1 {
originalConditions = strings.Split(where, " and ")
}
fixedConditions := make([]string, 0, len(originalConditions))
for _, cond := range originalConditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Check if this condition already has a table prefix (contains a dot)
if strings.Contains(cond, ".") {
fixedConditions = append(fixedConditions, cond)
continue
}
// Check if this is a SQL expression/literal that shouldn't be prefixed
lowerCond := strings.ToLower(strings.TrimSpace(cond))
if IsSQLExpression(lowerCond) {
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
fixedConditions = append(fixedConditions, cond)
continue
}
// Extract the column name (first identifier before operator)
columnName := ExtractColumnName(cond)
if columnName == "" {
// Can't identify column name, require explicit prefix
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
}
// Add relation prefix to the column name only
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
fixedConditions = append(fixedConditions, fixedCond)
}
fixedWhere := strings.Join(fixedConditions, " AND ")
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
return fixedWhere, nil
}
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
func IsSQLExpression(cond string) bool {
// Common SQL literals and expressions
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
for _, literal := range sqlLiterals {
if cond == literal {
return true
}
}
return false
}
// ExtractColumnName extracts the column name from a WHERE condition
// For example: "status = 'active'" returns "status"
func ExtractColumnName(cond string) string {
// Common SQL operators
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
for _, op := range operators {
if idx := strings.Index(cond, op); idx > 0 {
columnName := strings.TrimSpace(cond[:idx])
// Remove quotes if present
columnName = strings.Trim(columnName, "`\"'")
return columnName
}
}
// If no operator found, check if it's a simple identifier (for boolean columns)
parts := strings.Fields(cond)
if len(parts) > 0 {
columnName := strings.Trim(parts[0], "`\"'")
// Check if it's a valid identifier (not a SQL keyword)
if !IsSQLKeyword(strings.ToLower(columnName)) {
return columnName
}
}
return ""
}
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
func IsSQLKeyword(word string) bool {
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
for _, kw := range keywords {
if word == kw {
return true
}
}
return false
}

View File

@@ -32,15 +32,17 @@ type Parameter struct {
}
type PreloadOption struct {
Relation string `json:"relation"`
Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"`
Sort []SortOption `json:"sort"`
Filters []FilterOption `json:"filters"`
Where string `json:"where"`
Limit *int `json:"limit"`
Offset *int `json:"offset"`
Updatable *bool `json:"updateable"` // if true, the relation can be updated
Relation string `json:"relation"`
Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"`
Sort []SortOption `json:"sort"`
Filters []FilterOption `json:"filters"`
Where string `json:"where"`
Limit *int `json:"limit"`
Offset *int `json:"offset"`
Updatable *bool `json:"updateable"` // if true, the relation can be updated
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
}
type FilterOption struct {

View File

@@ -103,3 +103,18 @@ func CatchPanicCallback(location string, cb func(err any)) {
func CatchPanic(location string) {
CatchPanicCallback(location, nil)
}
// HandlePanic logs a panic and returns it as an error
// This should be called with the result of recover() from a deferred function
// Example usage:
//
// defer func() {
// if r := recover(); r != nil {
// err = logger.HandlePanic("MethodName", r)
// }
// }()
func HandlePanic(methodName string, r any) error {
stack := debug.Stack()
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
return fmt.Errorf("panic in %s: %v", methodName, r)
}

View File

@@ -323,6 +323,127 @@ func ExtractColumnFromBunTag(tag string) string {
return ""
}
// GetSQLModelColumns extracts column names that have valid SQL field mappings
// This function only returns columns that:
// 1. Have bun or gorm tags (not just json tags)
// 2. Are not relations (no rel:, join:, foreignKey, references, many2many tags)
// 3. Are not scan-only embedded fields
func GetSQLModelColumns(model any) []string {
var columns []string
modelType := reflect.TypeOf(model)
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return columns
}
collectSQLColumnsFromType(modelType, &columns, false)
return columns
}
// collectSQLColumnsFromType recursively collects SQL column names from a struct type
// scanOnlyEmbedded indicates if we're inside a scan-only embedded struct
func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbedded bool) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Check if the embedded struct itself is scan-only
isScanOnly := scanOnlyEmbedded
bunTag := field.Tag.Get("bun")
if bunTag != "" && isBunFieldScanOnly(bunTag) {
isScanOnly = true
}
// Recursively process embedded struct
if fieldType.Kind() == reflect.Struct {
collectSQLColumnsFromType(fieldType, columns, isScanOnly)
continue
}
}
// Skip fields in scan-only embedded structs
if scanOnlyEmbedded {
continue
}
// Get bun and gorm tags
bunTag := field.Tag.Get("bun")
gormTag := field.Tag.Get("gorm")
// Skip if neither bun nor gorm tag exists
if bunTag == "" && gormTag == "" {
continue
}
// Skip if explicitly marked with "-"
if bunTag == "-" || gormTag == "-" {
continue
}
// Skip if field itself is scan-only (bun)
if bunTag != "" && isBunFieldScanOnly(bunTag) {
continue
}
// Skip if field itself is read-only (gorm)
if gormTag != "" && isGormFieldReadOnly(gormTag) {
continue
}
// Skip relation fields (bun)
if bunTag != "" {
// Skip if it's a bun relation (rel:, join:, or m2m:)
if strings.Contains(bunTag, "rel:") ||
strings.Contains(bunTag, "join:") ||
strings.Contains(bunTag, "m2m:") {
continue
}
}
// Skip relation fields (gorm)
if gormTag != "" {
// Skip if it has gorm relationship tags
if strings.Contains(gormTag, "foreignKey:") ||
strings.Contains(gormTag, "references:") ||
strings.Contains(gormTag, "many2many:") ||
strings.Contains(gormTag, "constraint:") {
continue
}
}
// Get column name
columnName := ""
if bunTag != "" {
columnName = ExtractColumnFromBunTag(bunTag)
}
if columnName == "" && gormTag != "" {
columnName = ExtractColumnFromGormTag(gormTag)
}
// Skip if we couldn't extract a column name
if columnName == "" {
continue
}
*columns = append(*columns, columnName)
}
}
// IsColumnWritable checks if a column can be written to in the database
// For bun: returns false if the field has "scanonly" tag
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag

View File

@@ -474,3 +474,143 @@ func TestIsColumnWritableWithEmbedded(t *testing.T) {
})
}
}
// Test models with relations for GetSQLModelColumns
type User struct {
ID int `bun:"id,pk" json:"id"`
Name string `bun:"name" json:"name"`
Email string `bun:"email" json:"email"`
ProfileData string `json:"profile_data"` // No bun/gorm tag
Posts []Post `bun:"rel:has-many,join:id=user_id" json:"posts"`
Profile *Profile `bun:"rel:has-one,join:id=user_id" json:"profile"`
RowNumber int64 `bun:",scanonly" json:"_rownumber"`
}
type Post struct {
ID int `gorm:"column:id;primaryKey" json:"id"`
Title string `gorm:"column:title" json:"title"`
UserID int `gorm:"column:user_id;foreignKey" json:"user_id"`
User *User `gorm:"foreignKey:UserID;references:ID" json:"user"`
Tags []Tag `gorm:"many2many:post_tags" json:"tags"`
Content string `json:"content"` // No bun/gorm tag
}
type Profile struct {
ID int `bun:"id,pk" json:"id"`
Bio string `bun:"bio" json:"bio"`
UserID int `bun:"user_id" json:"user_id"`
}
type Tag struct {
ID int `gorm:"column:id;primaryKey" json:"id"`
Name string `gorm:"column:name" json:"name"`
}
// Model with scan-only embedded struct
type EntityWithScanOnlyEmbedded struct {
ID int `bun:"id,pk" json:"id"`
Name string `bun:"name" json:"name"`
AdhocBuffer `bun:",scanonly"` // Entire embedded struct is scan-only
}
func TestGetSQLModelColumns(t *testing.T) {
tests := []struct {
name string
model any
expected []string
}{
{
name: "Bun model with relations - excludes relations and non-SQL fields",
model: User{},
// Should include: id, name, email (has bun tags)
// Should exclude: profile_data (no bun tag), Posts/Profile (relations), RowNumber (scan-only in embedded would be excluded)
expected: []string{"id", "name", "email"},
},
{
name: "GORM model with relations - excludes relations and non-SQL fields",
model: Post{},
// Should include: id, title, user_id (has gorm tags)
// Should exclude: content (no gorm tag), User/Tags (relations)
expected: []string{"id", "title", "user_id"},
},
{
name: "Model with embedded base and scan-only embedded",
model: EntityWithScanOnlyEmbedded{},
// Should include: id, name from main struct
// Should exclude: all fields from AdhocBuffer (scan-only embedded struct)
expected: []string{"id", "name"},
},
{
name: "Model with embedded - includes SQL fields, excludes scan-only",
model: ModelWithEmbedded{},
// Should include: rid_base, created_at (from BaseModel), name, description (from main)
// Should exclude: cql1, cql2, _rownumber (from AdhocBuffer - scan-only fields)
expected: []string{"rid_base", "created_at", "name", "description"},
},
{
name: "GORM model with embedded - includes SQL fields, excludes scan-only",
model: GormModelWithEmbedded{},
// Should include: rid_base, created_at (from GormBaseModel), name, description (from main)
// Should exclude: cql1, cql2 (scan-only), _rownumber (no gorm column tag, marked as -)
expected: []string{"rid_base", "created_at", "name", "description"},
},
{
name: "Simple Profile model",
model: Profile{},
// Should include all fields with bun tags
expected: []string{"id", "bio", "user_id"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetSQLModelColumns(tt.model)
if len(result) != len(tt.expected) {
t.Errorf("GetSQLModelColumns() returned %d columns, want %d.\nGot: %v\nWant: %v",
len(result), len(tt.expected), result, tt.expected)
return
}
for i, col := range result {
if col != tt.expected[i] {
t.Errorf("GetSQLModelColumns()[%d] = %v, want %v.\nFull result: %v",
i, col, tt.expected[i], result)
}
}
})
}
}
func TestGetSQLModelColumnsVsGetModelColumns(t *testing.T) {
// Demonstrate the difference between GetModelColumns and GetSQLModelColumns
user := User{}
allColumns := GetModelColumns(user)
sqlColumns := GetSQLModelColumns(user)
t.Logf("GetModelColumns(User): %v", allColumns)
t.Logf("GetSQLModelColumns(User): %v", sqlColumns)
// GetModelColumns should return more columns (includes fields with only json tags)
if len(allColumns) <= len(sqlColumns) {
t.Errorf("Expected GetModelColumns to return more columns than GetSQLModelColumns")
}
// GetSQLModelColumns should not include 'profile_data' (no bun tag)
for _, col := range sqlColumns {
if col == "profile_data" {
t.Errorf("GetSQLModelColumns should not include 'profile_data' (no bun/gorm tag)")
}
}
// GetModelColumns should include 'profile_data' (has json tag)
hasProfileData := false
for _, col := range allColumns {
if col == "profile_data" {
hasProfileData = true
break
}
}
if !hasProfileData {
t.Errorf("GetModelColumns should include 'profile_data' (has json tag)")
}
}

View File

@@ -191,6 +191,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Table(tableName)
}
if len(options.Columns) == 0 && (len(options.ComputedColumns) > 0) {
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
options.Columns = reflection.GetSQLModelColumns(model)
}
// Apply column selection
if len(options.Columns) > 0 {
logger.Debug("Selecting columns: %v", options.Columns)
@@ -1132,15 +1137,20 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
// ORMs like GORM and Bun expect the struct field name, not the JSON name
relationFieldName := relInfo.fieldName
// For now, we'll preload without conditions
// TODO: Implement column selection and filtering for preloads
// This requires a more sophisticated approach with callbacks or query builders
// Apply preloading
// Validate and fix WHERE clause to ensure it contains the relation prefix
if len(preload.Where) > 0 {
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, relationFieldName)
if err != nil {
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err))
}
preload.Where = fixedWhere
}
logger.Debug("Applying preload: %s", relationFieldName)
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
if len(preload.OmitColumns) > 0 {
allCols := reflection.GetModelColumns(model)
allCols := reflection.GetSQLModelColumns(model)
// Remove omitted columns
preload.Columns = []string{}
for _, col := range allCols {

View File

@@ -13,6 +13,7 @@ const (
contextKeyTableName contextKey = "tableName"
contextKeyModel contextKey = "model"
contextKeyModelPtr contextKey = "modelPtr"
contextKeyOptions contextKey = "options"
)
// WithSchema adds schema to context
@@ -74,12 +75,28 @@ func GetModelPtr(ctx context.Context) interface{} {
return ctx.Value(contextKeyModelPtr)
}
// WithOptions adds request options to context
func WithOptions(ctx context.Context, options ExtendedRequestOptions) context.Context {
return context.WithValue(ctx, contextKeyOptions, options)
}
// GetOptions retrieves request options from context
func GetOptions(ctx context.Context) *ExtendedRequestOptions {
if v := ctx.Value(contextKeyOptions); v != nil {
if opts, ok := v.(ExtendedRequestOptions); ok {
return &opts
}
}
return nil
}
// WithRequestData adds all request-scoped data to context at once
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}, options ExtendedRequestOptions) context.Context {
ctx = WithSchema(ctx, schema)
ctx = WithEntity(ctx, entity)
ctx = WithTableName(ctx, tableName)
ctx = WithModel(ctx, model)
ctx = WithModelPtr(ctx, modelPtr)
ctx = WithOptions(ctx, options)
return ctx
}

View File

@@ -65,9 +65,6 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
entity := params["entity"]
id := params["id"]
// Parse options from headers (now returns ExtendedRequestOptions)
options := h.parseOptionsFromHeaders(r)
// Determine operation based on HTTP method
method := r.Method()
@@ -104,13 +101,16 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
tableName := h.getTableName(schema, entity, model)
// Add request-scoped data to context
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
// Parse options from headers - this now includes relation name resolution
options := h.parseOptionsFromHeaders(r, model)
// Validate and filter columns in options (log warnings for invalid columns)
validator := common.NewColumnValidator(model)
options = filterExtendedOptions(validator, options)
// Add request-scoped data to context (including options)
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
switch method {
case "GET":
if id != "" {
@@ -260,9 +260,12 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Table(tableName)
}
// Note: X-Files configuration is now applied via parseXFiles which populates
// ExtendedRequestOptions fields (columns, filters, sort, preload, etc.)
// These are applied below in the normal query building process
// If we have computed columns/expressions but options.Columns is empty,
// populate it with all model columns first since computed columns are additions
if len(options.Columns) == 0 && (len(options.ComputedQL) > 0 || len(options.ComputedColumns) > 0) {
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
options.Columns = reflection.GetSQLModelColumns(model)
}
// Apply ComputedQL fields if any
if len(options.ComputedQL) > 0 {
@@ -344,50 +347,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
for idx := range options.Preload {
preload := options.Preload[idx]
logger.Debug("Applying preload: %s", preload.Relation)
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
if len(preload.OmitColumns) > 0 {
allCols := reflection.GetModelColumns(model)
// Remove omitted columns
preload.Columns = []string{}
for _, col := range allCols {
addCols := true
for _, omitCol := range preload.OmitColumns {
if col == omitCol {
addCols = false
break
}
}
if addCols {
preload.Columns = append(preload.Columns, col)
}
}
}
if len(preload.Columns) > 0 {
sq = sq.Column(preload.Columns...)
// Validate and fix WHERE clause to ensure it contains the relation prefix
if len(preload.Where) > 0 {
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, preload.Relation)
if err != nil {
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",
fmt.Sprintf("Invalid preload WHERE clause for relation '%s'", preload.Relation), err)
return
}
preload.Where = fixedWhere
}
if len(preload.Filters) > 0 {
for _, filter := range preload.Filters {
sq = h.applyFilter(sq, filter, "", false, "AND")
}
}
if len(preload.Sort) > 0 {
for _, sort := range preload.Sort {
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
}
}
if len(preload.Where) > 0 {
sq = sq.Where(preload.Where)
}
if preload.Limit != nil && *preload.Limit > 0 {
sq = sq.Limit(*preload.Limit)
}
return sq
})
// Apply the preload with recursive support
query = h.applyPreloadWithRecursion(query, preload, model, 0)
}
// Apply DISTINCT if requested
@@ -573,6 +547,111 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
h.sendFormattedResponse(w, modelPtr, metadata, options)
}
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery {
// Apply the preload
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
// Get the related model for column operations
relatedModel := h.getRelationModel(model, preload.Relation)
if relatedModel == nil {
logger.Warn("Could not get related model for preload: %s", preload.Relation)
relatedModel = model // fallback to parent model
}
// If we have computed columns but no explicit columns, populate with all model columns first
// since computed columns are additions
if len(preload.Columns) == 0 && len(preload.ComputedQL) > 0 && relatedModel != nil {
logger.Debug("Populating preload columns with all model columns since computed columns are additions")
preload.Columns = reflection.GetSQLModelColumns(relatedModel)
}
// Apply ComputedQL fields if any
if len(preload.ComputedQL) > 0 {
for colName, colExpr := range preload.ComputedQL {
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName))
// Remove the computed column from selected columns to avoid duplication
for colIndex := range preload.Columns {
if preload.Columns[colIndex] == colName {
preload.Columns = append(preload.Columns[:colIndex], preload.Columns[colIndex+1:]...)
break
}
}
}
}
// Handle OmitColumns
if len(preload.OmitColumns) > 0 && relatedModel != nil {
allCols := reflection.GetModelColumns(relatedModel)
// Remove omitted columns
preload.Columns = []string{}
for _, col := range allCols {
addCols := true
for _, omitCol := range preload.OmitColumns {
if col == omitCol {
addCols = false
break
}
}
if addCols {
preload.Columns = append(preload.Columns, col)
}
}
}
// Apply column selection
if len(preload.Columns) > 0 {
sq = sq.Column(preload.Columns...)
}
// Apply filters
if len(preload.Filters) > 0 {
for _, filter := range preload.Filters {
sq = h.applyFilter(sq, filter, "", false, "AND")
}
}
// Apply sorting
if len(preload.Sort) > 0 {
for _, sort := range preload.Sort {
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
}
}
// Apply WHERE clause
if len(preload.Where) > 0 {
sq = sq.Where(preload.Where)
}
// Apply limit
if preload.Limit != nil && *preload.Limit > 0 {
sq = sq.Limit(*preload.Limit)
}
return sq
})
// Handle recursive preloading
if preload.Recursive && depth < 5 {
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
// For recursive relationships, we need to get the last part of the relation path
// e.g., "MastertaskItems" -> "MastertaskItems.MastertaskItems"
relationParts := strings.Split(preload.Relation, ".")
lastRelationName := relationParts[len(relationParts)-1]
// Create a recursive preload with the same configuration
// but with the relation path extended
recursivePreload := preload
recursivePreload.Relation = preload.Relation + "." + lastRelationName
// Recursively apply preload until we reach depth 5
query = h.applyPreloadWithRecursion(query, recursivePreload, model, depth+1)
}
return query
}
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
// Capture panics and return error response
defer func() {

View File

@@ -99,7 +99,8 @@ func DecodeParam(pStr string) (string, error) {
}
// parseOptionsFromHeaders parses all request options from HTTP headers
func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptions {
// If model is provided, it will resolve table names to field names in preload/expand options
func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) ExtendedRequestOptions {
options := ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Filters: make([]common.FilterOption, 0),
@@ -109,22 +110,35 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
AdvancedSQL: make(map[string]string),
ComputedQL: make(map[string]string),
Expand: make([]ExpandOption, 0),
ResponseFormat: "simple", // Default response format
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
ResponseFormat: "simple", // Default response format
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
}
// Get all headers
headers := r.AllHeaders()
// Process each header
// Get all query parameters
queryParams := r.AllQueryParams()
// Merge headers and query parameters - query parameters take precedence
// This allows the same parameters to be specified in either headers or query string
combinedParams := make(map[string]string)
for key, value := range headers {
// Normalize header key to lowercase for consistent matching
combinedParams[key] = value
}
for key, value := range queryParams {
combinedParams[key] = value
}
// Process each parameter (from both headers and query params)
for key, value := range combinedParams {
// Normalize parameter key to lowercase for consistent matching
normalizedKey := strings.ToLower(key)
// Decode value if it's base64 encoded
decodedValue := decodeHeaderValue(value)
// Parse based on header prefix/name
// Parse based on parameter prefix/name
switch {
// Field Selection
case strings.HasPrefix(normalizedKey, "x-select-fields"):
@@ -157,7 +171,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
if strings.HasSuffix(normalizedKey, "-where") {
continue
}
whereClaude := headers[fmt.Sprintf("%s-where", key)]
whereClaude := combinedParams[fmt.Sprintf("%s-where", key)]
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
case strings.HasPrefix(normalizedKey, "x-expand"):
@@ -169,14 +183,37 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
// Sorting & Pagination
case strings.HasPrefix(normalizedKey, "x-sort"):
h.parseSorting(&options, decodedValue)
//Special cases for older clients using sort(a,b,-c) syntax
case strings.HasPrefix(normalizedKey, "sort(") && strings.Contains(normalizedKey, ")"):
sortValue := normalizedKey[strings.Index(normalizedKey, "(")+1 : strings.Index(normalizedKey, ")")]
h.parseSorting(&options, sortValue)
case strings.HasPrefix(normalizedKey, "x-limit"):
if limit, err := strconv.Atoi(decodedValue); err == nil {
options.Limit = &limit
}
//Special cases for older clients using limit(n) syntax
case strings.HasPrefix(normalizedKey, "limit(") && strings.Contains(normalizedKey, ")"):
limitValue := normalizedKey[strings.Index(normalizedKey, "(")+1 : strings.Index(normalizedKey, ")")]
limitValueParts := strings.Split(limitValue, ",")
if len(limitValueParts) > 1 {
if offset, err := strconv.Atoi(limitValueParts[0]); err == nil {
options.Offset = &offset
}
if limit, err := strconv.Atoi(limitValueParts[1]); err == nil {
options.Limit = &limit
}
} else {
if limit, err := strconv.Atoi(limitValueParts[0]); err == nil {
options.Limit = &limit
}
}
case strings.HasPrefix(normalizedKey, "x-offset"):
if offset, err := strconv.Atoi(decodedValue); err == nil {
options.Offset = &offset
}
case strings.HasPrefix(normalizedKey, "x-cursor-forward"):
options.CursorForward = decodedValue
case strings.HasPrefix(normalizedKey, "x-cursor-backward"):
@@ -225,6 +262,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
}
}
// Resolve relation names (convert table names to field names) if model is provided
if model != nil {
h.resolveRelationNamesInOptions(&options, model)
}
return options
}
@@ -655,6 +697,192 @@ func (h *Handler) processXFilesRelations(xfiles *XFiles, options *ExtendedReques
}
}
// resolveRelationNamesInOptions resolves all table names to field names in preload options
// This is called internally by parseOptionsFromHeaders when a model is provided
func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions, model interface{}) {
if options == nil || model == nil {
return
}
// Resolve relation names in all preload options
for i := range options.Preload {
preload := &options.Preload[i]
// Split the relation path (e.g., "parent.child.grandchild")
parts := strings.Split(preload.Relation, ".")
resolvedParts := make([]string, 0, len(parts))
// Resolve each part of the path
currentModel := model
for _, part := range parts {
resolvedPart := h.resolveRelationName(currentModel, part)
resolvedParts = append(resolvedParts, resolvedPart)
// Try to get the model type for the next level
// This allows nested resolution
if nextModel := h.getRelationModel(currentModel, resolvedPart); nextModel != nil {
currentModel = nextModel
}
}
// Update the relation path with resolved names
resolvedPath := strings.Join(resolvedParts, ".")
if resolvedPath != preload.Relation {
logger.Debug("Resolved relation path '%s' -> '%s'", preload.Relation, resolvedPath)
preload.Relation = resolvedPath
}
}
// Resolve relation names in expand options
for i := range options.Expand {
expand := &options.Expand[i]
resolved := h.resolveRelationName(model, expand.Relation)
if resolved != expand.Relation {
logger.Debug("Resolved expand relation '%s' -> '%s'", expand.Relation, resolved)
expand.Relation = resolved
}
}
}
// getRelationModel gets the model type for a relation field
func (h *Handler) getRelationModel(model interface{}, fieldName string) interface{} {
if model == nil || fieldName == "" {
return nil
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nil
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil
}
// Find the field
field, found := modelType.FieldByName(fieldName)
if !found {
return nil
}
// Get the target type
targetType := field.Type
if targetType == nil {
return nil
}
if targetType.Kind() == reflect.Slice {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() != reflect.Struct {
return nil
}
// Create a zero value of the target type
return reflect.New(targetType).Elem().Interface()
}
// resolveRelationName resolves a relation name or table name to the actual field name in the model
// If the input is already a field name, it returns it as-is
// If the input is a table name, it looks up the corresponding relation field
func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) string {
if model == nil || nameOrTable == "" {
return nameOrTable
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nameOrTable
}
// Dereference pointer if needed
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
// Check again after dereferencing
if modelType == nil {
return nameOrTable
}
// Ensure it's a struct
if modelType.Kind() != reflect.Struct {
return nameOrTable
}
// First, check if the input matches a field name directly
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
if field.Name == nameOrTable {
// It's already a field name
logger.Debug("Input '%s' is a field name", nameOrTable)
return nameOrTable
}
}
// If not found as a field name, try to look it up as a table name
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
fieldType := field.Type
// Check if it's a slice or pointer to a struct
var targetType reflect.Type
if fieldType.Kind() == reflect.Slice {
targetType = fieldType.Elem()
} else if fieldType.Kind() == reflect.Ptr {
targetType = fieldType.Elem()
}
if targetType != nil {
// Dereference pointer if the slice contains pointers
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
}
// Check if it's a struct type
if targetType.Kind() == reflect.Struct {
// Get the type name and normalize it
typeName := targetType.Name()
// Extract the table name from type name
// Patterns: ModelCoreMastertaskitem -> mastertaskitem
// ModelMastertaskitem -> mastertaskitem
normalizedTypeName := strings.ToLower(typeName)
// Remove common prefixes like "model", "modelcore", etc.
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
// Compare normalized names
if normalizedTypeName == normalizedInput {
logger.Debug("Resolved table name '%s' to field '%s' (type: %s)", nameOrTable, field.Name, typeName)
return field.Name
}
}
}
}
// If no match found, return the original input
logger.Debug("No field found for '%s', using as-is", nameOrTable)
return nameOrTable
}
// addXFilesPreload converts an XFiles relation into a PreloadOption
// and recursively processes its children
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
@@ -662,7 +890,8 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
return
}
// Determine the relation path
// Store the table name as-is for now - it will be resolved to field name later
// when we have the model instance available
relationPath := xfile.TableName
if basePath != "" {
relationPath = basePath + "." + xfile.TableName
@@ -729,6 +958,19 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
}
}
// Add computed columns (CQL) -> ComputedQL
if len(xfile.CQLColumns) > 0 {
preloadOpt.ComputedQL = make(map[string]string)
for i, cqlExpr := range xfile.CQLColumns {
colName := fmt.Sprintf("cql%d", i+1)
preloadOpt.ComputedQL[colName] = cqlExpr
logger.Debug("X-Files: Added computed column %s to preload %s: %s", colName, relationPath, cqlExpr)
}
}
// Set recursive flag
preloadOpt.Recursive = xfile.Recursive
// Add the preload option
options.Preload = append(options.Preload, preloadOpt)

View File

@@ -0,0 +1,403 @@
package restheadspec
import (
"testing"
)
// MockRequest implements common.Request interface for testing
type MockRequest struct {
headers map[string]string
queryParams map[string]string
}
func (m *MockRequest) Method() string {
return "GET"
}
func (m *MockRequest) URL() string {
return "http://example.com/test"
}
func (m *MockRequest) Header(key string) string {
return m.headers[key]
}
func (m *MockRequest) AllHeaders() map[string]string {
return m.headers
}
func (m *MockRequest) Body() ([]byte, error) {
return nil, nil
}
func (m *MockRequest) PathParam(key string) string {
return ""
}
func (m *MockRequest) QueryParam(key string) string {
return m.queryParams[key]
}
func (m *MockRequest) AllQueryParams() map[string]string {
return m.queryParams
}
func TestParseOptionsFromQueryParams(t *testing.T) {
handler := NewHandler(nil, nil)
tests := []struct {
name string
queryParams map[string]string
headers map[string]string
validate func(t *testing.T, options ExtendedRequestOptions)
}{
{
name: "Parse custom SQL WHERE from query params",
queryParams: map[string]string{
"x-custom-sql-w-1": `("v_webui_clients".clientstatus = 0 or "v_webui_clients".clientstatus is null)`,
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if options.CustomSQLWhere == "" {
t.Error("Expected CustomSQLWhere to be set from query param")
}
expected := `("v_webui_clients".clientstatus = 0 or "v_webui_clients".clientstatus is null)`
if options.CustomSQLWhere != expected {
t.Errorf("Expected CustomSQLWhere=%q, got %q", expected, options.CustomSQLWhere)
}
},
},
{
name: "Parse sort from query params",
queryParams: map[string]string{
"x-sort": "-applicationdate,name",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if len(options.Sort) != 2 {
t.Errorf("Expected 2 sort options, got %d", len(options.Sort))
return
}
if options.Sort[0].Column != "applicationdate" || options.Sort[0].Direction != "DESC" {
t.Errorf("Expected first sort: applicationdate DESC, got %s %s", options.Sort[0].Column, options.Sort[0].Direction)
}
if options.Sort[1].Column != "name" || options.Sort[1].Direction != "ASC" {
t.Errorf("Expected second sort: name ASC, got %s %s", options.Sort[1].Column, options.Sort[1].Direction)
}
},
},
{
name: "Parse limit and offset from query params",
queryParams: map[string]string{
"x-limit": "100",
"x-offset": "50",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if options.Limit == nil || *options.Limit != 100 {
t.Errorf("Expected limit=100, got %v", options.Limit)
}
if options.Offset == nil || *options.Offset != 50 {
t.Errorf("Expected offset=50, got %v", options.Offset)
}
},
},
{
name: "Parse field filters from query params",
queryParams: map[string]string{
"x-fieldfilter-status": "active",
"x-fieldfilter-type": "user",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if len(options.Filters) != 2 {
t.Errorf("Expected 2 filters, got %d", len(options.Filters))
return
}
// Check that filters were created
foundStatus := false
foundType := false
for _, filter := range options.Filters {
if filter.Column == "status" && filter.Value == "active" && filter.Operator == "eq" {
foundStatus = true
}
if filter.Column == "type" && filter.Value == "user" && filter.Operator == "eq" {
foundType = true
}
}
if !foundStatus {
t.Error("Expected status filter not found")
}
if !foundType {
t.Error("Expected type filter not found")
}
},
},
{
name: "Parse select fields from query params",
queryParams: map[string]string{
"x-select-fields": "id,name,email",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if len(options.Columns) != 3 {
t.Errorf("Expected 3 columns, got %d", len(options.Columns))
return
}
expected := []string{"id", "name", "email"}
for i, col := range expected {
if i >= len(options.Columns) || options.Columns[i] != col {
t.Errorf("Expected column[%d]=%s, got %v", i, col, options.Columns)
}
}
},
},
{
name: "Parse preload from query params",
queryParams: map[string]string{
"x-preload": "posts:title,content|comments",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if len(options.Preload) != 2 {
t.Errorf("Expected 2 preload options, got %d", len(options.Preload))
return
}
// Check first preload (posts with columns)
if options.Preload[0].Relation != "posts" {
t.Errorf("Expected first preload relation=posts, got %s", options.Preload[0].Relation)
}
if len(options.Preload[0].Columns) != 2 {
t.Errorf("Expected 2 columns for posts preload, got %d", len(options.Preload[0].Columns))
}
// Check second preload (comments without columns)
if options.Preload[1].Relation != "comments" {
t.Errorf("Expected second preload relation=comments, got %s", options.Preload[1].Relation)
}
},
},
{
name: "Query params take precedence over headers",
queryParams: map[string]string{
"x-limit": "100",
},
headers: map[string]string{
"X-Limit": "50",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if options.Limit == nil || *options.Limit != 100 {
t.Errorf("Expected query param limit=100 to override header, got %v", options.Limit)
}
},
},
{
name: "Parse search operators from query params",
queryParams: map[string]string{
"x-searchop-contains-name": "john",
"x-searchop-gt-age": "18",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if len(options.Filters) != 2 {
t.Errorf("Expected 2 filters, got %d", len(options.Filters))
return
}
// Check for ILIKE filter
foundContains := false
foundGt := false
for _, filter := range options.Filters {
if filter.Column == "name" && filter.Operator == "ilike" {
foundContains = true
}
if filter.Column == "age" && filter.Operator == "gt" && filter.Value == "18" {
foundGt = true
}
}
if !foundContains {
t.Error("Expected contains filter not found")
}
if !foundGt {
t.Error("Expected gt filter not found")
}
},
},
{
name: "Parse complex example with multiple params",
queryParams: map[string]string{
"x-custom-sql-w-1": `("v_webui_clients".clientstatus = 0)`,
"x-sort": "-applicationdate",
"x-limit": "100",
"x-select-fields": "id,name,status",
"x-fieldfilter-active": "true",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
// Validate CustomSQLWhere
if options.CustomSQLWhere == "" {
t.Error("Expected CustomSQLWhere to be set")
}
// Validate Sort
if len(options.Sort) != 1 || options.Sort[0].Column != "applicationdate" || options.Sort[0].Direction != "DESC" {
t.Errorf("Expected sort by applicationdate DESC, got %v", options.Sort)
}
// Validate Limit
if options.Limit == nil || *options.Limit != 100 {
t.Errorf("Expected limit=100, got %v", options.Limit)
}
// Validate Columns
if len(options.Columns) != 3 {
t.Errorf("Expected 3 columns, got %d", len(options.Columns))
}
// Validate Filters
if len(options.Filters) < 1 {
t.Error("Expected at least 1 filter")
}
},
},
{
name: "Parse distinct flag from query params",
queryParams: map[string]string{
"x-distinct": "true",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if !options.Distinct {
t.Error("Expected Distinct to be true")
}
},
},
{
name: "Parse skip count flag from query params",
queryParams: map[string]string{
"x-skipcount": "true",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if !options.SkipCount {
t.Error("Expected SkipCount to be true")
}
},
},
{
name: "Parse response format from query params",
queryParams: map[string]string{
"x-syncfusion": "true",
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if options.ResponseFormat != "syncfusion" {
t.Errorf("Expected ResponseFormat=syncfusion, got %s", options.ResponseFormat)
}
},
},
{
name: "Parse custom SQL OR from query params",
queryParams: map[string]string{
"x-custom-sql-or": `("field1" = 'value1' OR "field2" = 'value2')`,
},
validate: func(t *testing.T, options ExtendedRequestOptions) {
if options.CustomSQLOr == "" {
t.Error("Expected CustomSQLOr to be set")
}
expected := `("field1" = 'value1' OR "field2" = 'value2')`
if options.CustomSQLOr != expected {
t.Errorf("Expected CustomSQLOr=%q, got %q", expected, options.CustomSQLOr)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create mock request
req := &MockRequest{
headers: tt.headers,
queryParams: tt.queryParams,
}
if req.headers == nil {
req.headers = make(map[string]string)
}
if req.queryParams == nil {
req.queryParams = make(map[string]string)
}
// Parse options
options := handler.parseOptionsFromHeaders(req, nil)
// Validate
tt.validate(t, options)
})
}
}
func TestQueryParamsWithURLEncoding(t *testing.T) {
handler := NewHandler(nil, nil)
// Test with URL-encoded query parameter (like the user's example)
req := &MockRequest{
headers: make(map[string]string),
queryParams: map[string]string{
// URL-encoded version of the SQL WHERE clause
"x-custom-sql-w-1": `("v_webui_clients".clientstatus = 0 or "v_webui_clients".clientstatus is null) and ("v_webui_clients".inactive = 0 or "v_webui_clients".inactive is null)`,
},
}
options := handler.parseOptionsFromHeaders(req, nil)
if options.CustomSQLWhere == "" {
t.Error("Expected CustomSQLWhere to be set from URL-encoded query param")
}
// The SQL should contain the expected conditions
if !contains(options.CustomSQLWhere, "clientstatus") {
t.Error("Expected CustomSQLWhere to contain 'clientstatus'")
}
if !contains(options.CustomSQLWhere, "inactive") {
t.Error("Expected CustomSQLWhere to contain 'inactive'")
}
}
func TestHeadersAndQueryParamsCombined(t *testing.T) {
handler := NewHandler(nil, nil)
// Test that headers and query params can work together
req := &MockRequest{
headers: map[string]string{
"X-Select-Fields": "id,name",
"X-Limit": "50",
},
queryParams: map[string]string{
"x-sort": "-created_at",
"x-offset": "10",
// This should override the header value
"x-limit": "100",
},
}
options := handler.parseOptionsFromHeaders(req, nil)
// Verify columns from header
if len(options.Columns) != 2 {
t.Errorf("Expected 2 columns from header, got %d", len(options.Columns))
}
// Verify sort from query param
if len(options.Sort) != 1 || options.Sort[0].Column != "created_at" {
t.Errorf("Expected sort from query param, got %v", options.Sort)
}
// Verify offset from query param
if options.Offset == nil || *options.Offset != 10 {
t.Errorf("Expected offset=10 from query param, got %v", options.Offset)
}
// Verify limit from query param (should override header)
if options.Limit == nil {
t.Error("Expected limit to be set from query param")
} else if *options.Limit != 100 {
t.Errorf("Expected limit=100 from query param (overriding header), got %d", *options.Limit)
}
}
// Helper function to check if a string contains a substring
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}