Compare commits

...

8 Commits

Author SHA1 Message Date
Hein
05962035b6 when you specify computed columns without explicitly listing base columns, you'll get all base model column
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-20 17:34:46 +02:00
Hein
1cd04b7083 Better where clause handling for preloads 2025-11-20 17:02:27 +02:00
Hein
0d4909054c Better handling of preload where conditions and a few panic changes 2025-11-20 16:50:26 +02:00
Hein
745564f2e7 More Panic Recovery for reflection on orm 2025-11-20 15:20:21 +02:00
Hein
311e50bfdd Better relation lookup
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-20 14:30:59 +02:00
Hein
c95bc9e633 Added x-files feature 2025-11-20 12:47:36 +02:00
Hein
07b09e2025 handle JSON sql columns 2025-11-20 12:04:19 +02:00
Hein
3d5334002d Fixes on Table Name on insert 2025-11-20 11:49:07 +02:00
13 changed files with 1671 additions and 54 deletions

View File

@@ -9,6 +9,7 @@ import (
"github.com/uptrace/bun"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
@@ -43,12 +44,22 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery {
return &BunDeleteQuery{query: b.db.NewDelete()}
}
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunAdapter.Exec", r)
}
}()
result, err := b.db.ExecContext(ctx, query, args...)
return &BunResult{result: result}, err
}
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunAdapter.Query", r)
}
}()
return b.db.NewRaw(query, args...).Scan(ctx, dest)
}
@@ -73,7 +84,12 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error {
return nil
}
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
}
}()
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
// Create adapter with transaction
adapter := &BunTxAdapter{tx: tx}
@@ -219,6 +235,11 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() {
if r := recover(); r != nil {
logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
}
}()
if len(apply) == 0 {
return sq
}
@@ -276,15 +297,38 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
return b
}
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) error {
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Scan", r)
}
}()
if dest == nil {
return fmt.Errorf("destination cannot be nil")
}
return b.query.Scan(ctx, dest)
}
func (b *BunSelectQuery) ScanModel(ctx context.Context) error {
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
}
}()
if b.query.GetModel() == nil {
return fmt.Errorf("model is nil")
}
return b.query.Scan(ctx)
}
func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Count", r)
count = 0
}
}()
// If Model() was set, use bun's native Count() which works properly
if b.hasModel {
count, err := b.query.Count(ctx)
@@ -293,15 +337,20 @@ func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
// Otherwise, wrap as subquery to avoid "Model(nil)" error
// This is needed when only Table() is set without a model
var count int
err := b.db.NewSelect().
err = b.db.NewSelect().
TableExpr("(?) AS subquery", b.query).
ColumnExpr("COUNT(*)").
Scan(ctx, &count)
return count, err
}
func (b *BunSelectQuery) Exists(ctx context.Context) (bool, error) {
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Exists", r)
exists = false
}
}()
return b.query.Exists(ctx)
}
@@ -319,6 +368,9 @@ func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
}
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
if b.hasModel {
return b
}
b.query = b.query.Table(table)
return b
}
@@ -343,7 +395,12 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
return b
}
func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunInsertQuery.Exec", r)
}
}()
if b.values != nil && len(b.values) > 0 {
if !b.hasModel {
// If no model was set, use the values map as the model
@@ -424,7 +481,12 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
return b
}
func (b *BunUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
}
}()
result, err := b.query.Exec(ctx)
return &BunResult{result: result}, err
}
@@ -449,7 +511,12 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ
return b
}
func (b *BunDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
}
}()
result, err := b.query.Exec(ctx)
return &BunResult{result: result}, err
}

View File

@@ -8,6 +8,7 @@ import (
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
@@ -38,12 +39,22 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
return &GormDeleteQuery{db: g.db}
}
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.Exec", r)
}
}()
result := g.db.WithContext(ctx).Exec(query, args...)
return &GormResult{result: result}, result.Error
}
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.Query", r)
}
}()
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
}
@@ -63,7 +74,12 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
return g.db.WithContext(ctx).Rollback().Error
}
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
}
}()
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx}
return fn(adapter)
@@ -255,26 +271,48 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
return g
}
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) error {
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Scan", r)
}
}()
return g.db.WithContext(ctx).Find(dest).Error
}
func (g *GormSelectQuery) ScanModel(ctx context.Context) error {
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
}
}()
if g.db.Statement.Model == nil {
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
}
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
}
func (g *GormSelectQuery) Count(ctx context.Context) (int, error) {
var count int64
err := g.db.WithContext(ctx).Count(&count).Error
return int(count), err
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Count", r)
count = 0
}
}()
var count64 int64
err = g.db.WithContext(ctx).Count(&count64).Error
return int(count64), err
}
func (g *GormSelectQuery) Exists(ctx context.Context) (bool, error) {
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Exists", r)
exists = false
}
}()
var count int64
err := g.db.WithContext(ctx).Limit(1).Count(&count).Error
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
return count > 0, err
}
@@ -314,7 +352,12 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
return g
}
func (g *GormInsertQuery) Exec(ctx context.Context) (common.Result, error) {
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormInsertQuery.Exec", r)
}
}()
var result *gorm.DB
switch {
case g.model != nil:
@@ -401,7 +444,12 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
return g
}
func (g *GormUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Updates(g.updates)
return &GormResult{result: result}, result.Error
}
@@ -428,7 +476,12 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
return g
}
func (g *GormDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Delete(g.model)
return &GormResult{result: result}, result.Error
}

136
pkg/common/sql_helpers.go Normal file
View File

@@ -0,0 +1,136 @@
package common
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
// the relation prefix (alias). If not present, it attempts to add it to column references.
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
if where == "" {
return where, nil
}
// Check if the relation name is already present in the WHERE clause
lowerWhere := strings.ToLower(where)
lowerRelation := strings.ToLower(relationName)
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
if strings.Contains(lowerWhere, lowerRelation+".") ||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
// Relation prefix is already present
return where, nil
}
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
// we can't safely auto-fix it - require explicit prefix
if strings.Contains(lowerWhere, " or ") ||
strings.Contains(where, "(") ||
strings.Contains(where, ")") {
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
}
// Try to add the relation prefix to simple column references
// This handles basic cases like "column = value" or "column = value AND other_column = value"
// Split by AND to handle multiple conditions (case-insensitive)
originalConditions := strings.Split(where, " AND ")
// If uppercase split didn't work, try lowercase
if len(originalConditions) == 1 {
originalConditions = strings.Split(where, " and ")
}
fixedConditions := make([]string, 0, len(originalConditions))
for _, cond := range originalConditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Check if this condition already has a table prefix (contains a dot)
if strings.Contains(cond, ".") {
fixedConditions = append(fixedConditions, cond)
continue
}
// Check if this is a SQL expression/literal that shouldn't be prefixed
lowerCond := strings.ToLower(strings.TrimSpace(cond))
if IsSQLExpression(lowerCond) {
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
fixedConditions = append(fixedConditions, cond)
continue
}
// Extract the column name (first identifier before operator)
columnName := ExtractColumnName(cond)
if columnName == "" {
// Can't identify column name, require explicit prefix
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
}
// Add relation prefix to the column name only
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
fixedConditions = append(fixedConditions, fixedCond)
}
fixedWhere := strings.Join(fixedConditions, " AND ")
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
return fixedWhere, nil
}
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
func IsSQLExpression(cond string) bool {
// Common SQL literals and expressions
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
for _, literal := range sqlLiterals {
if cond == literal {
return true
}
}
return false
}
// ExtractColumnName extracts the column name from a WHERE condition
// For example: "status = 'active'" returns "status"
func ExtractColumnName(cond string) string {
// Common SQL operators
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
for _, op := range operators {
if idx := strings.Index(cond, op); idx > 0 {
columnName := strings.TrimSpace(cond[:idx])
// Remove quotes if present
columnName = strings.Trim(columnName, "`\"'")
return columnName
}
}
// If no operator found, check if it's a simple identifier (for boolean columns)
parts := strings.Fields(cond)
if len(parts) > 0 {
columnName := strings.Trim(parts[0], "`\"'")
// Check if it's a valid identifier (not a SQL keyword)
if !IsSQLKeyword(strings.ToLower(columnName)) {
return columnName
}
}
return ""
}
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
func IsSQLKeyword(word string) bool {
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
for _, kw := range keywords {
if word == kw {
return true
}
}
return false
}

View File

@@ -92,9 +92,27 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
return strings.ToLower(field.Name)
}
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func extractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// ValidateColumn validates a single column name
// Returns nil if valid, error if invalid
// Columns prefixed with "cql" (case insensitive) are always valid
// Handles PostgreSQL JSON operators (-> and ->>)
func (v *ColumnValidator) ValidateColumn(column string) error {
// Allow empty columns
if column == "" {
@@ -106,8 +124,11 @@ func (v *ColumnValidator) ValidateColumn(column string) error {
return nil
}
// Extract source column name (remove JSON operators like ->> or ->)
sourceColumn := extractSourceColumn(column)
// Check if column exists in model
if _, exists := v.validColumns[strings.ToLower(column)]; !exists {
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
}

View File

@@ -0,0 +1,124 @@
package common
import (
"testing"
)
func TestExtractSourceColumn(t *testing.T) {
testCases := []struct {
name string
input string
expected string
}{
{
name: "simple column name",
input: "columna",
expected: "columna",
},
{
name: "column with ->> operator",
input: "columna->>'val'",
expected: "columna",
},
{
name: "column with -> operator",
input: "columna->'key'",
expected: "columna",
},
{
name: "column with table prefix and ->> operator",
input: "table.columna->>'val'",
expected: "table.columna",
},
{
name: "column with table prefix and -> operator",
input: "table.columna->'key'",
expected: "table.columna",
},
{
name: "complex JSON path with ->>",
input: "data->>'nested'->>'value'",
expected: "data",
},
{
name: "column with spaces before operator",
input: "columna ->>'val'",
expected: "columna",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := extractSourceColumn(tc.input)
if result != tc.expected {
t.Errorf("extractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
}
})
}
}
func TestValidateColumnWithJSONOperators(t *testing.T) {
// Create a test model
type TestModel struct {
ID int `json:"id"`
Name string `json:"name"`
Data string `json:"data"` // JSON column
Metadata string `json:"metadata"`
}
validator := NewColumnValidator(TestModel{})
testCases := []struct {
name string
column string
shouldErr bool
}{
{
name: "simple valid column",
column: "name",
shouldErr: false,
},
{
name: "valid column with ->> operator",
column: "data->>'field'",
shouldErr: false,
},
{
name: "valid column with -> operator",
column: "metadata->'key'",
shouldErr: false,
},
{
name: "invalid column",
column: "invalid_column",
shouldErr: true,
},
{
name: "invalid column with ->> operator",
column: "invalid_column->>'field'",
shouldErr: true,
},
{
name: "cql prefixed column (always valid)",
column: "cql_computed",
shouldErr: false,
},
{
name: "empty column",
column: "",
shouldErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := validator.ValidateColumn(tc.column)
if tc.shouldErr && err == nil {
t.Errorf("ValidateColumn(%q) expected error, got nil", tc.column)
}
if !tc.shouldErr && err != nil {
t.Errorf("ValidateColumn(%q) expected no error, got %v", tc.column, err)
}
})
}
}

View File

@@ -103,3 +103,18 @@ func CatchPanicCallback(location string, cb func(err any)) {
func CatchPanic(location string) {
CatchPanicCallback(location, nil)
}
// HandlePanic logs a panic and returns it as an error
// This should be called with the result of recover() from a deferred function
// Example usage:
//
// defer func() {
// if r := recover(); r != nil {
// err = logger.HandlePanic("MethodName", r)
// }
// }()
func HandlePanic(methodName string, r any) error {
stack := debug.Stack()
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
return fmt.Errorf("panic in %s: %v", methodName, r)
}

View File

@@ -1132,10 +1132,15 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
// ORMs like GORM and Bun expect the struct field name, not the JSON name
relationFieldName := relInfo.fieldName
// For now, we'll preload without conditions
// TODO: Implement column selection and filtering for preloads
// This requires a more sophisticated approach with callbacks or query builders
// Apply preloading
// Validate and fix WHERE clause to ensure it contains the relation prefix
if len(preload.Where) > 0 {
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, relationFieldName)
if err != nil {
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err))
}
preload.Where = fixedWhere
}
logger.Debug("Applying preload: %s", relationFieldName)
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {

View File

@@ -13,6 +13,7 @@ const (
contextKeyTableName contextKey = "tableName"
contextKeyModel contextKey = "model"
contextKeyModelPtr contextKey = "modelPtr"
contextKeyOptions contextKey = "options"
)
// WithSchema adds schema to context
@@ -74,12 +75,28 @@ func GetModelPtr(ctx context.Context) interface{} {
return ctx.Value(contextKeyModelPtr)
}
// WithOptions adds request options to context
func WithOptions(ctx context.Context, options ExtendedRequestOptions) context.Context {
return context.WithValue(ctx, contextKeyOptions, options)
}
// GetOptions retrieves request options from context
func GetOptions(ctx context.Context) *ExtendedRequestOptions {
if v := ctx.Value(contextKeyOptions); v != nil {
if opts, ok := v.(ExtendedRequestOptions); ok {
return &opts
}
}
return nil
}
// WithRequestData adds all request-scoped data to context at once
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}, options ExtendedRequestOptions) context.Context {
ctx = WithSchema(ctx, schema)
ctx = WithEntity(ctx, entity)
ctx = WithTableName(ctx, tableName)
ctx = WithModel(ctx, model)
ctx = WithModelPtr(ctx, modelPtr)
ctx = WithOptions(ctx, options)
return ctx
}

View File

@@ -65,9 +65,6 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
entity := params["entity"]
id := params["id"]
// Parse options from headers (now returns ExtendedRequestOptions)
options := h.parseOptionsFromHeaders(r)
// Determine operation based on HTTP method
method := r.Method()
@@ -104,13 +101,16 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
tableName := h.getTableName(schema, entity, model)
// Add request-scoped data to context
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
// Parse options from headers - this now includes relation name resolution
options := h.parseOptionsFromHeaders(r, model)
// Validate and filter columns in options (log warnings for invalid columns)
validator := common.NewColumnValidator(model)
options = filterExtendedOptions(validator, options)
// Add request-scoped data to context (including options)
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
switch method {
case "GET":
if id != "" {
@@ -260,6 +260,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Table(tableName)
}
// If we have computed columns/expressions but options.Columns is empty,
// populate it with all model columns first since computed columns are additions
if len(options.Columns) == 0 && (len(options.ComputedQL) > 0 || len(options.ComputedColumns) > 0) {
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
options.Columns = reflection.GetModelColumns(model)
}
// Apply ComputedQL fields if any
if len(options.ComputedQL) > 0 {
for colName, colExpr := range options.ComputedQL {
@@ -340,6 +347,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
for idx := range options.Preload {
preload := options.Preload[idx]
logger.Debug("Applying preload: %s", preload.Relation)
// Validate and fix WHERE clause to ensure it contains the relation prefix
if len(preload.Where) > 0 {
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, preload.Relation)
if err != nil {
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",
fmt.Sprintf("Invalid preload WHERE clause for relation '%s'", preload.Relation), err)
return
}
preload.Where = fixedWhere
}
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
if len(preload.OmitColumns) > 0 {
allCols := reflection.GetModelColumns(model)
@@ -663,7 +683,14 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
}
// Create insert query
query := tx.NewInsert().Model(modelValue).Table(tableName).Returning("*")
query := tx.NewInsert().Model(modelValue)
// Only set Table() if the model doesn't provide a table name via TableNameProvider
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
query = query.Table(tableName)
}
query = query.Returning("*")
// Execute BeforeScan hooks - pass query chain so hooks can modify it
itemHookCtx := &HookContext{
@@ -1640,13 +1667,9 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
data = h.normalizeResultArray(data)
}
response := common.Response{
Success: true,
Data: data,
Metadata: metadata,
}
// Return data as-is without wrapping in common.Response
w.WriteHeader(http.StatusOK)
if err := w.WriteJSON(response); err != nil {
if err := w.WriteJSON(data); err != nil {
logger.Error("Failed to write JSON response: %v", err)
}
}

View File

@@ -2,6 +2,7 @@ package restheadspec
import (
"encoding/base64"
"encoding/json"
"fmt"
"reflect"
"strconv"
@@ -42,6 +43,9 @@ type ExtendedRequestOptions struct {
// Transaction
AtomicTransaction bool
// X-Files configuration - comprehensive query options as a single JSON object
XFiles *XFiles
}
// ExpandOption represents a relation expansion configuration
@@ -95,7 +99,8 @@ func DecodeParam(pStr string) (string, error) {
}
// parseOptionsFromHeaders parses all request options from HTTP headers
func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptions {
// If model is provided, it will resolve table names to field names in preload/expand options
func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) ExtendedRequestOptions {
options := ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Filters: make([]common.FilterOption, 0),
@@ -214,9 +219,18 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
// Transaction Control
case strings.HasPrefix(normalizedKey, "x-transaction-atomic"):
options.AtomicTransaction = strings.EqualFold(decodedValue, "true")
// X-Files - comprehensive JSON configuration
case strings.HasPrefix(normalizedKey, "x-files"):
h.parseXFiles(&options, decodedValue)
}
}
// Resolve relation names (convert table names to field names) if model is provided
if model != nil {
h.resolveRelationNamesInOptions(&options, model)
}
return options
}
@@ -480,12 +494,472 @@ func (h *Handler) parseCommaSeparated(value string) []string {
return result
}
// parseXFiles parses x-files header containing comprehensive JSON configuration
// and populates ExtendedRequestOptions fields from it
func (h *Handler) parseXFiles(options *ExtendedRequestOptions, value string) {
if value == "" {
return
}
var xfiles XFiles
if err := json.Unmarshal([]byte(value), &xfiles); err != nil {
logger.Warn("Failed to parse x-files header: %v", err)
return
}
logger.Debug("Parsed x-files configuration for table: %s", xfiles.TableName)
// Store the original XFiles for reference
options.XFiles = &xfiles
// Map XFiles fields to ExtendedRequestOptions
// Column selection
if len(xfiles.Columns) > 0 {
options.Columns = append(options.Columns, xfiles.Columns...)
logger.Debug("X-Files: Added columns: %v", xfiles.Columns)
}
// Omit columns
if len(xfiles.OmitColumns) > 0 {
options.OmitColumns = append(options.OmitColumns, xfiles.OmitColumns...)
logger.Debug("X-Files: Added omit columns: %v", xfiles.OmitColumns)
}
// Computed columns (CQL) -> ComputedQL
if len(xfiles.CQLColumns) > 0 {
if options.ComputedQL == nil {
options.ComputedQL = make(map[string]string)
}
for i, cqlExpr := range xfiles.CQLColumns {
colName := fmt.Sprintf("cql%d", i+1)
options.ComputedQL[colName] = cqlExpr
logger.Debug("X-Files: Added computed column %s: %s", colName, cqlExpr)
}
}
// Sorting
if len(xfiles.Sort) > 0 {
for _, sortField := range xfiles.Sort {
direction := "ASC"
colName := sortField
// Handle direction prefixes
if strings.HasPrefix(sortField, "-") {
direction = "DESC"
colName = strings.TrimPrefix(sortField, "-")
} else if strings.HasPrefix(sortField, "+") {
colName = strings.TrimPrefix(sortField, "+")
}
// Handle DESC suffix
if strings.HasSuffix(strings.ToLower(colName), " desc") {
direction = "DESC"
colName = strings.TrimSuffix(strings.ToLower(colName), " desc")
} else if strings.HasSuffix(strings.ToLower(colName), " asc") {
colName = strings.TrimSuffix(strings.ToLower(colName), " asc")
}
options.Sort = append(options.Sort, common.SortOption{
Column: strings.TrimSpace(colName),
Direction: direction,
})
}
logger.Debug("X-Files: Added %d sort options", len(xfiles.Sort))
}
// Filter fields
if len(xfiles.FilterFields) > 0 {
for _, filterField := range xfiles.FilterFields {
options.Filters = append(options.Filters, common.FilterOption{
Column: filterField.Field,
Operator: filterField.Operator,
Value: filterField.Value,
LogicOperator: "AND", // Default to AND
})
}
logger.Debug("X-Files: Added %d filter fields", len(xfiles.FilterFields))
}
// SQL AND conditions -> CustomSQLWhere
if len(xfiles.SqlAnd) > 0 {
if options.CustomSQLWhere != "" {
options.CustomSQLWhere += " AND "
}
options.CustomSQLWhere += "(" + strings.Join(xfiles.SqlAnd, " AND ") + ")"
logger.Debug("X-Files: Added SQL AND conditions")
}
// SQL OR conditions -> CustomSQLOr
if len(xfiles.SqlOr) > 0 {
if options.CustomSQLOr != "" {
options.CustomSQLOr += " OR "
}
options.CustomSQLOr += "(" + strings.Join(xfiles.SqlOr, " OR ") + ")"
logger.Debug("X-Files: Added SQL OR conditions")
}
// Pagination - Limit
if limitStr := xfiles.Limit.String(); limitStr != "" && limitStr != "0" {
if limitVal, err := xfiles.Limit.Int64(); err == nil && limitVal > 0 {
limit := int(limitVal)
options.Limit = &limit
logger.Debug("X-Files: Set limit: %d", limit)
}
}
// Pagination - Offset
if offsetStr := xfiles.Offset.String(); offsetStr != "" && offsetStr != "0" {
if offsetVal, err := xfiles.Offset.Int64(); err == nil && offsetVal > 0 {
offset := int(offsetVal)
options.Offset = &offset
logger.Debug("X-Files: Set offset: %d", offset)
}
}
// Cursor pagination
if xfiles.CursorForward != "" {
options.CursorForward = xfiles.CursorForward
logger.Debug("X-Files: Set cursor forward")
}
if xfiles.CursorBackward != "" {
options.CursorBackward = xfiles.CursorBackward
logger.Debug("X-Files: Set cursor backward")
}
// Flags
if xfiles.Skipcount {
options.SkipCount = true
logger.Debug("X-Files: Set skip count")
}
// Process ParentTables and ChildTables recursively
h.processXFilesRelations(&xfiles, options, "")
}
// processXFilesRelations processes ParentTables and ChildTables from XFiles
// and adds them as Preload options recursively
func (h *Handler) processXFilesRelations(xfiles *XFiles, options *ExtendedRequestOptions, basePath string) {
if xfiles == nil {
return
}
// Process ParentTables
if len(xfiles.ParentTables) > 0 {
logger.Debug("X-Files: Processing %d parent tables", len(xfiles.ParentTables))
for _, parentTable := range xfiles.ParentTables {
h.addXFilesPreload(parentTable, options, basePath)
}
}
// Process ChildTables
if len(xfiles.ChildTables) > 0 {
logger.Debug("X-Files: Processing %d child tables", len(xfiles.ChildTables))
for _, childTable := range xfiles.ChildTables {
h.addXFilesPreload(childTable, options, basePath)
}
}
}
// resolveRelationNamesInOptions resolves all table names to field names in preload options
// This is called internally by parseOptionsFromHeaders when a model is provided
func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions, model interface{}) {
if options == nil || model == nil {
return
}
// Resolve relation names in all preload options
for i := range options.Preload {
preload := &options.Preload[i]
// Split the relation path (e.g., "parent.child.grandchild")
parts := strings.Split(preload.Relation, ".")
resolvedParts := make([]string, 0, len(parts))
// Resolve each part of the path
currentModel := model
for _, part := range parts {
resolvedPart := h.resolveRelationName(currentModel, part)
resolvedParts = append(resolvedParts, resolvedPart)
// Try to get the model type for the next level
// This allows nested resolution
if nextModel := h.getRelationModel(currentModel, resolvedPart); nextModel != nil {
currentModel = nextModel
}
}
// Update the relation path with resolved names
resolvedPath := strings.Join(resolvedParts, ".")
if resolvedPath != preload.Relation {
logger.Debug("Resolved relation path '%s' -> '%s'", preload.Relation, resolvedPath)
preload.Relation = resolvedPath
}
}
// Resolve relation names in expand options
for i := range options.Expand {
expand := &options.Expand[i]
resolved := h.resolveRelationName(model, expand.Relation)
if resolved != expand.Relation {
logger.Debug("Resolved expand relation '%s' -> '%s'", expand.Relation, resolved)
expand.Relation = resolved
}
}
}
// getRelationModel gets the model type for a relation field
func (h *Handler) getRelationModel(model interface{}, fieldName string) interface{} {
if model == nil || fieldName == "" {
return nil
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nil
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil
}
// Find the field
field, found := modelType.FieldByName(fieldName)
if !found {
return nil
}
// Get the target type
targetType := field.Type
if targetType == nil {
return nil
}
if targetType.Kind() == reflect.Slice {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() != reflect.Struct {
return nil
}
// Create a zero value of the target type
return reflect.New(targetType).Elem().Interface()
}
// resolveRelationName resolves a relation name or table name to the actual field name in the model
// If the input is already a field name, it returns it as-is
// If the input is a table name, it looks up the corresponding relation field
func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) string {
if model == nil || nameOrTable == "" {
return nameOrTable
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nameOrTable
}
// Dereference pointer if needed
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
// Check again after dereferencing
if modelType == nil {
return nameOrTable
}
// Ensure it's a struct
if modelType.Kind() != reflect.Struct {
return nameOrTable
}
// First, check if the input matches a field name directly
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
if field.Name == nameOrTable {
// It's already a field name
logger.Debug("Input '%s' is a field name", nameOrTable)
return nameOrTable
}
}
// If not found as a field name, try to look it up as a table name
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
fieldType := field.Type
// Check if it's a slice or pointer to a struct
var targetType reflect.Type
if fieldType.Kind() == reflect.Slice {
targetType = fieldType.Elem()
} else if fieldType.Kind() == reflect.Ptr {
targetType = fieldType.Elem()
}
if targetType != nil {
// Dereference pointer if the slice contains pointers
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
}
// Check if it's a struct type
if targetType.Kind() == reflect.Struct {
// Get the type name and normalize it
typeName := targetType.Name()
// Extract the table name from type name
// Patterns: ModelCoreMastertaskitem -> mastertaskitem
// ModelMastertaskitem -> mastertaskitem
normalizedTypeName := strings.ToLower(typeName)
// Remove common prefixes like "model", "modelcore", etc.
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
// Compare normalized names
if normalizedTypeName == normalizedInput {
logger.Debug("Resolved table name '%s' to field '%s' (type: %s)", nameOrTable, field.Name, typeName)
return field.Name
}
}
}
}
// If no match found, return the original input
logger.Debug("No field found for '%s', using as-is", nameOrTable)
return nameOrTable
}
// addXFilesPreload converts an XFiles relation into a PreloadOption
// and recursively processes its children
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
if xfile == nil || xfile.TableName == "" {
return
}
// Store the table name as-is for now - it will be resolved to field name later
// when we have the model instance available
relationPath := xfile.TableName
if basePath != "" {
relationPath = basePath + "." + xfile.TableName
}
logger.Debug("X-Files: Adding preload for relation: %s", relationPath)
// Create PreloadOption from XFiles configuration
preloadOpt := common.PreloadOption{
Relation: relationPath,
Columns: xfile.Columns,
OmitColumns: xfile.OmitColumns,
}
// Add sorting if specified
if len(xfile.Sort) > 0 {
preloadOpt.Sort = make([]common.SortOption, 0, len(xfile.Sort))
for _, sortField := range xfile.Sort {
direction := "ASC"
colName := sortField
// Handle direction prefixes
if strings.HasPrefix(sortField, "-") {
direction = "DESC"
colName = strings.TrimPrefix(sortField, "-")
} else if strings.HasPrefix(sortField, "+") {
colName = strings.TrimPrefix(sortField, "+")
}
preloadOpt.Sort = append(preloadOpt.Sort, common.SortOption{
Column: strings.TrimSpace(colName),
Direction: direction,
})
}
}
// Add filters if specified
if len(xfile.FilterFields) > 0 {
preloadOpt.Filters = make([]common.FilterOption, 0, len(xfile.FilterFields))
for _, filterField := range xfile.FilterFields {
preloadOpt.Filters = append(preloadOpt.Filters, common.FilterOption{
Column: filterField.Field,
Operator: filterField.Operator,
Value: filterField.Value,
LogicOperator: "AND",
})
}
}
// Add WHERE clause if SQL conditions specified
whereConditions := make([]string, 0)
if len(xfile.SqlAnd) > 0 {
whereConditions = append(whereConditions, xfile.SqlAnd...)
}
if len(whereConditions) > 0 {
preloadOpt.Where = strings.Join(whereConditions, " AND ")
}
// Add limit if specified
if limitStr := xfile.Limit.String(); limitStr != "" && limitStr != "0" {
if limitVal, err := xfile.Limit.Int64(); err == nil && limitVal > 0 {
limit := int(limitVal)
preloadOpt.Limit = &limit
}
}
// Add the preload option
options.Preload = append(options.Preload, preloadOpt)
// Recursively process nested ParentTables and ChildTables
if xfile.Recursive {
logger.Debug("X-Files: Recursive preload enabled for: %s", relationPath)
h.processXFilesRelations(xfile, options, relationPath)
} else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 {
h.processXFilesRelations(xfile, options, relationPath)
}
}
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func extractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
if model == nil {
return reflect.Invalid
}
// Extract the source column name (remove JSON operators like ->> or ->)
sourceColName := extractSourceColumn(colName)
modelType := reflect.TypeOf(model)
// Dereference pointer if needed
if modelType.Kind() == reflect.Ptr {
@@ -506,19 +980,19 @@ func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) refl
if jsonTag != "" {
// Parse JSON tag (format: "name,omitempty")
parts := strings.Split(jsonTag, ",")
if parts[0] == colName {
if parts[0] == sourceColName {
return field.Type.Kind()
}
}
// Check field name (case-insensitive)
if strings.EqualFold(field.Name, colName) {
if strings.EqualFold(field.Name, sourceColName) {
return field.Type.Kind()
}
// Check snake_case conversion
snakeCaseName := toSnakeCase(field.Name)
if snakeCaseName == colName {
if snakeCaseName == sourceColName {
return field.Type.Kind()
}
}

431
pkg/restheadspec/xfiles.go Normal file
View File

@@ -0,0 +1,431 @@
package restheadspec
import (
"encoding/json"
"reflect"
)
type XFiles struct {
TableName string `json:"tablename"`
Schema string `json:"schema"`
PrimaryKey string `json:"primarykey"`
ForeignKey string `json:"foreignkey"`
RelatedKey string `json:"relatedkey"`
Sort []string `json:"sort"`
Prefix string `json:"prefix"`
Editable bool `json:"editable"`
Recursive bool `json:"recursive"`
Expand bool `json:"expand"`
Rownumber bool `json:"rownumber"`
Skipcount bool `json:"skipcount"`
Offset json.Number `json:"offset"`
Limit json.Number `json:"limit"`
Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"`
CQLColumns []string `json:"cql_columns"`
SqlJoins []string `json:"sql_joins"`
SqlOr []string `json:"sql_or"`
SqlAnd []string `json:"sql_and"`
ParentTables []*XFiles `json:"parenttables"`
ChildTables []*XFiles `json:"childtables"`
ModelType reflect.Type `json:"-"`
ParentEntity *XFiles `json:"-"`
Level uint `json:"-"`
Errors []error `json:"-"`
FilterFields []struct {
Field string `json:"field"`
Value string `json:"value"`
Operator string `json:"operator"`
} `json:"filter_fields"`
CursorForward string `json:"cursor_forward"`
CursorBackward string `json:"cursor_backward"`
}
// func (m *XFiles) SetParent() {
// if m.ChildTables != nil {
// for _, child := range m.ChildTables {
// if child.ParentEntity != nil {
// continue
// }
// child.ParentEntity = m
// child.Level = m.Level + 1000
// child.SetParent()
// }
// }
// if m.ParentTables != nil {
// for _, pt := range m.ParentTables {
// if pt.ParentEntity != nil {
// continue
// }
// pt.ParentEntity = m
// pt.Level = m.Level + 1
// pt.SetParent()
// }
// }
// }
// func (m *XFiles) GetParentRelations() []reflection.GormRelationType {
// if m.ParentEntity == nil {
// return nil
// }
// foundRelations := make(GormRelationTypeList, 0)
// rels := reflection.GetValidModelRelationTypes(m.ParentEntity.ModelType, false)
// if m.ParentEntity.ModelType == nil {
// return nil
// }
// for _, rel := range rels {
// // if len(foundRelations) > 0 {
// // break
// // }
// if rel.FieldName != "" && rel.AssociationTable.Name() == m.ModelType.Name() {
// if rel.AssociationKey != "" && m.RelatedKey != "" && strings.EqualFold(rel.AssociationKey, m.RelatedKey) {
// foundRelations = append(foundRelations, rel)
// } else if rel.AssociationKey != "" && m.ForeignKey != "" && strings.EqualFold(rel.AssociationKey, m.ForeignKey) {
// foundRelations = append(foundRelations, rel)
// } else if rel.ForeignKey != "" && m.ForeignKey != "" && strings.EqualFold(rel.ForeignKey, m.ForeignKey) {
// foundRelations = append(foundRelations, rel)
// } else if rel.ForeignKey != "" && m.RelatedKey != "" && strings.EqualFold(rel.ForeignKey, m.RelatedKey) {
// foundRelations = append(foundRelations, rel)
// } else if rel.ForeignKey != "" && m.ForeignKey == "" && m.RelatedKey == "" {
// foundRelations = append(foundRelations, rel)
// }
// }
// //idName := fmt.Sprintf("%s_to_%s_%s=%s_m%v", rel.TableName, rel.AssociationTableName, rel.ForeignKey, rel.AssociationKey, rel.OneToMany)
// }
// sort.Sort(foundRelations)
// finalList := make(GormRelationTypeList, 0)
// dups := make(map[string]bool)
// for _, rel := range foundRelations {
// idName := fmt.Sprintf("%s_to_%s_%s_%s=%s_m%v", rel.TableName, rel.AssociationTableName, rel.FieldName, rel.ForeignKey, rel.AssociationKey, rel.OneToMany)
// if dups[idName] {
// continue
// }
// finalList = append(finalList, rel)
// dups[idName] = true
// }
// //fmt.Printf("GetParentRelations %s: %+v %d=%d\n", m.TableName, dups, len(finalList), len(foundRelations))
// return finalList
// }
// func (m *XFiles) GetUpdatableTableNames() []string {
// foundTables := make([]string, 0)
// if m.Editable {
// foundTables = append(foundTables, m.TableName)
// }
// if m.ParentTables != nil {
// for _, pt := range m.ParentTables {
// list := pt.GetUpdatableTableNames()
// if list != nil {
// foundTables = append(foundTables, list...)
// }
// }
// }
// if m.ChildTables != nil {
// for _, ct := range m.ChildTables {
// list := ct.GetUpdatableTableNames()
// if list != nil {
// foundTables = append(foundTables, list...)
// }
// }
// }
// return foundTables
// }
// func (m *XFiles) preload(db *gorm.DB, pPath string, pCnt uint) (*gorm.DB, error) {
// path := pPath
// _, colval := JSONSyntaxToSQLIn(path, m.ModelType, "preload")
// if colval != "" {
// path = colval
// }
// if path == "" {
// return db, fmt.Errorf("invalid preload path %s", path)
// }
// sortList := ""
// if m.Sort != nil {
// for _, sort := range m.Sort {
// descSort := false
// if strings.HasPrefix(sort, "-") || strings.Contains(strings.ToLower(sort), " desc") {
// descSort = true
// }
// sort = strings.TrimPrefix(strings.TrimPrefix(sort, "+"), "-")
// sort = strings.ReplaceAll(strings.ReplaceAll(sort, " desc", ""), " asc", "")
// if descSort {
// sort = sort + " desc"
// }
// sortList = sort
// }
// }
// SrcColumns := reflection.GetModelSQLColumns(m.ModelType)
// Columns := make([]string, 0)
// for _, s := range SrcColumns {
// for _, v := range m.Columns {
// if strings.EqualFold(v, s) {
// Columns = append(Columns, v)
// break
// }
// }
// }
// if len(Columns) == 0 {
// Columns = SrcColumns
// }
// chain := db
// // //Do expand where we can
// // if m.Expand {
// // ops := func(subchain *gorm.DB) *gorm.DB {
// // subchain = subchain.Select(strings.Join(m.Columns, ","))
// // if m.Filter != "" {
// // subchain = subchain.Where(m.Filter)
// // }
// // return subchain
// // }
// // chain = chain.Joins(path, ops(chain))
// // }
// //fmt.Printf("Preloading %s: %s lvl:%d \n", m.TableName, path, m.Level)
// //Do preload
// chain = chain.Preload(path, func(db *gorm.DB) *gorm.DB {
// subchain := db
// if sortList != "" {
// subchain = subchain.Order(sortList)
// }
// for _, sql := range m.SqlAnd {
// fnType, colval := JSONSyntaxToSQL(sql, m.ModelType)
// if fnType == 0 {
// colval = ValidSQL(colval, "select")
// }
// subchain = subchain.Where(colval)
// }
// for _, sql := range m.SqlOr {
// fnType, colval := JSONSyntaxToSQL(sql, m.ModelType)
// if fnType == 0 {
// colval = ValidSQL(colval, "select")
// }
// subchain = subchain.Or(colval)
// }
// limitval, err := m.Limit.Int64()
// if err == nil && limitval > 0 {
// subchain = subchain.Limit(int(limitval))
// }
// for _, j := range m.SqlJoins {
// subchain = subchain.Joins(ValidSQL(j, "select"))
// }
// offsetval, err := m.Offset.Int64()
// if err == nil && offsetval > 0 {
// subchain = subchain.Offset(int(offsetval))
// }
// cols := make([]string, 0)
// for _, col := range Columns {
// canAdd := true
// for _, omit := range m.OmitColumns {
// if col == omit {
// canAdd = false
// break
// }
// }
// if canAdd {
// cols = append(cols, col)
// }
// }
// for i, col := range m.CQLColumns {
// cols = append(cols, fmt.Sprintf("(%s) as cql%d", col, i+1))
// }
// if len(cols) > 0 {
// colStr := strings.Join(cols, ",")
// subchain = subchain.Select(colStr)
// }
// if m.Recursive && pCnt < 5 {
// paths := strings.Split(path, ".")
// p := paths[0]
// if len(paths) > 1 {
// p = strings.Join(paths[1:], ".")
// }
// for i := uint(0); i < 3; i++ {
// inlineStr := strings.Repeat(p+".", int(i+1))
// inlineStr = strings.TrimRight(inlineStr, ".")
// fmt.Printf("Preloading Recursive (%d) %s: %s lvl:%d \n", i, m.TableName, inlineStr, m.Level)
// subchain, err = m.preload(subchain, inlineStr, pCnt+i)
// if err != nil {
// cfg.LogError("Preload (%s,%d) error: %v", m.TableName, pCnt, err)
// } else {
// if m.ChildTables != nil {
// for _, child := range m.ChildTables {
// if child.ParentEntity == nil {
// continue
// }
// subchain, _ = child.ChainPreload(subchain, inlineStr, pCnt+i)
// }
// }
// if m.ParentTables != nil {
// for _, pt := range m.ParentTables {
// if pt.ParentEntity == nil {
// continue
// }
// subchain, _ = pt.ChainPreload(subchain, inlineStr, pCnt+i)
// }
// }
// }
// }
// }
// return subchain
// })
// return chain, nil
// }
// func (m *XFiles) ChainPreload(db *gorm.DB, pPath string, pCnt uint) (*gorm.DB, error) {
// var err error
// chain := db
// relations := m.GetParentRelations()
// if pCnt > 10000 {
// cfg.LogError("Preload Max size (%s,%s): %v", m.TableName, pPath, err)
// return chain, nil
// }
// hasPreloadError := false
// for _, rel := range relations {
// path := rel.FieldName
// if pPath != "" {
// path = fmt.Sprintf("%s.%s", pPath, rel.FieldName)
// }
// chain, err = m.preload(chain, path, pCnt)
// if err != nil {
// cfg.LogError("Preload Error (%s,%s): %v", m.TableName, path, err)
// hasPreloadError = true
// //return chain, err
// }
// //fmt.Printf("Preloading Rel %v: %s @ %s lvl:%d \n", m.Recursive, path, m.TableName, m.Level)
// if !hasPreloadError && m.ChildTables != nil {
// for _, child := range m.ChildTables {
// if child.ParentEntity == nil {
// continue
// }
// chain, err = child.ChainPreload(chain, path, pCnt)
// if err != nil {
// return chain, err
// }
// }
// }
// if !hasPreloadError && m.ParentTables != nil {
// for _, pt := range m.ParentTables {
// if pt.ParentEntity == nil {
// continue
// }
// chain, err = pt.ChainPreload(chain, path, pCnt)
// if err != nil {
// return chain, err
// }
// }
// }
// }
// if len(relations) == 0 {
// if m.ChildTables != nil {
// for _, child := range m.ChildTables {
// if child.ParentEntity == nil {
// continue
// }
// chain, err = child.ChainPreload(chain, pPath, pCnt)
// if err != nil {
// return chain, err
// }
// }
// }
// if m.ParentTables != nil {
// for _, pt := range m.ParentTables {
// if pt.ParentEntity == nil {
// continue
// }
// chain, err = pt.ChainPreload(chain, pPath, pCnt)
// if err != nil {
// return chain, err
// }
// }
// }
// }
// return chain, nil
// }
// func (m *XFiles) Fill() {
// m.ModelType = models.GetModelType(m.Schema, m.TableName)
// if m.ModelType == nil {
// m.Errors = append(m.Errors, fmt.Errorf("ModelType not found for %s", m.TableName))
// }
// if m.Prefix == "" {
// m.Prefix = reflection.GetTablePrefixFromType(m.ModelType)
// }
// if m.PrimaryKey == "" {
// m.PrimaryKey = reflection.GetPKNameFromType(m.ModelType)
// }
// if m.Schema == "" {
// m.Schema = reflection.GetSchemaNameFromType(m.ModelType)
// }
// for _, t := range m.ParentTables {
// t.Fill()
// }
// for _, t := range m.ChildTables {
// t.Fill()
// }
// }
// type GormRelationTypeList []reflection.GormRelationType
// func (s GormRelationTypeList) Len() int { return len(s) }
// func (s GormRelationTypeList) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// func (s GormRelationTypeList) Less(i, j int) bool {
// if strings.HasPrefix(strings.ToLower(s[j].FieldName),
// strings.ToLower(fmt.Sprintf("%s_%s_%s", s[i].AssociationSchema, s[i].AssociationTable, s[i].AssociationKey))) {
// return true
// }
// return s[i].FieldName < s[j].FieldName
// }

View File

@@ -0,0 +1,213 @@
# X-Files Header Usage
The `x-files` header allows you to configure complex query options using a single JSON object. The XFiles configuration is parsed and populates the `ExtendedRequestOptions` fields, which means it integrates seamlessly with the existing query building system.
## Architecture
When an `x-files` header is received:
1. It's parsed into an `XFiles` struct
2. The `XFiles` fields populate the `ExtendedRequestOptions` (columns, filters, sort, preload, etc.)
3. The normal query building process applies these options to the SQL query
4. This allows x-files to work alongside individual headers if needed
## Basic Example
```http
GET /public/users
X-Files: {"tablename":"users","columns":["id","name","email"],"limit":"10","offset":"0"}
```
## Complete Example
```http
GET /public/users
X-Files: {
"tablename": "users",
"schema": "public",
"columns": ["id", "name", "email", "created_at"],
"omit_columns": [],
"sort": ["-created_at", "name"],
"limit": "50",
"offset": "0",
"filter_fields": [
{
"field": "status",
"operator": "eq",
"value": "active"
},
{
"field": "age",
"operator": "gt",
"value": "18"
}
],
"sql_and": ["deleted_at IS NULL"],
"sql_or": [],
"cql_columns": ["UPPER(name)"],
"skipcount": false,
"distinct": false
}
```
## Supported Filter Operators
- `eq` - equals
- `neq` - not equals
- `gt` - greater than
- `gte` - greater than or equals
- `lt` - less than
- `lte` - less than or equals
- `like` - SQL LIKE
- `ilike` - case-insensitive LIKE
- `in` - IN clause
- `between` - between (exclusive)
- `between_inclusive` - between (inclusive)
- `is_null` - is NULL
- `is_not_null` - is NOT NULL
## Sorting
Sort fields can be prefixed with:
- `+` for ascending (default)
- `-` for descending
Examples:
- `"sort": ["name"]` - ascending by name
- `"sort": ["-created_at"]` - descending by created_at
- `"sort": ["-created_at", "name"]` - multiple sorts
## Computed Columns (CQL)
Use `cql_columns` to add computed SQL expressions:
```json
{
"cql_columns": [
"UPPER(name)",
"CONCAT(first_name, ' ', last_name)"
]
}
```
These will be available as `cql1`, `cql2`, etc. in the response.
## Cursor Pagination
```json
{
"cursor_forward": "eyJpZCI6MTAwfQ==",
"cursor_backward": ""
}
```
## Base64 Encoding
For complex JSON, you can base64-encode the value and prefix it with `ZIP_` or `__`:
```http
GET /public/users
X-Files: ZIP_eyJ0YWJsZW5hbWUiOiJ1c2VycyIsImxpbWl0IjoiMTAifQ==
```
## XFiles Struct Reference
```go
type XFiles struct {
TableName string `json:"tablename"`
Schema string `json:"schema"`
PrimaryKey string `json:"primarykey"`
ForeignKey string `json:"foreignkey"`
RelatedKey string `json:"relatedkey"`
Sort []string `json:"sort"`
Prefix string `json:"prefix"`
Editable bool `json:"editable"`
Recursive bool `json:"recursive"`
Expand bool `json:"expand"`
Rownumber bool `json:"rownumber"`
Skipcount bool `json:"skipcount"`
Offset json.Number `json:"offset"`
Limit json.Number `json:"limit"`
Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"`
CQLColumns []string `json:"cql_columns"`
SqlJoins []string `json:"sql_joins"`
SqlOr []string `json:"sql_or"`
SqlAnd []string `json:"sql_and"`
FilterFields []struct {
Field string `json:"field"`
Value string `json:"value"`
Operator string `json:"operator"`
} `json:"filter_fields"`
CursorForward string `json:"cursor_forward"`
CursorBackward string `json:"cursor_backward"`
}
```
## Recursive Preloading with ParentTables and ChildTables
XFiles now supports recursive preloading of related entities:
```json
{
"tablename": "users",
"columns": ["id", "name"],
"limit": "10",
"parenttables": [
{
"tablename": "Company",
"columns": ["id", "name", "industry"],
"sort": ["-created_at"]
}
],
"childtables": [
{
"tablename": "Orders",
"columns": ["id", "total", "status"],
"limit": "5",
"sort": ["-order_date"],
"filter_fields": [
{"field": "status", "operator": "eq", "value": "completed"}
],
"childtables": [
{
"tablename": "OrderItems",
"columns": ["id", "product_name", "quantity"],
"recursive": true
}
]
}
]
}
```
### How Recursive Preloading Works
- **ParentTables**: Preloads parent relationships (e.g., User -> Company)
- **ChildTables**: Preloads child relationships (e.g., User -> Orders -> OrderItems)
- **Recursive**: When `true`, continues preloading the same relation recursively
- Each nested table can have its own:
- Column selection (`columns`, `omit_columns`)
- Filtering (`filter_fields`, `sql_and`)
- Sorting (`sort`)
- Pagination (`limit`)
- Further nesting (`parenttables`, `childtables`)
### Relation Path Building
Relations are built as dot-separated paths:
- `Company` (direct parent)
- `Orders` (direct child)
- `Orders.OrderItems` (nested child)
- `Orders.OrderItems.Product` (deeply nested)
## Notes
- Individual headers (like `x-select-fields`, `x-sort`, etc.) can still be used alongside `x-files`
- X-Files populates `ExtendedRequestOptions` which is then processed by the normal query building logic
- ParentTables and ChildTables are converted to `PreloadOption` entries with full support for:
- Column selection
- Filtering
- Sorting
- Limit
- Recursive nesting
- The relation name in ParentTables/ChildTables should match the GORM/Bun relation field name on the model

View File

@@ -372,7 +372,14 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.True(t, result["success"].(bool), "Create department should succeed")
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
if success, ok := result["success"]; ok && success != nil {
assert.True(t, success.(bool), "Create department should succeed")
} else {
// Unwrapped format - verify we got the created data back
assert.NotEmpty(t, result, "Create department should return data")
assert.Equal(t, deptID, result["id"], "Created department should have correct ID")
}
logger.Info("Department created successfully: %s", deptID)
})
@@ -393,7 +400,14 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.True(t, result["success"].(bool), "Create employee should succeed")
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
if success, ok := result["success"]; ok && success != nil {
assert.True(t, success.(bool), "Create employee should succeed")
} else {
// Unwrapped format - verify we got the created data back
assert.NotEmpty(t, result, "Create employee should return data")
assert.Equal(t, empID, result["id"], "Created employee should have correct ID")
}
logger.Info("Employee created successfully: %s", empID)
})
@@ -540,7 +554,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.True(t, result["success"].(bool), "Update department should succeed")
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
if success, ok := result["success"]; ok && success != nil {
assert.True(t, success.(bool), "Update department should succeed")
} else {
// Unwrapped format - verify we got the updated data back
assert.NotEmpty(t, result, "Update department should return data")
}
logger.Info("Department updated successfully: %s", deptID)
// Verify update by reading the department again
@@ -558,7 +578,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.True(t, result["success"].(bool), "Update employee should succeed")
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
if success, ok := result["success"]; ok && success != nil {
assert.True(t, success.(bool), "Update employee should succeed")
} else {
// Unwrapped format - verify we got the updated data back
assert.NotEmpty(t, result, "Update employee should return data")
}
logger.Info("Employee updated successfully: %s", empID)
})
@@ -569,7 +595,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.True(t, result["success"].(bool), "Delete employee should succeed")
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
if success, ok := result["success"]; ok && success != nil {
assert.True(t, success.(bool), "Delete employee should succeed")
} else {
// Unwrapped format - verify we got a response (typically {"deleted": count})
assert.NotEmpty(t, result, "Delete employee should return data")
}
logger.Info("Employee deleted successfully: %s", empID)
// Verify deletion - just log that delete succeeded
@@ -582,7 +614,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
assert.True(t, result["success"].(bool), "Delete department should succeed")
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
if success, ok := result["success"]; ok && success != nil {
assert.True(t, success.(bool), "Delete department should succeed")
} else {
// Unwrapped format - verify we got a response (typically {"deleted": count})
assert.NotEmpty(t, result, "Delete department should return data")
}
logger.Info("Department deleted successfully: %s", deptID)
})