mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 07:44:25 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
02c9b96b0c | ||
|
|
9a3564f05f | ||
|
|
a931b8cdd2 | ||
|
|
7e76977dcc | ||
|
|
7853a3f56a | ||
|
|
c2e0c36c79 | ||
|
|
59bd709460 | ||
|
|
05962035b6 | ||
|
|
1cd04b7083 | ||
|
|
0d4909054c | ||
|
|
745564f2e7 | ||
|
|
311e50bfdd | ||
|
|
c95bc9e633 | ||
|
|
07b09e2025 | ||
|
|
3d5334002d | ||
|
|
640582d508 |
@@ -86,7 +86,6 @@
|
||||
"emptyFallthrough",
|
||||
"equalFold",
|
||||
"flagName",
|
||||
"ifElseChain",
|
||||
"indexAlloc",
|
||||
"initClause",
|
||||
"methodExprCall",
|
||||
@@ -106,6 +105,9 @@
|
||||
"unnecessaryBlock",
|
||||
"weakCond",
|
||||
"yodaStyleExpr"
|
||||
],
|
||||
"disabled-checks": [
|
||||
"ifElseChain"
|
||||
]
|
||||
},
|
||||
"revive": {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/uptrace/bun"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
@@ -43,12 +44,22 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
||||
return &BunDeleteQuery{query: b.db.NewDelete()}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result, err := b.db.ExecContext(ctx, query, args...)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
||||
}
|
||||
|
||||
@@ -73,7 +84,12 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
||||
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||
// Create adapter with transaction
|
||||
adapter := &BunTxAdapter{tx: tx}
|
||||
@@ -219,6 +235,14 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
||||
|
||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err := logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
if len(apply) == 0 {
|
||||
return sq
|
||||
}
|
||||
@@ -276,15 +300,38 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
if dest == nil {
|
||||
return fmt.Errorf("destination cannot be nil")
|
||||
}
|
||||
return b.query.Scan(ctx, dest)
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) error {
|
||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
|
||||
}
|
||||
}()
|
||||
if b.query.GetModel() == nil {
|
||||
return fmt.Errorf("model is nil")
|
||||
}
|
||||
|
||||
return b.query.Scan(ctx)
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
|
||||
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Count", r)
|
||||
count = 0
|
||||
}
|
||||
}()
|
||||
// If Model() was set, use bun's native Count() which works properly
|
||||
if b.hasModel {
|
||||
count, err := b.query.Count(ctx)
|
||||
@@ -293,15 +340,20 @@ func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
|
||||
|
||||
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
||||
// This is needed when only Table() is set without a model
|
||||
var count int
|
||||
err := b.db.NewSelect().
|
||||
err = b.db.NewSelect().
|
||||
TableExpr("(?) AS subquery", b.query).
|
||||
ColumnExpr("COUNT(*)").
|
||||
Scan(ctx, &count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Exists(ctx context.Context) (bool, error) {
|
||||
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Exists", r)
|
||||
exists = false
|
||||
}
|
||||
}()
|
||||
return b.query.Exists(ctx)
|
||||
}
|
||||
|
||||
@@ -319,6 +371,9 @@ func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
|
||||
if b.hasModel {
|
||||
return b
|
||||
}
|
||||
b.query = b.query.Table(table)
|
||||
return b
|
||||
}
|
||||
@@ -343,8 +398,13 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
if b.values != nil && len(b.values) > 0 {
|
||||
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunInsertQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
if len(b.values) > 0 {
|
||||
if !b.hasModel {
|
||||
// If no model was set, use the values map as the model
|
||||
// Bun can insert map[string]interface{} directly
|
||||
@@ -424,7 +484,12 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result, err := b.query.Exec(ctx)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
@@ -449,7 +514,12 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result, err := b.query.Exec(ctx)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
@@ -38,12 +39,22 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
||||
return &GormDeleteQuery{db: g.db}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Exec(query, args...)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||
}
|
||||
|
||||
@@ -63,7 +74,12 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
|
||||
return g.db.WithContext(ctx).Rollback().Error
|
||||
}
|
||||
|
||||
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
||||
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx}
|
||||
return fn(adapter)
|
||||
@@ -255,26 +271,48 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Find(dest).Error
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) ScanModel(ctx context.Context) error {
|
||||
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
|
||||
}
|
||||
}()
|
||||
if g.db.Statement.Model == nil {
|
||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||
}
|
||||
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Count(ctx context.Context) (int, error) {
|
||||
var count int64
|
||||
err := g.db.WithContext(ctx).Count(&count).Error
|
||||
return int(count), err
|
||||
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormSelectQuery.Count", r)
|
||||
count = 0
|
||||
}
|
||||
}()
|
||||
var count64 int64
|
||||
err = g.db.WithContext(ctx).Count(&count64).Error
|
||||
return int(count64), err
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Exists(ctx context.Context) (bool, error) {
|
||||
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormSelectQuery.Exists", r)
|
||||
exists = false
|
||||
}
|
||||
}()
|
||||
var count int64
|
||||
err := g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
@@ -314,7 +352,12 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
var result *gorm.DB
|
||||
switch {
|
||||
case g.model != nil:
|
||||
@@ -401,7 +444,12 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
@@ -428,7 +476,12 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Delete(g.model)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
@@ -121,6 +121,16 @@ func (b *BunRouterRequest) QueryParam(key string) string {
|
||||
return b.req.URL.Query().Get(key)
|
||||
}
|
||||
|
||||
func (b *BunRouterRequest) AllQueryParams() map[string]string {
|
||||
params := make(map[string]string)
|
||||
for key, values := range b.req.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
params[key] = values[0]
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func (b *BunRouterRequest) AllHeaders() map[string]string {
|
||||
headers := make(map[string]string)
|
||||
for key, values := range b.req.Header {
|
||||
|
||||
@@ -117,6 +117,16 @@ func (h *HTTPRequest) QueryParam(key string) string {
|
||||
return h.req.URL.Query().Get(key)
|
||||
}
|
||||
|
||||
func (h *HTTPRequest) AllQueryParams() map[string]string {
|
||||
params := make(map[string]string)
|
||||
for key, values := range h.req.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
params[key] = values[0]
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func (h *HTTPRequest) AllHeaders() map[string]string {
|
||||
headers := make(map[string]string)
|
||||
for key, values := range h.req.Header {
|
||||
|
||||
@@ -116,6 +116,7 @@ type Request interface {
|
||||
Body() ([]byte, error)
|
||||
PathParam(key string) string
|
||||
QueryParam(key string) string
|
||||
AllQueryParams() map[string]string // Get all query parameters as a map
|
||||
}
|
||||
|
||||
// ResponseWriter interface abstracts HTTP response
|
||||
|
||||
340
pkg/common/sql_helpers.go
Normal file
340
pkg/common/sql_helpers.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
||||
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
||||
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
||||
if where == "" {
|
||||
return where, nil
|
||||
}
|
||||
|
||||
// Check if the relation name is already present in the WHERE clause
|
||||
lowerWhere := strings.ToLower(where)
|
||||
lowerRelation := strings.ToLower(relationName)
|
||||
|
||||
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
||||
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
||||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
||||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
||||
// Relation prefix is already present
|
||||
return where, nil
|
||||
}
|
||||
|
||||
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
||||
// we can't safely auto-fix it - require explicit prefix
|
||||
if strings.Contains(lowerWhere, " or ") ||
|
||||
strings.Contains(where, "(") ||
|
||||
strings.Contains(where, ")") {
|
||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
||||
}
|
||||
|
||||
// Try to add the relation prefix to simple column references
|
||||
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
||||
// Split by AND to handle multiple conditions (case-insensitive)
|
||||
originalConditions := strings.Split(where, " AND ")
|
||||
|
||||
// If uppercase split didn't work, try lowercase
|
||||
if len(originalConditions) == 1 {
|
||||
originalConditions = strings.Split(where, " and ")
|
||||
}
|
||||
|
||||
fixedConditions := make([]string, 0, len(originalConditions))
|
||||
|
||||
for _, cond := range originalConditions {
|
||||
cond = strings.TrimSpace(cond)
|
||||
if cond == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this condition already has a table prefix (contains a dot)
|
||||
if strings.Contains(cond, ".") {
|
||||
fixedConditions = append(fixedConditions, cond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||
lowerCond := strings.ToLower(strings.TrimSpace(cond))
|
||||
if IsSQLExpression(lowerCond) {
|
||||
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
|
||||
fixedConditions = append(fixedConditions, cond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract the column name (first identifier before operator)
|
||||
columnName := ExtractColumnName(cond)
|
||||
if columnName == "" {
|
||||
// Can't identify column name, require explicit prefix
|
||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
|
||||
}
|
||||
|
||||
// Add relation prefix to the column name only
|
||||
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
|
||||
fixedConditions = append(fixedConditions, fixedCond)
|
||||
}
|
||||
|
||||
fixedWhere := strings.Join(fixedConditions, " AND ")
|
||||
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
||||
return fixedWhere, nil
|
||||
}
|
||||
|
||||
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
|
||||
func IsSQLExpression(cond string) bool {
|
||||
// Common SQL literals and expressions
|
||||
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
|
||||
for _, literal := range sqlLiterals {
|
||||
if cond == literal {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
|
||||
// These conditions should be removed from WHERE clauses as they have no filtering effect
|
||||
func IsTrivialCondition(cond string) bool {
|
||||
cond = strings.TrimSpace(cond)
|
||||
lowerCond := strings.ToLower(cond)
|
||||
|
||||
// Conditions that always evaluate to true
|
||||
trivialConditions := []string{
|
||||
"1=1", "1 = 1", "1= 1", "1 =1",
|
||||
"true", "true = true", "true=true", "true= true", "true =true",
|
||||
"0=0", "0 = 0", "0= 0", "0 =0",
|
||||
}
|
||||
|
||||
for _, trivial := range trivialConditions {
|
||||
if lowerCond == trivial {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
|
||||
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
||||
//
|
||||
// Parameters:
|
||||
// - where: The WHERE clause string to sanitize
|
||||
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
|
||||
//
|
||||
// Returns:
|
||||
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
|
||||
// - An empty string if all conditions were trivial or the input was empty
|
||||
func SanitizeWhereClause(where string, tableName string) string {
|
||||
if where == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
where = strings.TrimSpace(where)
|
||||
|
||||
// Strip outer parentheses and re-trim
|
||||
where = stripOuterParentheses(where)
|
||||
|
||||
// Get valid columns from the model if tableName is provided
|
||||
var validColumns map[string]bool
|
||||
if tableName != "" {
|
||||
validColumns = getValidColumnsForTable(tableName)
|
||||
}
|
||||
|
||||
// Split by AND to handle multiple conditions
|
||||
conditions := splitByAND(where)
|
||||
|
||||
validConditions := make([]string, 0, len(conditions))
|
||||
|
||||
for _, cond := range conditions {
|
||||
cond = strings.TrimSpace(cond)
|
||||
if cond == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Strip parentheses from the condition before checking
|
||||
condToCheck := stripOuterParentheses(cond)
|
||||
|
||||
// Skip trivial conditions that always evaluate to true
|
||||
if IsTrivialCondition(condToCheck) {
|
||||
logger.Debug("Removing trivial condition: '%s'", cond)
|
||||
continue
|
||||
}
|
||||
|
||||
// If tableName is provided and the condition doesn't already have a table prefix,
|
||||
// attempt to add it
|
||||
if tableName != "" && !hasTablePrefix(condToCheck) {
|
||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||
if !IsSQLExpression(strings.ToLower(condToCheck)) {
|
||||
// Extract the column name and prefix it
|
||||
columnName := ExtractColumnName(condToCheck)
|
||||
if columnName != "" {
|
||||
// Only prefix if this is a valid column in the model
|
||||
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
|
||||
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||
// Replace in the original condition (without stripped parens)
|
||||
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
|
||||
logger.Debug("Prefixed column in condition: '%s'", cond)
|
||||
} else {
|
||||
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validConditions = append(validConditions, cond)
|
||||
}
|
||||
|
||||
if len(validConditions) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := strings.Join(validConditions, " AND ")
|
||||
|
||||
if result != where {
|
||||
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// stripOuterParentheses removes matching outer parentheses from a string
|
||||
// It handles nested parentheses correctly
|
||||
func stripOuterParentheses(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
for {
|
||||
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
|
||||
return s
|
||||
}
|
||||
|
||||
// Check if these parentheses match (i.e., they're the outermost pair)
|
||||
depth := 0
|
||||
matched := false
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
if depth == 0 && i == len(s)-1 {
|
||||
matched = true
|
||||
} else if depth == 0 {
|
||||
// Found a closing paren before the end, so outer parens don't match
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
return s
|
||||
}
|
||||
|
||||
// Strip the outer parentheses and continue
|
||||
s = strings.TrimSpace(s[1 : len(s)-1])
|
||||
}
|
||||
}
|
||||
|
||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||
// This is a simple split that doesn't handle nested parentheses or complex expressions
|
||||
func splitByAND(where string) []string {
|
||||
// First try uppercase AND
|
||||
conditions := strings.Split(where, " AND ")
|
||||
|
||||
// If we didn't split on uppercase, try lowercase
|
||||
if len(conditions) == 1 {
|
||||
conditions = strings.Split(where, " and ")
|
||||
}
|
||||
|
||||
// If we still didn't split, try mixed case
|
||||
if len(conditions) == 1 {
|
||||
conditions = strings.Split(where, " And ")
|
||||
}
|
||||
|
||||
return conditions
|
||||
}
|
||||
|
||||
// hasTablePrefix checks if a condition already has a table/relation prefix (contains a dot)
|
||||
func hasTablePrefix(cond string) bool {
|
||||
// Look for patterns like "table.column" or "`table`.`column`" or "\"table\".\"column\""
|
||||
return strings.Contains(cond, ".")
|
||||
}
|
||||
|
||||
// ExtractColumnName extracts the column name from a WHERE condition
|
||||
// For example: "status = 'active'" returns "status"
|
||||
func ExtractColumnName(cond string) string {
|
||||
// Common SQL operators
|
||||
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
||||
|
||||
for _, op := range operators {
|
||||
if idx := strings.Index(cond, op); idx > 0 {
|
||||
columnName := strings.TrimSpace(cond[:idx])
|
||||
// Remove quotes if present
|
||||
columnName = strings.Trim(columnName, "`\"'")
|
||||
return columnName
|
||||
}
|
||||
}
|
||||
|
||||
// If no operator found, check if it's a simple identifier (for boolean columns)
|
||||
parts := strings.Fields(cond)
|
||||
if len(parts) > 0 {
|
||||
columnName := strings.Trim(parts[0], "`\"'")
|
||||
// Check if it's a valid identifier (not a SQL keyword)
|
||||
if !IsSQLKeyword(strings.ToLower(columnName)) {
|
||||
return columnName
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
|
||||
func IsSQLKeyword(word string) bool {
|
||||
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
|
||||
for _, kw := range keywords {
|
||||
if word == kw {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry
|
||||
// Returns a map of column names for fast lookup, or nil if the model is not found
|
||||
func getValidColumnsForTable(tableName string) map[string]bool {
|
||||
// Try to get the model from the registry
|
||||
model, err := modelregistry.GetModelByName(tableName)
|
||||
if err != nil {
|
||||
// Model not found, return nil to indicate we should use fallback behavior
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get SQL columns from the model
|
||||
columns := reflection.GetSQLModelColumns(model)
|
||||
if len(columns) == 0 {
|
||||
// No columns found, return nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build a map for fast lookup
|
||||
columnMap := make(map[string]bool, len(columns))
|
||||
for _, col := range columns {
|
||||
columnMap[strings.ToLower(col)] = true
|
||||
}
|
||||
|
||||
return columnMap
|
||||
}
|
||||
|
||||
// isValidColumn checks if a column name exists in the valid columns map
|
||||
// Handles case-insensitive comparison
|
||||
func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||
if validColumns == nil {
|
||||
return true // No model info, assume valid
|
||||
}
|
||||
return validColumns[strings.ToLower(columnName)]
|
||||
}
|
||||
224
pkg/common/sql_helpers_test.go
Normal file
224
pkg/common/sql_helpers_test.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
func TestSanitizeWhereClause(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "trivial conditions in parentheses",
|
||||
where: "(true AND true AND true)",
|
||||
tableName: "mastertask",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "trivial conditions without parentheses",
|
||||
where: "true AND true AND true",
|
||||
tableName: "mastertask",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single trivial condition",
|
||||
where: "true",
|
||||
tableName: "mastertask",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "valid condition with parentheses",
|
||||
where: "(status = 'active')",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "mixed trivial and valid conditions",
|
||||
where: "true AND status = 'active' AND 1=1",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "condition already with table prefix",
|
||||
where: "users.status = 'active'",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "multiple valid conditions",
|
||||
where: "status = 'active' AND age > 18",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' AND users.age > 18",
|
||||
},
|
||||
{
|
||||
name: "no table name provided",
|
||||
where: "status = 'active'",
|
||||
tableName: "",
|
||||
expected: "status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "empty where clause",
|
||||
where: "",
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeWhereClause(tt.where, tt.tableName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripOuterParentheses(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single level parentheses",
|
||||
input: "(true)",
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "multiple levels",
|
||||
input: "((true))",
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "no parentheses",
|
||||
input: "true",
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "mismatched parentheses",
|
||||
input: "(true",
|
||||
expected: "(true",
|
||||
},
|
||||
{
|
||||
name: "complex expression",
|
||||
input: "(a AND b)",
|
||||
expected: "a AND b",
|
||||
},
|
||||
{
|
||||
name: "nested but not outer",
|
||||
input: "(a AND (b OR c)) AND d",
|
||||
expected: "(a AND (b OR c)) AND d",
|
||||
},
|
||||
{
|
||||
name: "with spaces",
|
||||
input: " ( true ) ",
|
||||
expected: "true",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := stripOuterParentheses(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("stripOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTrivialCondition(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"true", "true", true},
|
||||
{"true with spaces", " true ", true},
|
||||
{"TRUE uppercase", "TRUE", true},
|
||||
{"1=1", "1=1", true},
|
||||
{"1 = 1", "1 = 1", true},
|
||||
{"true = true", "true = true", true},
|
||||
{"valid condition", "status = 'active'", false},
|
||||
{"false", "false", false},
|
||||
{"column name", "is_active", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsTrivialCondition(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsTrivialCondition(%q) = %v; want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test model for model-aware sanitization tests
|
||||
type MasterTask struct {
|
||||
ID int `bun:"id,pk"`
|
||||
Name string `bun:"name"`
|
||||
Status string `bun:"status"`
|
||||
UserID int `bun:"user_id"`
|
||||
}
|
||||
|
||||
func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||
// Register the test model
|
||||
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
|
||||
if err != nil {
|
||||
// Model might already be registered, ignore error
|
||||
t.Logf("Model registration returned: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "valid column gets prefixed",
|
||||
where: "status = 'active'",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "multiple valid columns get prefixed",
|
||||
where: "status = 'active' AND user_id = 123",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
|
||||
},
|
||||
{
|
||||
name: "invalid column does not get prefixed",
|
||||
where: "invalid_column = 'value'",
|
||||
tableName: "mastertask",
|
||||
expected: "invalid_column = 'value'",
|
||||
},
|
||||
{
|
||||
name: "mix of valid and trivial conditions",
|
||||
where: "true AND status = 'active' AND 1=1",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "parentheses with valid column",
|
||||
where: "(status = 'active')",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeWhereClause(tt.where, tt.tableName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func tryParseDT(str string) (time.Time, error) {
|
||||
@@ -236,13 +238,13 @@ func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
|
||||
var err error
|
||||
|
||||
if b == nil {
|
||||
t = &SqlTimeStamp{}
|
||||
|
||||
return nil
|
||||
}
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
if s == "null" || s == "" || s == "0" ||
|
||||
s == "0001-01-01T00:00:00" || s == "0001-01-01" {
|
||||
t = &SqlTimeStamp{}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -291,7 +293,7 @@ func (t *SqlTimeStamp) Scan(value interface{}) error {
|
||||
|
||||
// String - Override String format of time
|
||||
func (t SqlTimeStamp) String() string {
|
||||
return fmt.Sprintf("%s", time.Time(t).Format("2006-01-02T15:04:05"))
|
||||
return time.Time(t).Format("2006-01-02T15:04:05")
|
||||
}
|
||||
|
||||
// GetTime - Returns Time
|
||||
@@ -306,7 +308,7 @@ func (t *SqlTimeStamp) SetTime(pTime time.Time) {
|
||||
|
||||
// Format - Formats the time
|
||||
func (t SqlTimeStamp) Format(layout string) string {
|
||||
return fmt.Sprintf("%s", time.Time(t).Format(layout))
|
||||
return time.Time(t).Format(layout)
|
||||
}
|
||||
|
||||
func SqlTimeStampNow() SqlTimeStamp {
|
||||
@@ -418,7 +420,6 @@ func (t *SqlDate) UnmarshalJSON(b []byte) error {
|
||||
if s == "null" || s == "" || s == "0" ||
|
||||
strings.HasPrefix(s, "0001-01-01T00:00:00") ||
|
||||
s == "0001-01-01" {
|
||||
t = &SqlDate{}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -432,7 +433,7 @@ func (t *SqlDate) UnmarshalJSON(b []byte) error {
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (t SqlDate) MarshalJSON() ([]byte, error) {
|
||||
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
|
||||
tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339
|
||||
if strings.HasPrefix(tmstr, "0001-01-01") {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
@@ -480,7 +481,7 @@ func (t SqlDate) Int64() int64 {
|
||||
|
||||
// String - Override String format of time
|
||||
func (t SqlDate) String() string {
|
||||
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
|
||||
tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339
|
||||
if strings.HasPrefix(tmstr, "0001-01-01") || strings.HasPrefix(tmstr, "1800-12-31") {
|
||||
return "0"
|
||||
}
|
||||
@@ -515,8 +516,8 @@ func (t *SqlTime) UnmarshalJSON(b []byte) error {
|
||||
*t = SqlTime{}
|
||||
return nil
|
||||
}
|
||||
tx := time.Time{}
|
||||
tx, err = tryParseDT(s)
|
||||
|
||||
tx, err := tryParseDT(s)
|
||||
*t = SqlTime(tx)
|
||||
|
||||
return err
|
||||
@@ -640,9 +641,8 @@ func (n SqlJSONB) AsSlice() ([]any, error) {
|
||||
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
invalid := (s == "null" || s == "" || len(s) < 2) || !(strings.Contains(s, "{") || strings.Contains(s, "["))
|
||||
invalid := (s == "null" || s == "" || len(s) < 2) || (!strings.Contains(s, "{") && !strings.Contains(s, "["))
|
||||
if invalid {
|
||||
s = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -659,7 +659,7 @@ func (n SqlJSONB) MarshalJSON() ([]byte, error) {
|
||||
var obj interface{}
|
||||
err := json.Unmarshal(n, &obj)
|
||||
if err != nil {
|
||||
//fmt.Printf("Invalid JSON %v", err)
|
||||
// fmt.Printf("Invalid JSON %v", err)
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
@@ -671,3 +671,101 @@ func (n SqlJSONB) MarshalJSON() ([]byte, error) {
|
||||
|
||||
return dat, nil
|
||||
}
|
||||
|
||||
// SqlUUID - Nullable UUID String
|
||||
type SqlUUID sql.NullString
|
||||
|
||||
// Scan -
|
||||
func (n *SqlUUID) Scan(value interface{}) error {
|
||||
str := sql.NullString{String: "", Valid: false}
|
||||
if value == nil {
|
||||
*n = SqlUUID(str)
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
uuid, err := uuid.Parse(v)
|
||||
if err == nil {
|
||||
str.String = uuid.String()
|
||||
str.Valid = true
|
||||
*n = SqlUUID(str)
|
||||
}
|
||||
case []uint8:
|
||||
uuid, err := uuid.ParseBytes(v)
|
||||
if err == nil {
|
||||
str.String = uuid.String()
|
||||
str.Valid = true
|
||||
*n = SqlUUID(str)
|
||||
}
|
||||
default:
|
||||
uuid, err := uuid.Parse(fmt.Sprintf("%v", v))
|
||||
if err == nil {
|
||||
str.String = uuid.String()
|
||||
str.Valid = true
|
||||
*n = SqlUUID(str)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value -
|
||||
func (n SqlUUID) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.String, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Override JSON
|
||||
func (n *SqlUUID) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
invalid := (s == "null" || s == "" || len(s) < 30)
|
||||
if invalid {
|
||||
return nil
|
||||
}
|
||||
*n = SqlUUID(sql.NullString{String: s, Valid: !invalid})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlUUID) MarshalJSON() ([]byte, error) {
|
||||
if !n.Valid {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf("\"%s\"", n.String)), nil
|
||||
}
|
||||
|
||||
// TryIfInt64 - Wrapper function to quickly try and cast text to int
|
||||
func TryIfInt64(v any, def int64) int64 {
|
||||
str := ""
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
str = val
|
||||
case int:
|
||||
return int64(val)
|
||||
case int32:
|
||||
return int64(val)
|
||||
case int64:
|
||||
return val
|
||||
case uint32:
|
||||
return int64(val)
|
||||
case uint64:
|
||||
return int64(val)
|
||||
case float32:
|
||||
return int64(val)
|
||||
case float64:
|
||||
return int64(val)
|
||||
case []byte:
|
||||
str = string(val)
|
||||
default:
|
||||
str = fmt.Sprintf("%d", def)
|
||||
}
|
||||
val, err := strconv.ParseInt(str, 10, 64)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
566
pkg/common/sql_types_test.go
Normal file
566
pkg/common/sql_types_test.go
Normal file
@@ -0,0 +1,566 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestSqlInt16 tests SqlInt16 type
|
||||
func TestSqlInt16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected SqlInt16
|
||||
}{
|
||||
{"int", 42, SqlInt16(42)},
|
||||
{"int32", int32(100), SqlInt16(100)},
|
||||
{"int64", int64(200), SqlInt16(200)},
|
||||
{"string", "123", SqlInt16(123)},
|
||||
{"nil", nil, SqlInt16(0)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n SqlInt16
|
||||
if err := n.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if n != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlInt16_Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlInt16
|
||||
expected driver.Value
|
||||
}{
|
||||
{"zero", SqlInt16(0), nil},
|
||||
{"positive", SqlInt16(42), int64(42)},
|
||||
{"negative", SqlInt16(-10), int64(-10)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := tt.input.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlInt16_JSON(t *testing.T) {
|
||||
n := SqlInt16(42)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(n)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := "42"
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var n2 SqlInt16
|
||||
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if n2 != 123 {
|
||||
t.Errorf("expected 123, got %d", n2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlInt64 tests SqlInt64 type
|
||||
func TestSqlInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected SqlInt64
|
||||
}{
|
||||
{"int", 42, SqlInt64(42)},
|
||||
{"int32", int32(100), SqlInt64(100)},
|
||||
{"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)},
|
||||
{"uint32", uint32(100), SqlInt64(100)},
|
||||
{"uint64", uint64(200), SqlInt64(200)},
|
||||
{"nil", nil, SqlInt64(0)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n SqlInt64
|
||||
if err := n.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if n != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlFloat64 tests SqlFloat64 type
|
||||
func TestSqlFloat64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected float64
|
||||
valid bool
|
||||
}{
|
||||
{"float64", float64(3.14), 3.14, true},
|
||||
{"float32", float32(2.5), 2.5, true},
|
||||
{"int", 42, 42.0, true},
|
||||
{"int64", int64(100), 100.0, true},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n SqlFloat64
|
||||
if err := n.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if n.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
|
||||
}
|
||||
if tt.valid && n.Float64 != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n.Float64)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlTimeStamp tests SqlTimeStamp type
|
||||
func TestSqlTimeStamp(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
}{
|
||||
{"time.Time", now},
|
||||
{"string RFC3339", now.Format(time.RFC3339)},
|
||||
{"string date", "2024-01-15"},
|
||||
{"string datetime", "2024-01-15T10:30:00"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var ts SqlTimeStamp
|
||||
if err := ts.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if ts.GetTime().IsZero() {
|
||||
t.Error("expected non-zero time")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlTimeStamp_JSON(t *testing.T) {
|
||||
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
|
||||
ts := SqlTimeStamp(now)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(ts)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := `"2024-01-15T10:30:45"`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var ts2 SqlTimeStamp
|
||||
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if ts2.GetTime().Year() != 2024 {
|
||||
t.Errorf("expected year 2024, got %d", ts2.GetTime().Year())
|
||||
}
|
||||
|
||||
// Test null
|
||||
var ts3 SqlTimeStamp
|
||||
if err := json.Unmarshal([]byte("null"), &ts3); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlDate tests SqlDate type
|
||||
func TestSqlDate(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
}{
|
||||
{"time.Time", now},
|
||||
{"string date", "2024-01-15"},
|
||||
{"string UK format", "15/01/2024"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var d SqlDate
|
||||
if err := d.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if d.String() == "0" {
|
||||
t.Error("expected non-zero date")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlDate_JSON(t *testing.T) {
|
||||
date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(date)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := `"2024-01-15"`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var d2 SqlDate
|
||||
if err := json.Unmarshal([]byte(`"2024-01-15"`), &d2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlTime tests SqlTime type
|
||||
func TestSqlTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{"time.Time", now, now.Format("15:04:05")},
|
||||
{"string time", "10:30:45", "10:30:45"},
|
||||
{"string short time", "10:30", "10:30:00"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var tm SqlTime
|
||||
if err := tm.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if tm.String() != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, tm.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlJSONB tests SqlJSONB type
|
||||
func TestSqlJSONB_Scan(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{"string JSON object", `{"key":"value"}`, `{"key":"value"}`},
|
||||
{"string JSON array", `[1,2,3]`, `[1,2,3]`},
|
||||
{"bytes", []byte(`{"test":true}`), `{"test":true}`},
|
||||
{"nil", nil, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var j SqlJSONB
|
||||
if err := j.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if tt.expected == "" && j == nil {
|
||||
return // nil case
|
||||
}
|
||||
if string(j) != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, string(j))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlJSONB
|
||||
expected string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid object", SqlJSONB(`{"key":"value"}`), `{"key":"value"}`, false},
|
||||
{"valid array", SqlJSONB(`[1,2,3]`), `[1,2,3]`, false},
|
||||
{"empty", SqlJSONB{}, "", false},
|
||||
{"nil", nil, "", false},
|
||||
{"invalid JSON", SqlJSONB(`{invalid`), "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := tt.input.Value()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if tt.expected == "" && val == nil {
|
||||
return // nil case
|
||||
}
|
||||
if val.(string) != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_JSON(t *testing.T) {
|
||||
// Marshal
|
||||
j := SqlJSONB(`{"name":"test","count":42}`)
|
||||
data, err := json.Marshal(j)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("Unmarshal result failed: %v", err)
|
||||
}
|
||||
if result["name"] != "test" {
|
||||
t.Errorf("expected name=test, got %v", result["name"])
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var j2 SqlJSONB
|
||||
if err := json.Unmarshal([]byte(`{"key":"value"}`), &j2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if string(j2) != `{"key":"value"}` {
|
||||
t.Errorf("expected {\"key\":\"value\"}, got %s", string(j2))
|
||||
}
|
||||
|
||||
// Test null
|
||||
var j3 SqlJSONB
|
||||
if err := json.Unmarshal([]byte("null"), &j3); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_AsMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlJSONB
|
||||
wantErr bool
|
||||
wantNil bool
|
||||
}{
|
||||
{"valid object", SqlJSONB(`{"name":"test","age":30}`), false, false},
|
||||
{"empty", SqlJSONB{}, false, true},
|
||||
{"nil", nil, false, true},
|
||||
{"invalid JSON", SqlJSONB(`{invalid`), true, false},
|
||||
{"array not object", SqlJSONB(`[1,2,3]`), true, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m, err := tt.input.AsMap()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("AsMap failed: %v", err)
|
||||
}
|
||||
if tt.wantNil {
|
||||
if m != nil {
|
||||
t.Errorf("expected nil, got %v", m)
|
||||
}
|
||||
return
|
||||
}
|
||||
if m == nil {
|
||||
t.Error("expected non-nil map")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_AsSlice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlJSONB
|
||||
wantErr bool
|
||||
wantNil bool
|
||||
}{
|
||||
{"valid array", SqlJSONB(`[1,2,3]`), false, false},
|
||||
{"empty", SqlJSONB{}, false, true},
|
||||
{"nil", nil, false, true},
|
||||
{"invalid JSON", SqlJSONB(`[invalid`), true, false},
|
||||
{"object not array", SqlJSONB(`{"key":"value"}`), true, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, err := tt.input.AsSlice()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("AsSlice failed: %v", err)
|
||||
}
|
||||
if tt.wantNil {
|
||||
if s != nil {
|
||||
t.Errorf("expected nil, got %v", s)
|
||||
}
|
||||
return
|
||||
}
|
||||
if s == nil {
|
||||
t.Error("expected non-nil slice")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlUUID tests SqlUUID type
|
||||
func TestSqlUUID_Scan(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
testUUIDStr := testUUID.String()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
valid bool
|
||||
}{
|
||||
{"string UUID", testUUIDStr, testUUIDStr, true},
|
||||
{"bytes UUID", []byte(testUUIDStr), testUUIDStr, true},
|
||||
{"nil", nil, "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var u SqlUUID
|
||||
if err := u.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if u.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
|
||||
}
|
||||
if tt.valid && u.String != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, u.String)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlUUID_Value(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||
|
||||
val, err := u.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val != testUUID.String() {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), val)
|
||||
}
|
||||
|
||||
// Test invalid UUID
|
||||
u2 := SqlUUID{Valid: false}
|
||||
val2, err := u2.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val2 != nil {
|
||||
t.Errorf("expected nil, got %v", val2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlUUID_JSON(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(u)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := `"` + testUUID.String() + `"`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var u2 SqlUUID
|
||||
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if u2.String != testUUID.String() {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), u2.String)
|
||||
}
|
||||
|
||||
// Test null
|
||||
var u3 SqlUUID
|
||||
if err := json.Unmarshal([]byte("null"), &u3); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
if u3.Valid {
|
||||
t.Error("expected invalid UUID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTryIfInt64 tests the TryIfInt64 helper function
|
||||
func TestTryIfInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
def int64
|
||||
expected int64
|
||||
}{
|
||||
{"string valid", "123", 0, 123},
|
||||
{"string invalid", "abc", 99, 99},
|
||||
{"int", 42, 0, 42},
|
||||
{"int32", int32(100), 0, 100},
|
||||
{"int64", int64(200), 0, 200},
|
||||
{"uint32", uint32(50), 0, 50},
|
||||
{"uint64", uint64(75), 0, 75},
|
||||
{"float32", float32(3.14), 0, 3},
|
||||
{"float64", float64(2.71), 0, 2},
|
||||
{"bytes", []byte("456"), 0, 456},
|
||||
{"unknown type", struct{}{}, 999, 999},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := TryIfInt64(tt.input, tt.def)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -32,15 +32,22 @@ type Parameter struct {
|
||||
}
|
||||
|
||||
type PreloadOption struct {
|
||||
Relation string `json:"relation"`
|
||||
Columns []string `json:"columns"`
|
||||
OmitColumns []string `json:"omit_columns"`
|
||||
Sort []SortOption `json:"sort"`
|
||||
Filters []FilterOption `json:"filters"`
|
||||
Where string `json:"where"`
|
||||
Limit *int `json:"limit"`
|
||||
Offset *int `json:"offset"`
|
||||
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
||||
Relation string `json:"relation"`
|
||||
Columns []string `json:"columns"`
|
||||
OmitColumns []string `json:"omit_columns"`
|
||||
Sort []SortOption `json:"sort"`
|
||||
Filters []FilterOption `json:"filters"`
|
||||
Where string `json:"where"`
|
||||
Limit *int `json:"limit"`
|
||||
Offset *int `json:"offset"`
|
||||
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
||||
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
|
||||
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
|
||||
|
||||
// Relationship keys from XFiles - used to build proper foreign key filters
|
||||
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
||||
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
||||
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
||||
}
|
||||
|
||||
type FilterOption struct {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// ColumnValidator validates column names against a model's fields
|
||||
@@ -95,6 +96,7 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
|
||||
// ValidateColumn validates a single column name
|
||||
// Returns nil if valid, error if invalid
|
||||
// Columns prefixed with "cql" (case insensitive) are always valid
|
||||
// Handles PostgreSQL JSON operators (-> and ->>)
|
||||
func (v *ColumnValidator) ValidateColumn(column string) error {
|
||||
// Allow empty columns
|
||||
if column == "" {
|
||||
@@ -106,8 +108,11 @@ func (v *ColumnValidator) ValidateColumn(column string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract source column name (remove JSON operators like ->> or ->)
|
||||
sourceColumn := reflection.ExtractSourceColumn(column)
|
||||
|
||||
// Check if column exists in model
|
||||
if _, exists := v.validColumns[strings.ToLower(column)]; !exists {
|
||||
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {
|
||||
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
|
||||
}
|
||||
|
||||
|
||||
126
pkg/common/validation_json_test.go
Normal file
126
pkg/common/validation_json_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
func TestExtractSourceColumn(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple column name",
|
||||
input: "columna",
|
||||
expected: "columna",
|
||||
},
|
||||
{
|
||||
name: "column with ->> operator",
|
||||
input: "columna->>'val'",
|
||||
expected: "columna",
|
||||
},
|
||||
{
|
||||
name: "column with -> operator",
|
||||
input: "columna->'key'",
|
||||
expected: "columna",
|
||||
},
|
||||
{
|
||||
name: "column with table prefix and ->> operator",
|
||||
input: "table.columna->>'val'",
|
||||
expected: "table.columna",
|
||||
},
|
||||
{
|
||||
name: "column with table prefix and -> operator",
|
||||
input: "table.columna->'key'",
|
||||
expected: "table.columna",
|
||||
},
|
||||
{
|
||||
name: "complex JSON path with ->>",
|
||||
input: "data->>'nested'->>'value'",
|
||||
expected: "data",
|
||||
},
|
||||
{
|
||||
name: "column with spaces before operator",
|
||||
input: "columna ->>'val'",
|
||||
expected: "columna",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := reflection.ExtractSourceColumn(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("reflection.ExtractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumnWithJSONOperators(t *testing.T) {
|
||||
// Create a test model
|
||||
type TestModel struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Data string `json:"data"` // JSON column
|
||||
Metadata string `json:"metadata"`
|
||||
}
|
||||
|
||||
validator := NewColumnValidator(TestModel{})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
column string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple valid column",
|
||||
column: "name",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid column with ->> operator",
|
||||
column: "data->>'field'",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid column with -> operator",
|
||||
column: "metadata->'key'",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid column",
|
||||
column: "invalid_column",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid column with ->> operator",
|
||||
column: "invalid_column->>'field'",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "cql prefixed column (always valid)",
|
||||
column: "cql_computed",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty column",
|
||||
column: "",
|
||||
shouldErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validator.ValidateColumn(tc.column)
|
||||
if tc.shouldErr && err == nil {
|
||||
t.Errorf("ValidateColumn(%q) expected error, got nil", tc.column)
|
||||
}
|
||||
if !tc.shouldErr && err != nil {
|
||||
t.Errorf("ValidateColumn(%q) expected no error, got %v", tc.column, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -103,3 +103,18 @@ func CatchPanicCallback(location string, cb func(err any)) {
|
||||
func CatchPanic(location string) {
|
||||
CatchPanicCallback(location, nil)
|
||||
}
|
||||
|
||||
// HandlePanic logs a panic and returns it as an error
|
||||
// This should be called with the result of recover() from a deferred function
|
||||
// Example usage:
|
||||
//
|
||||
// defer func() {
|
||||
// if r := recover(); r != nil {
|
||||
// err = logger.HandlePanic("MethodName", r)
|
||||
// }
|
||||
// }()
|
||||
func HandlePanic(methodName string, r any) error {
|
||||
stack := debug.Stack()
|
||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
||||
return fmt.Errorf("panic in %s: %v", methodName, r)
|
||||
}
|
||||
|
||||
@@ -26,8 +26,7 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||
}
|
||||
}()
|
||||
|
||||
var lst []ModelFieldDetail
|
||||
lst = make([]ModelFieldDetail, 0)
|
||||
lst := make([]ModelFieldDetail, 0)
|
||||
|
||||
if !record.IsValid() {
|
||||
return lst
|
||||
|
||||
@@ -17,3 +17,33 @@ func Len(v any) int {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractTableNameOnly extracts the table name from a fully qualified table reference.
|
||||
// It removes any schema prefix (e.g., "schema.table" -> "table") and truncates at
|
||||
// the first delimiter (comma, space, tab, or newline). If the input contains multiple
|
||||
// dots, it returns everything after the last dot up to the first delimiter.
|
||||
func ExtractTableNameOnly(fullName string) string {
|
||||
// First, split by dot to remove schema prefix if present
|
||||
lastDotIndex := -1
|
||||
for i, char := range fullName {
|
||||
if char == '.' {
|
||||
lastDotIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
// Start from after the last dot (or from beginning if no dot)
|
||||
startIndex := 0
|
||||
if lastDotIndex != -1 {
|
||||
startIndex = lastDotIndex + 1
|
||||
}
|
||||
|
||||
// Now find the end (first delimiter after the table name)
|
||||
for i := startIndex; i < len(fullName); i++ {
|
||||
char := rune(fullName[i])
|
||||
if char == ',' || char == ' ' || char == '\t' || char == '\n' {
|
||||
return fullName[startIndex:i]
|
||||
}
|
||||
}
|
||||
|
||||
return fullName[startIndex:]
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
@@ -132,7 +134,7 @@ func findFieldByName(val reflect.Value, name string) any {
|
||||
}
|
||||
|
||||
// Check if field name matches
|
||||
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() {
|
||||
if strings.EqualFold(field.Name, name) && fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
@@ -323,6 +325,127 @@ func ExtractColumnFromBunTag(tag string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetSQLModelColumns extracts column names that have valid SQL field mappings
|
||||
// This function only returns columns that:
|
||||
// 1. Have bun or gorm tags (not just json tags)
|
||||
// 2. Are not relations (no rel:, join:, foreignKey, references, many2many tags)
|
||||
// 3. Are not scan-only embedded fields
|
||||
func GetSQLModelColumns(model any) []string {
|
||||
var columns []string
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Validate that we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return columns
|
||||
}
|
||||
|
||||
collectSQLColumnsFromType(modelType, &columns, false)
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// collectSQLColumnsFromType recursively collects SQL column names from a struct type
|
||||
// scanOnlyEmbedded indicates if we're inside a scan-only embedded struct
|
||||
func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbedded bool) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Check if the embedded struct itself is scan-only
|
||||
isScanOnly := scanOnlyEmbedded
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||
isScanOnly = true
|
||||
}
|
||||
|
||||
// Recursively process embedded struct
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
collectSQLColumnsFromType(fieldType, columns, isScanOnly)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Skip fields in scan-only embedded structs
|
||||
if scanOnlyEmbedded {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get bun and gorm tags
|
||||
bunTag := field.Tag.Get("bun")
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
|
||||
// Skip if neither bun nor gorm tag exists
|
||||
if bunTag == "" && gormTag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if explicitly marked with "-"
|
||||
if bunTag == "-" || gormTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if field itself is scan-only (bun)
|
||||
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if field itself is read-only (gorm)
|
||||
if gormTag != "" && isGormFieldReadOnly(gormTag) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip relation fields (bun)
|
||||
if bunTag != "" {
|
||||
// Skip if it's a bun relation (rel:, join:, or m2m:)
|
||||
if strings.Contains(bunTag, "rel:") ||
|
||||
strings.Contains(bunTag, "join:") ||
|
||||
strings.Contains(bunTag, "m2m:") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Skip relation fields (gorm)
|
||||
if gormTag != "" {
|
||||
// Skip if it has gorm relationship tags
|
||||
if strings.Contains(gormTag, "foreignKey:") ||
|
||||
strings.Contains(gormTag, "references:") ||
|
||||
strings.Contains(gormTag, "many2many:") ||
|
||||
strings.Contains(gormTag, "constraint:") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Get column name
|
||||
columnName := ""
|
||||
if bunTag != "" {
|
||||
columnName = ExtractColumnFromBunTag(bunTag)
|
||||
}
|
||||
if columnName == "" && gormTag != "" {
|
||||
columnName = ExtractColumnFromGormTag(gormTag)
|
||||
}
|
||||
|
||||
// Skip if we couldn't extract a column name
|
||||
if columnName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
*columns = append(*columns, columnName)
|
||||
}
|
||||
}
|
||||
|
||||
// IsColumnWritable checks if a column can be written to in the database
|
||||
// For bun: returns false if the field has "scanonly" tag
|
||||
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
||||
@@ -351,7 +474,7 @@ func IsColumnWritable(model any, columnName string) bool {
|
||||
|
||||
// isColumnWritableInType recursively searches for a column and checks if it's writable
|
||||
// Returns (found, writable) where found indicates if the column was found
|
||||
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) {
|
||||
func isColumnWritableInType(typ reflect.Type, columnName string) (found bool, writable bool) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
@@ -440,3 +563,321 @@ func isGormFieldReadOnly(tag string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractSourceColumn extracts the base column name from PostgreSQL JSON operators
|
||||
// Examples:
|
||||
// - "columna->>'val'" returns "columna"
|
||||
// - "columna->'key'" returns "columna"
|
||||
// - "columna" returns "columna"
|
||||
// - "table.columna->>'val'" returns "table.columna"
|
||||
func ExtractSourceColumn(colName string) string {
|
||||
// Check for PostgreSQL JSON operators: -> and ->>
|
||||
if idx := strings.Index(colName, "->>"); idx != -1 {
|
||||
return strings.TrimSpace(colName[:idx])
|
||||
}
|
||||
if idx := strings.Index(colName, "->"); idx != -1 {
|
||||
return strings.TrimSpace(colName[:idx])
|
||||
}
|
||||
return colName
|
||||
}
|
||||
|
||||
// ToSnakeCase converts a string from CamelCase to snake_case
|
||||
func ToSnakeCase(s string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range s {
|
||||
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
result.WriteRune(r)
|
||||
}
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// GetColumnTypeFromModel uses reflection to determine the Go type of a column in a model
|
||||
func GetColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
||||
if model == nil {
|
||||
return reflect.Invalid
|
||||
}
|
||||
|
||||
// Extract the source column name (remove JSON operators like ->> or ->)
|
||||
sourceColName := ExtractSourceColumn(colName)
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
// Dereference pointer if needed
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a struct
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
return reflect.Invalid
|
||||
}
|
||||
|
||||
// Find the field by JSON tag or field name
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// Check JSON tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" {
|
||||
// Parse JSON tag (format: "name,omitempty")
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if parts[0] == sourceColName {
|
||||
return field.Type.Kind()
|
||||
}
|
||||
}
|
||||
|
||||
// Check field name (case-insensitive)
|
||||
if strings.EqualFold(field.Name, sourceColName) {
|
||||
return field.Type.Kind()
|
||||
}
|
||||
|
||||
// Check snake_case conversion
|
||||
snakeCaseName := ToSnakeCase(field.Name)
|
||||
if snakeCaseName == sourceColName {
|
||||
return field.Type.Kind()
|
||||
}
|
||||
}
|
||||
|
||||
return reflect.Invalid
|
||||
}
|
||||
|
||||
// IsNumericType checks if a reflect.Kind is a numeric type
|
||||
func IsNumericType(kind reflect.Kind) bool {
|
||||
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
|
||||
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
|
||||
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
|
||||
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
|
||||
}
|
||||
|
||||
// IsStringType checks if a reflect.Kind is a string type
|
||||
func IsStringType(kind reflect.Kind) bool {
|
||||
return kind == reflect.String
|
||||
}
|
||||
|
||||
// IsNumericValue checks if a string value can be parsed as a number
|
||||
func IsNumericValue(value string) bool {
|
||||
value = strings.TrimSpace(value)
|
||||
_, err := strconv.ParseFloat(value, 64)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ConvertToNumericType converts a string value to the appropriate numeric type
|
||||
func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
switch kind {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
// Parse as integer
|
||||
bitSize := 64
|
||||
switch kind {
|
||||
case reflect.Int8:
|
||||
bitSize = 8
|
||||
case reflect.Int16:
|
||||
bitSize = 16
|
||||
case reflect.Int32:
|
||||
bitSize = 32
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseInt(value, 10, bitSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid integer value: %w", err)
|
||||
}
|
||||
|
||||
// Return the appropriate type
|
||||
switch kind {
|
||||
case reflect.Int:
|
||||
return int(intVal), nil
|
||||
case reflect.Int8:
|
||||
return int8(intVal), nil
|
||||
case reflect.Int16:
|
||||
return int16(intVal), nil
|
||||
case reflect.Int32:
|
||||
return int32(intVal), nil
|
||||
case reflect.Int64:
|
||||
return intVal, nil
|
||||
}
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
// Parse as unsigned integer
|
||||
bitSize := 64
|
||||
switch kind {
|
||||
case reflect.Uint8:
|
||||
bitSize = 8
|
||||
case reflect.Uint16:
|
||||
bitSize = 16
|
||||
case reflect.Uint32:
|
||||
bitSize = 32
|
||||
}
|
||||
|
||||
uintVal, err := strconv.ParseUint(value, 10, bitSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
|
||||
}
|
||||
|
||||
// Return the appropriate type
|
||||
switch kind {
|
||||
case reflect.Uint:
|
||||
return uint(uintVal), nil
|
||||
case reflect.Uint8:
|
||||
return uint8(uintVal), nil
|
||||
case reflect.Uint16:
|
||||
return uint16(uintVal), nil
|
||||
case reflect.Uint32:
|
||||
return uint32(uintVal), nil
|
||||
case reflect.Uint64:
|
||||
return uintVal, nil
|
||||
}
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
// Parse as float
|
||||
bitSize := 64
|
||||
if kind == reflect.Float32 {
|
||||
bitSize = 32
|
||||
}
|
||||
|
||||
floatVal, err := strconv.ParseFloat(value, bitSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid float value: %w", err)
|
||||
}
|
||||
|
||||
if kind == reflect.Float32 {
|
||||
return float32(floatVal), nil
|
||||
}
|
||||
return floatVal, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
|
||||
}
|
||||
|
||||
// GetRelationModel gets the model type for a relation field
|
||||
// It searches for the field by name in the following order (case-insensitive):
|
||||
// 1. Actual field name
|
||||
// 2. Bun tag name (if exists)
|
||||
// 3. Gorm tag name (if exists)
|
||||
// 4. JSON tag name (if exists)
|
||||
//
|
||||
// Supports recursive field paths using dot notation (e.g., "MAL.MAL.DEF")
|
||||
// For nested fields, it traverses through each level of the struct hierarchy
|
||||
func GetRelationModel(model interface{}, fieldName string) interface{} {
|
||||
if model == nil || fieldName == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Split the field name by "." to handle nested/recursive relations
|
||||
fieldParts := strings.Split(fieldName, ".")
|
||||
|
||||
// Start with the current model
|
||||
currentModel := model
|
||||
|
||||
// Traverse through each level of the field path
|
||||
for _, part := range fieldParts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
currentModel = getRelationModelSingleLevel(currentModel, part)
|
||||
if currentModel == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return currentModel
|
||||
}
|
||||
|
||||
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
||||
// This is a helper function used by GetRelationModel to handle one level at a time
|
||||
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
||||
if model == nil || fieldName == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find the field by checking in priority order (case-insensitive)
|
||||
var field *reflect.StructField
|
||||
normalizedFieldName := strings.ToLower(fieldName)
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
f := modelType.Field(i)
|
||||
|
||||
// 1. Check actual field name (case-insensitive)
|
||||
if strings.EqualFold(f.Name, fieldName) {
|
||||
field = &f
|
||||
break
|
||||
}
|
||||
|
||||
// 2. Check bun tag name
|
||||
bunTag := f.Tag.Get("bun")
|
||||
if bunTag != "" {
|
||||
bunColName := ExtractColumnFromBunTag(bunTag)
|
||||
if bunColName != "" && strings.EqualFold(bunColName, normalizedFieldName) {
|
||||
field = &f
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Check gorm tag name
|
||||
gormTag := f.Tag.Get("gorm")
|
||||
if gormTag != "" {
|
||||
gormColName := ExtractColumnFromGormTag(gormTag)
|
||||
if gormColName != "" && strings.EqualFold(gormColName, normalizedFieldName) {
|
||||
field = &f
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Check JSON tag name
|
||||
jsonTag := f.Tag.Get("json")
|
||||
if jsonTag != "" {
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
|
||||
if strings.EqualFold(parts[0], normalizedFieldName) {
|
||||
field = &f
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if field == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the target type
|
||||
targetType := field.Type
|
||||
if targetType == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if targetType.Kind() == reflect.Slice {
|
||||
targetType = targetType.Elem()
|
||||
if targetType == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if targetType.Kind() == reflect.Ptr {
|
||||
targetType = targetType.Elem()
|
||||
if targetType == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if targetType.Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create a zero value of the target type
|
||||
return reflect.New(targetType).Elem().Interface()
|
||||
}
|
||||
|
||||
@@ -474,3 +474,143 @@ func TestIsColumnWritableWithEmbedded(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test models with relations for GetSQLModelColumns
|
||||
type User struct {
|
||||
ID int `bun:"id,pk" json:"id"`
|
||||
Name string `bun:"name" json:"name"`
|
||||
Email string `bun:"email" json:"email"`
|
||||
ProfileData string `json:"profile_data"` // No bun/gorm tag
|
||||
Posts []Post `bun:"rel:has-many,join:id=user_id" json:"posts"`
|
||||
Profile *Profile `bun:"rel:has-one,join:id=user_id" json:"profile"`
|
||||
RowNumber int64 `bun:",scanonly" json:"_rownumber"`
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||
Title string `gorm:"column:title" json:"title"`
|
||||
UserID int `gorm:"column:user_id;foreignKey" json:"user_id"`
|
||||
User *User `gorm:"foreignKey:UserID;references:ID" json:"user"`
|
||||
Tags []Tag `gorm:"many2many:post_tags" json:"tags"`
|
||||
Content string `json:"content"` // No bun/gorm tag
|
||||
}
|
||||
|
||||
type Profile struct {
|
||||
ID int `bun:"id,pk" json:"id"`
|
||||
Bio string `bun:"bio" json:"bio"`
|
||||
UserID int `bun:"user_id" json:"user_id"`
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||
Name string `gorm:"column:name" json:"name"`
|
||||
}
|
||||
|
||||
// Model with scan-only embedded struct
|
||||
type EntityWithScanOnlyEmbedded struct {
|
||||
ID int `bun:"id,pk" json:"id"`
|
||||
Name string `bun:"name" json:"name"`
|
||||
AdhocBuffer `bun:",scanonly"` // Entire embedded struct is scan-only
|
||||
}
|
||||
|
||||
func TestGetSQLModelColumns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Bun model with relations - excludes relations and non-SQL fields",
|
||||
model: User{},
|
||||
// Should include: id, name, email (has bun tags)
|
||||
// Should exclude: profile_data (no bun tag), Posts/Profile (relations), RowNumber (scan-only in embedded would be excluded)
|
||||
expected: []string{"id", "name", "email"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with relations - excludes relations and non-SQL fields",
|
||||
model: Post{},
|
||||
// Should include: id, title, user_id (has gorm tags)
|
||||
// Should exclude: content (no gorm tag), User/Tags (relations)
|
||||
expected: []string{"id", "title", "user_id"},
|
||||
},
|
||||
{
|
||||
name: "Model with embedded base and scan-only embedded",
|
||||
model: EntityWithScanOnlyEmbedded{},
|
||||
// Should include: id, name from main struct
|
||||
// Should exclude: all fields from AdhocBuffer (scan-only embedded struct)
|
||||
expected: []string{"id", "name"},
|
||||
},
|
||||
{
|
||||
name: "Model with embedded - includes SQL fields, excludes scan-only",
|
||||
model: ModelWithEmbedded{},
|
||||
// Should include: rid_base, created_at (from BaseModel), name, description (from main)
|
||||
// Should exclude: cql1, cql2, _rownumber (from AdhocBuffer - scan-only fields)
|
||||
expected: []string{"rid_base", "created_at", "name", "description"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded - includes SQL fields, excludes scan-only",
|
||||
model: GormModelWithEmbedded{},
|
||||
// Should include: rid_base, created_at (from GormBaseModel), name, description (from main)
|
||||
// Should exclude: cql1, cql2 (scan-only), _rownumber (no gorm column tag, marked as -)
|
||||
expected: []string{"rid_base", "created_at", "name", "description"},
|
||||
},
|
||||
{
|
||||
name: "Simple Profile model",
|
||||
model: Profile{},
|
||||
// Should include all fields with bun tags
|
||||
expected: []string{"id", "bio", "user_id"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetSQLModelColumns(tt.model)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("GetSQLModelColumns() returned %d columns, want %d.\nGot: %v\nWant: %v",
|
||||
len(result), len(tt.expected), result, tt.expected)
|
||||
return
|
||||
}
|
||||
for i, col := range result {
|
||||
if col != tt.expected[i] {
|
||||
t.Errorf("GetSQLModelColumns()[%d] = %v, want %v.\nFull result: %v",
|
||||
i, col, tt.expected[i], result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSQLModelColumnsVsGetModelColumns(t *testing.T) {
|
||||
// Demonstrate the difference between GetModelColumns and GetSQLModelColumns
|
||||
user := User{}
|
||||
|
||||
allColumns := GetModelColumns(user)
|
||||
sqlColumns := GetSQLModelColumns(user)
|
||||
|
||||
t.Logf("GetModelColumns(User): %v", allColumns)
|
||||
t.Logf("GetSQLModelColumns(User): %v", sqlColumns)
|
||||
|
||||
// GetModelColumns should return more columns (includes fields with only json tags)
|
||||
if len(allColumns) <= len(sqlColumns) {
|
||||
t.Errorf("Expected GetModelColumns to return more columns than GetSQLModelColumns")
|
||||
}
|
||||
|
||||
// GetSQLModelColumns should not include 'profile_data' (no bun tag)
|
||||
for _, col := range sqlColumns {
|
||||
if col == "profile_data" {
|
||||
t.Errorf("GetSQLModelColumns should not include 'profile_data' (no bun/gorm tag)")
|
||||
}
|
||||
}
|
||||
|
||||
// GetModelColumns should include 'profile_data' (has json tag)
|
||||
hasProfileData := false
|
||||
for _, col := range allColumns {
|
||||
if col == "profile_data" {
|
||||
hasProfileData = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasProfileData {
|
||||
t.Errorf("GetModelColumns should include 'profile_data' (has json tag)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,6 +191,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
query = query.Table(tableName)
|
||||
}
|
||||
|
||||
if len(options.Columns) == 0 && (len(options.ComputedColumns) > 0) {
|
||||
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
|
||||
options.Columns = reflection.GetSQLModelColumns(model)
|
||||
}
|
||||
|
||||
// Apply column selection
|
||||
if len(options.Columns) > 0 {
|
||||
logger.Debug("Selecting columns: %v", options.Columns)
|
||||
@@ -1132,15 +1137,25 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
// ORMs like GORM and Bun expect the struct field name, not the JSON name
|
||||
relationFieldName := relInfo.fieldName
|
||||
|
||||
// For now, we'll preload without conditions
|
||||
// TODO: Implement column selection and filtering for preloads
|
||||
// This requires a more sophisticated approach with callbacks or query builders
|
||||
// Apply preloading
|
||||
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||
if len(preload.Where) > 0 {
|
||||
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, relationFieldName)
|
||||
if err != nil {
|
||||
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
|
||||
panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err))
|
||||
}
|
||||
preload.Where = fixedWhere
|
||||
}
|
||||
|
||||
logger.Debug("Applying preload: %s", relationFieldName)
|
||||
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
||||
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
|
||||
preload.Columns = reflection.GetSQLModelColumns(model)
|
||||
}
|
||||
|
||||
// Handle column selection and omission
|
||||
if len(preload.OmitColumns) > 0 {
|
||||
allCols := reflection.GetModelColumns(model)
|
||||
allCols := reflection.GetSQLModelColumns(model)
|
||||
// Remove omitted columns
|
||||
preload.Columns = []string{}
|
||||
for _, col := range allCols {
|
||||
@@ -1194,7 +1209,10 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
}
|
||||
|
||||
if len(preload.Where) > 0 {
|
||||
sq = sq.Where(preload.Where)
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||
if len(sanitizedWhere) > 0 {
|
||||
sq = sq.Where(sanitizedWhere)
|
||||
}
|
||||
}
|
||||
|
||||
if preload.Limit != nil && *preload.Limit > 0 {
|
||||
|
||||
@@ -13,6 +13,7 @@ const (
|
||||
contextKeyTableName contextKey = "tableName"
|
||||
contextKeyModel contextKey = "model"
|
||||
contextKeyModelPtr contextKey = "modelPtr"
|
||||
contextKeyOptions contextKey = "options"
|
||||
)
|
||||
|
||||
// WithSchema adds schema to context
|
||||
@@ -74,12 +75,28 @@ func GetModelPtr(ctx context.Context) interface{} {
|
||||
return ctx.Value(contextKeyModelPtr)
|
||||
}
|
||||
|
||||
// WithOptions adds request options to context
|
||||
func WithOptions(ctx context.Context, options ExtendedRequestOptions) context.Context {
|
||||
return context.WithValue(ctx, contextKeyOptions, options)
|
||||
}
|
||||
|
||||
// GetOptions retrieves request options from context
|
||||
func GetOptions(ctx context.Context) *ExtendedRequestOptions {
|
||||
if v := ctx.Value(contextKeyOptions); v != nil {
|
||||
if opts, ok := v.(ExtendedRequestOptions); ok {
|
||||
return &opts
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithRequestData adds all request-scoped data to context at once
|
||||
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
|
||||
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}, options ExtendedRequestOptions) context.Context {
|
||||
ctx = WithSchema(ctx, schema)
|
||||
ctx = WithEntity(ctx, entity)
|
||||
ctx = WithTableName(ctx, tableName)
|
||||
ctx = WithModel(ctx, model)
|
||||
ctx = WithModelPtr(ctx, modelPtr)
|
||||
ctx = WithOptions(ctx, options)
|
||||
return ctx
|
||||
}
|
||||
|
||||
@@ -65,9 +65,6 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
||||
entity := params["entity"]
|
||||
id := params["id"]
|
||||
|
||||
// Parse options from headers (now returns ExtendedRequestOptions)
|
||||
options := h.parseOptionsFromHeaders(r)
|
||||
|
||||
// Determine operation based on HTTP method
|
||||
method := r.Method()
|
||||
|
||||
@@ -104,13 +101,16 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
||||
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
|
||||
// Add request-scoped data to context
|
||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr)
|
||||
// Parse options from headers - this now includes relation name resolution
|
||||
options := h.parseOptionsFromHeaders(r, model)
|
||||
|
||||
// Validate and filter columns in options (log warnings for invalid columns)
|
||||
validator := common.NewColumnValidator(model)
|
||||
options = filterExtendedOptions(validator, options)
|
||||
|
||||
// Add request-scoped data to context (including options)
|
||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
||||
|
||||
switch method {
|
||||
case "GET":
|
||||
if id != "" {
|
||||
@@ -260,6 +260,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
query = query.Table(tableName)
|
||||
}
|
||||
|
||||
// If we have computed columns/expressions but options.Columns is empty,
|
||||
// populate it with all model columns first since computed columns are additions
|
||||
if len(options.Columns) == 0 && (len(options.ComputedQL) > 0 || len(options.ComputedColumns) > 0) {
|
||||
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
|
||||
options.Columns = reflection.GetSQLModelColumns(model)
|
||||
}
|
||||
|
||||
// Apply ComputedQL fields if any
|
||||
if len(options.ComputedQL) > 0 {
|
||||
for colName, colExpr := range options.ComputedQL {
|
||||
@@ -340,50 +347,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
for idx := range options.Preload {
|
||||
preload := options.Preload[idx]
|
||||
logger.Debug("Applying preload: %s", preload.Relation)
|
||||
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
||||
if len(preload.OmitColumns) > 0 {
|
||||
allCols := reflection.GetModelColumns(model)
|
||||
// Remove omitted columns
|
||||
preload.Columns = []string{}
|
||||
for _, col := range allCols {
|
||||
addCols := true
|
||||
for _, omitCol := range preload.OmitColumns {
|
||||
if col == omitCol {
|
||||
addCols = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if addCols {
|
||||
preload.Columns = append(preload.Columns, col)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(preload.Columns) > 0 {
|
||||
sq = sq.Column(preload.Columns...)
|
||||
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||
if len(preload.Where) > 0 {
|
||||
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, preload.Relation)
|
||||
if err != nil {
|
||||
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
|
||||
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",
|
||||
fmt.Sprintf("Invalid preload WHERE clause for relation '%s'", preload.Relation), err)
|
||||
return
|
||||
}
|
||||
preload.Where = fixedWhere
|
||||
}
|
||||
|
||||
if len(preload.Filters) > 0 {
|
||||
for _, filter := range preload.Filters {
|
||||
sq = h.applyFilter(sq, filter, "", false, "AND")
|
||||
}
|
||||
}
|
||||
if len(preload.Sort) > 0 {
|
||||
for _, sort := range preload.Sort {
|
||||
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||
}
|
||||
}
|
||||
|
||||
if len(preload.Where) > 0 {
|
||||
sq = sq.Where(preload.Where)
|
||||
}
|
||||
|
||||
if preload.Limit != nil && *preload.Limit > 0 {
|
||||
sq = sq.Limit(*preload.Limit)
|
||||
}
|
||||
|
||||
return sq
|
||||
})
|
||||
// Apply the preload with recursive support
|
||||
query = h.applyPreloadWithRecursion(query, preload, model, 0)
|
||||
}
|
||||
|
||||
// Apply DISTINCT if requested
|
||||
@@ -413,13 +391,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply custom SQL WHERE clause (AND condition)
|
||||
if options.CustomSQLWhere != "" {
|
||||
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
||||
query = query.Where(options.CustomSQLWhere)
|
||||
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
||||
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
||||
if sanitizedWhere != "" {
|
||||
query = query.Where(sanitizedWhere)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom SQL WHERE clause (OR condition)
|
||||
if options.CustomSQLOr != "" {
|
||||
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
||||
query = query.WhereOr(options.CustomSQLOr)
|
||||
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
||||
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
||||
if sanitizedOr != "" {
|
||||
query = query.WhereOr(sanitizedOr)
|
||||
}
|
||||
}
|
||||
|
||||
// If ID is provided, filter by ID
|
||||
@@ -495,7 +481,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply cursor filter to query
|
||||
if cursorFilter != "" {
|
||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||
query = query.Where(cursorFilter)
|
||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName))
|
||||
if sanitizedCursor != "" {
|
||||
query = query.Where(sanitizedCursor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -569,6 +558,143 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
h.sendFormattedResponse(w, modelPtr, metadata, options)
|
||||
}
|
||||
|
||||
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
||||
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery {
|
||||
// Log relationship keys if they're specified (from XFiles)
|
||||
if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" {
|
||||
logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s",
|
||||
preload.Relation, preload.PrimaryKey, preload.RelatedKey, preload.ForeignKey)
|
||||
|
||||
// Build a WHERE clause using the relationship keys if needed
|
||||
// Note: Bun's PreloadRelation typically handles the relationship join automatically via struct tags
|
||||
// However, if the relationship keys are explicitly provided from XFiles, we can use them
|
||||
// to add additional filtering or validation
|
||||
if preload.RelatedKey != "" && preload.Where == "" {
|
||||
// For child tables: ensure the child's relatedkey column will be matched
|
||||
// The actual parent value is dynamic and handled by Bun's preload mechanism
|
||||
// We just log this for visibility
|
||||
logger.Debug("Child table %s will be filtered by %s matching parent's primary key",
|
||||
preload.Relation, preload.RelatedKey)
|
||||
}
|
||||
if preload.ForeignKey != "" && preload.Where == "" {
|
||||
// For parent tables: ensure the parent's primary key matches the current table's foreign key
|
||||
logger.Debug("Parent table %s will be filtered by primary key matching current table's %s",
|
||||
preload.Relation, preload.ForeignKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the preload
|
||||
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
||||
// Get the related model for column operations
|
||||
relatedModel := reflection.GetRelationModel(model, preload.Relation)
|
||||
if relatedModel == nil {
|
||||
logger.Warn("Could not get related model for preload: %s", preload.Relation)
|
||||
// relatedModel = model // fallback to parent model
|
||||
} else {
|
||||
|
||||
// If we have computed columns but no explicit columns, populate with all model columns first
|
||||
// since computed columns are additions
|
||||
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
|
||||
logger.Debug("Populating preload columns with all model columns since computed columns are additions")
|
||||
preload.Columns = reflection.GetSQLModelColumns(relatedModel)
|
||||
}
|
||||
|
||||
// Apply ComputedQL fields if any
|
||||
if len(preload.ComputedQL) > 0 {
|
||||
for colName, colExpr := range preload.ComputedQL {
|
||||
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
|
||||
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName))
|
||||
// Remove the computed column from selected columns to avoid duplication
|
||||
for colIndex := range preload.Columns {
|
||||
if preload.Columns[colIndex] == colName {
|
||||
preload.Columns = append(preload.Columns[:colIndex], preload.Columns[colIndex+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle OmitColumns
|
||||
if len(preload.OmitColumns) > 0 {
|
||||
allCols := preload.Columns
|
||||
// Remove omitted columns
|
||||
preload.Columns = []string{}
|
||||
for _, col := range allCols {
|
||||
addCols := true
|
||||
for _, omitCol := range preload.OmitColumns {
|
||||
if col == omitCol {
|
||||
addCols = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if addCols {
|
||||
preload.Columns = append(preload.Columns, col)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply column selection
|
||||
if len(preload.Columns) > 0 {
|
||||
sq = sq.Column(preload.Columns...)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply filters
|
||||
if len(preload.Filters) > 0 {
|
||||
for _, filter := range preload.Filters {
|
||||
sq = h.applyFilter(sq, filter, "", false, "AND")
|
||||
}
|
||||
}
|
||||
|
||||
// Apply sorting
|
||||
if len(preload.Sort) > 0 {
|
||||
for _, sort := range preload.Sort {
|
||||
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Apply WHERE clause
|
||||
if len(preload.Where) > 0 {
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||
if len(sanitizedWhere) > 0 {
|
||||
sq = sq.Where(sanitizedWhere)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply limit
|
||||
if preload.Limit != nil && *preload.Limit > 0 {
|
||||
sq = sq.Limit(*preload.Limit)
|
||||
}
|
||||
|
||||
if preload.Offset != nil && *preload.Offset > 0 {
|
||||
sq = sq.Offset(*preload.Offset)
|
||||
}
|
||||
|
||||
return sq
|
||||
})
|
||||
|
||||
// Handle recursive preloading
|
||||
if preload.Recursive && depth < 5 {
|
||||
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
|
||||
|
||||
// For recursive relationships, we need to get the last part of the relation path
|
||||
// e.g., "MastertaskItems" -> "MastertaskItems.MastertaskItems"
|
||||
relationParts := strings.Split(preload.Relation, ".")
|
||||
lastRelationName := relationParts[len(relationParts)-1]
|
||||
|
||||
// Create a recursive preload with the same configuration
|
||||
// but with the relation path extended
|
||||
recursivePreload := preload
|
||||
recursivePreload.Relation = preload.Relation + "." + lastRelationName
|
||||
|
||||
// Recursively apply preload until we reach depth 5
|
||||
query = h.applyPreloadWithRecursion(query, recursivePreload, model, depth+1)
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
||||
// Capture panics and return error response
|
||||
defer func() {
|
||||
@@ -663,7 +789,14 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
||||
}
|
||||
|
||||
// Create insert query
|
||||
query := tx.NewInsert().Model(modelValue).Table(tableName).Returning("*")
|
||||
query := tx.NewInsert().Model(modelValue)
|
||||
|
||||
// Only set Table() if the model doesn't provide a table name via TableNameProvider
|
||||
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||
query = query.Table(tableName)
|
||||
}
|
||||
|
||||
query = query.Returning("*")
|
||||
|
||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||
itemHookCtx := &HookContext{
|
||||
@@ -1222,7 +1355,7 @@ func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
|
||||
func (h *Handler) extractNestedRelations(
|
||||
data map[string]interface{},
|
||||
model interface{},
|
||||
) (map[string]interface{}, map[string]interface{}, error) {
|
||||
) (_cleanedData map[string]interface{}, _relations map[string]interface{}, _err error) {
|
||||
// Get model type for reflection
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
@@ -1640,13 +1773,9 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
|
||||
data = h.normalizeResultArray(data)
|
||||
}
|
||||
|
||||
response := common.Response{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Metadata: metadata,
|
||||
}
|
||||
// Return data as-is without wrapping in common.Response
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := w.WriteJSON(response); err != nil {
|
||||
if err := w.WriteJSON(data); err != nil {
|
||||
logger.Error("Failed to write JSON response: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1655,7 +1784,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
|
||||
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
||||
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
||||
if data == nil {
|
||||
return data
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use reflection to check if data is a slice or array
|
||||
|
||||
@@ -2,6 +2,7 @@ package restheadspec
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -9,6 +10,7 @@ import (
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// ExtendedRequestOptions extends common.RequestOptions with additional features
|
||||
@@ -42,6 +44,9 @@ type ExtendedRequestOptions struct {
|
||||
|
||||
// Transaction
|
||||
AtomicTransaction bool
|
||||
|
||||
// X-Files configuration - comprehensive query options as a single JSON object
|
||||
XFiles *XFiles
|
||||
}
|
||||
|
||||
// ExpandOption represents a relation expansion configuration
|
||||
@@ -95,7 +100,8 @@ func DecodeParam(pStr string) (string, error) {
|
||||
}
|
||||
|
||||
// parseOptionsFromHeaders parses all request options from HTTP headers
|
||||
func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptions {
|
||||
// If model is provided, it will resolve table names to field names in preload/expand options
|
||||
func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) ExtendedRequestOptions {
|
||||
options := ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Filters: make([]common.FilterOption, 0),
|
||||
@@ -105,105 +111,140 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
|
||||
AdvancedSQL: make(map[string]string),
|
||||
ComputedQL: make(map[string]string),
|
||||
Expand: make([]ExpandOption, 0),
|
||||
ResponseFormat: "simple", // Default response format
|
||||
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
|
||||
ResponseFormat: "simple", // Default response format
|
||||
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
|
||||
}
|
||||
|
||||
// Get all headers
|
||||
headers := r.AllHeaders()
|
||||
|
||||
// Process each header
|
||||
for key, value := range headers {
|
||||
// Normalize header key to lowercase for consistent matching
|
||||
normalizedKey := strings.ToLower(key)
|
||||
// Get all query parameters
|
||||
queryParams := r.AllQueryParams()
|
||||
|
||||
// Merge headers and query parameters - query parameters take precedence
|
||||
// This allows the same parameters to be specified in either headers or query string
|
||||
// Normalize keys to lowercase to ensure query params properly override headers
|
||||
combinedParams := make(map[string]string)
|
||||
for key, value := range headers {
|
||||
combinedParams[strings.ToLower(key)] = value
|
||||
}
|
||||
for key, value := range queryParams {
|
||||
combinedParams[strings.ToLower(key)] = value
|
||||
}
|
||||
|
||||
// Process each parameter (from both headers and query params)
|
||||
// Note: keys are already normalized to lowercase in combinedParams
|
||||
for key, value := range combinedParams {
|
||||
// Decode value if it's base64 encoded
|
||||
decodedValue := decodeHeaderValue(value)
|
||||
|
||||
// Parse based on header prefix/name
|
||||
// Parse based on parameter prefix/name
|
||||
switch {
|
||||
// Field Selection
|
||||
case strings.HasPrefix(normalizedKey, "x-select-fields"):
|
||||
case strings.HasPrefix(key, "x-select-fields"):
|
||||
h.parseSelectFields(&options, decodedValue)
|
||||
case strings.HasPrefix(normalizedKey, "x-not-select-fields"):
|
||||
case strings.HasPrefix(key, "x-not-select-fields"):
|
||||
h.parseNotSelectFields(&options, decodedValue)
|
||||
case strings.HasPrefix(normalizedKey, "x-clean-json"):
|
||||
case strings.HasPrefix(key, "x-clean-json"):
|
||||
options.CleanJSON = strings.EqualFold(decodedValue, "true")
|
||||
|
||||
// Filtering & Search
|
||||
case strings.HasPrefix(normalizedKey, "x-fieldfilter-"):
|
||||
h.parseFieldFilter(&options, normalizedKey, decodedValue)
|
||||
case strings.HasPrefix(normalizedKey, "x-searchfilter-"):
|
||||
h.parseSearchFilter(&options, normalizedKey, decodedValue)
|
||||
case strings.HasPrefix(normalizedKey, "x-searchop-"):
|
||||
h.parseSearchOp(&options, normalizedKey, decodedValue, "AND")
|
||||
case strings.HasPrefix(normalizedKey, "x-searchor-"):
|
||||
h.parseSearchOp(&options, normalizedKey, decodedValue, "OR")
|
||||
case strings.HasPrefix(normalizedKey, "x-searchand-"):
|
||||
h.parseSearchOp(&options, normalizedKey, decodedValue, "AND")
|
||||
case strings.HasPrefix(normalizedKey, "x-searchcols"):
|
||||
case strings.HasPrefix(key, "x-fieldfilter-"):
|
||||
h.parseFieldFilter(&options, key, decodedValue)
|
||||
case strings.HasPrefix(key, "x-searchfilter-"):
|
||||
h.parseSearchFilter(&options, key, decodedValue)
|
||||
case strings.HasPrefix(key, "x-searchop-"):
|
||||
h.parseSearchOp(&options, key, decodedValue, "AND")
|
||||
case strings.HasPrefix(key, "x-searchor-"):
|
||||
h.parseSearchOp(&options, key, decodedValue, "OR")
|
||||
case strings.HasPrefix(key, "x-searchand-"):
|
||||
h.parseSearchOp(&options, key, decodedValue, "AND")
|
||||
case strings.HasPrefix(key, "x-searchcols"):
|
||||
options.SearchColumns = h.parseCommaSeparated(decodedValue)
|
||||
case strings.HasPrefix(normalizedKey, "x-custom-sql-w"):
|
||||
case strings.HasPrefix(key, "x-custom-sql-w"):
|
||||
options.CustomSQLWhere = decodedValue
|
||||
case strings.HasPrefix(normalizedKey, "x-custom-sql-or"):
|
||||
case strings.HasPrefix(key, "x-custom-sql-or"):
|
||||
options.CustomSQLOr = decodedValue
|
||||
|
||||
// Joins & Relations
|
||||
case strings.HasPrefix(normalizedKey, "x-preload"):
|
||||
if strings.HasSuffix(normalizedKey, "-where") {
|
||||
case strings.HasPrefix(key, "x-preload"):
|
||||
if strings.HasSuffix(key, "-where") {
|
||||
continue
|
||||
}
|
||||
whereClaude := headers[fmt.Sprintf("%s-where", key)]
|
||||
whereClaude := combinedParams[fmt.Sprintf("%s-where", key)]
|
||||
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
|
||||
|
||||
case strings.HasPrefix(normalizedKey, "x-expand"):
|
||||
case strings.HasPrefix(key, "x-expand"):
|
||||
h.parseExpand(&options, decodedValue)
|
||||
case strings.HasPrefix(normalizedKey, "x-custom-sql-join"):
|
||||
case strings.HasPrefix(key, "x-custom-sql-join"):
|
||||
// TODO: Implement custom SQL join
|
||||
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
|
||||
|
||||
// Sorting & Pagination
|
||||
case strings.HasPrefix(normalizedKey, "x-sort"):
|
||||
case strings.HasPrefix(key, "x-sort"):
|
||||
h.parseSorting(&options, decodedValue)
|
||||
case strings.HasPrefix(normalizedKey, "x-limit"):
|
||||
// Special cases for older clients using sort(a,b,-c) syntax
|
||||
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
|
||||
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
|
||||
h.parseSorting(&options, sortValue)
|
||||
case strings.HasPrefix(key, "x-limit"):
|
||||
if limit, err := strconv.Atoi(decodedValue); err == nil {
|
||||
options.Limit = &limit
|
||||
}
|
||||
case strings.HasPrefix(normalizedKey, "x-offset"):
|
||||
// Special cases for older clients using limit(n) syntax
|
||||
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
|
||||
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
|
||||
limitValueParts := strings.Split(limitValue, ",")
|
||||
|
||||
if len(limitValueParts) > 1 {
|
||||
if offset, err := strconv.Atoi(limitValueParts[0]); err == nil {
|
||||
options.Offset = &offset
|
||||
}
|
||||
if limit, err := strconv.Atoi(limitValueParts[1]); err == nil {
|
||||
options.Limit = &limit
|
||||
}
|
||||
} else {
|
||||
if limit, err := strconv.Atoi(limitValueParts[0]); err == nil {
|
||||
options.Limit = &limit
|
||||
}
|
||||
}
|
||||
|
||||
case strings.HasPrefix(key, "x-offset"):
|
||||
if offset, err := strconv.Atoi(decodedValue); err == nil {
|
||||
options.Offset = &offset
|
||||
}
|
||||
case strings.HasPrefix(normalizedKey, "x-cursor-forward"):
|
||||
|
||||
case strings.HasPrefix(key, "x-cursor-forward"):
|
||||
options.CursorForward = decodedValue
|
||||
case strings.HasPrefix(normalizedKey, "x-cursor-backward"):
|
||||
case strings.HasPrefix(key, "x-cursor-backward"):
|
||||
options.CursorBackward = decodedValue
|
||||
|
||||
// Advanced Features
|
||||
case strings.HasPrefix(normalizedKey, "x-advsql-"):
|
||||
colName := strings.TrimPrefix(normalizedKey, "x-advsql-")
|
||||
case strings.HasPrefix(key, "x-advsql-"):
|
||||
colName := strings.TrimPrefix(key, "x-advsql-")
|
||||
options.AdvancedSQL[colName] = decodedValue
|
||||
case strings.HasPrefix(normalizedKey, "x-cql-sel-"):
|
||||
colName := strings.TrimPrefix(normalizedKey, "x-cql-sel-")
|
||||
case strings.HasPrefix(key, "x-cql-sel-"):
|
||||
colName := strings.TrimPrefix(key, "x-cql-sel-")
|
||||
options.ComputedQL[colName] = decodedValue
|
||||
case strings.HasPrefix(normalizedKey, "x-distinct"):
|
||||
case strings.HasPrefix(key, "x-distinct"):
|
||||
options.Distinct = strings.EqualFold(decodedValue, "true")
|
||||
case strings.HasPrefix(normalizedKey, "x-skipcount"):
|
||||
case strings.HasPrefix(key, "x-skipcount"):
|
||||
options.SkipCount = strings.EqualFold(decodedValue, "true")
|
||||
case strings.HasPrefix(normalizedKey, "x-skipcache"):
|
||||
case strings.HasPrefix(key, "x-skipcache"):
|
||||
options.SkipCache = strings.EqualFold(decodedValue, "true")
|
||||
case strings.HasPrefix(normalizedKey, "x-fetch-rownumber"):
|
||||
case strings.HasPrefix(key, "x-fetch-rownumber"):
|
||||
options.FetchRowNumber = &decodedValue
|
||||
case strings.HasPrefix(normalizedKey, "x-pkrow"):
|
||||
case strings.HasPrefix(key, "x-pkrow"):
|
||||
options.PKRow = &decodedValue
|
||||
|
||||
// Response Format
|
||||
case strings.HasPrefix(normalizedKey, "x-simpleapi"):
|
||||
case strings.HasPrefix(key, "x-simpleapi"):
|
||||
options.ResponseFormat = "simple"
|
||||
case strings.HasPrefix(normalizedKey, "x-detailapi"):
|
||||
case strings.HasPrefix(key, "x-detailapi"):
|
||||
options.ResponseFormat = "detail"
|
||||
case strings.HasPrefix(normalizedKey, "x-syncfusion"):
|
||||
case strings.HasPrefix(key, "x-syncfusion"):
|
||||
options.ResponseFormat = "syncfusion"
|
||||
case strings.HasPrefix(normalizedKey, "x-single-record-as-object"):
|
||||
case strings.HasPrefix(key, "x-single-record-as-object"):
|
||||
// Parse as boolean - "false" disables, "true" enables (default is true)
|
||||
if strings.EqualFold(decodedValue, "false") {
|
||||
options.SingleRecordAsObject = false
|
||||
@@ -212,11 +253,26 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio
|
||||
}
|
||||
|
||||
// Transaction Control
|
||||
case strings.HasPrefix(normalizedKey, "x-transaction-atomic"):
|
||||
case strings.HasPrefix(key, "x-transaction-atomic"):
|
||||
options.AtomicTransaction = strings.EqualFold(decodedValue, "true")
|
||||
|
||||
// X-Files - comprehensive JSON configuration
|
||||
case strings.HasPrefix(key, "x-files"):
|
||||
h.parseXFiles(&options, decodedValue)
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve relation names (convert table names to field names) if model is provided
|
||||
if model != nil {
|
||||
h.resolveRelationNamesInOptions(&options, model)
|
||||
}
|
||||
|
||||
//Always sort according to the primary key if no sorting is specified
|
||||
if len(options.Sort) == 0 {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||
}
|
||||
|
||||
return options
|
||||
}
|
||||
|
||||
@@ -480,170 +536,419 @@ func (h *Handler) parseCommaSeparated(value string) []string {
|
||||
return result
|
||||
}
|
||||
|
||||
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
|
||||
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
||||
if model == nil {
|
||||
return reflect.Invalid
|
||||
// parseXFiles parses x-files header containing comprehensive JSON configuration
|
||||
// and populates ExtendedRequestOptions fields from it
|
||||
func (h *Handler) parseXFiles(options *ExtendedRequestOptions, value string) {
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var xfiles XFiles
|
||||
if err := json.Unmarshal([]byte(value), &xfiles); err != nil {
|
||||
logger.Warn("Failed to parse x-files header: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug("Parsed x-files configuration for table: %s", xfiles.TableName)
|
||||
|
||||
// Store the original XFiles for reference
|
||||
options.XFiles = &xfiles
|
||||
|
||||
// Map XFiles fields to ExtendedRequestOptions
|
||||
|
||||
// Column selection
|
||||
if len(xfiles.Columns) > 0 {
|
||||
options.Columns = append(options.Columns, xfiles.Columns...)
|
||||
logger.Debug("X-Files: Added columns: %v", xfiles.Columns)
|
||||
}
|
||||
|
||||
// Omit columns
|
||||
if len(xfiles.OmitColumns) > 0 {
|
||||
options.OmitColumns = append(options.OmitColumns, xfiles.OmitColumns...)
|
||||
logger.Debug("X-Files: Added omit columns: %v", xfiles.OmitColumns)
|
||||
}
|
||||
|
||||
// Computed columns (CQL) -> ComputedQL
|
||||
if len(xfiles.CQLColumns) > 0 {
|
||||
if options.ComputedQL == nil {
|
||||
options.ComputedQL = make(map[string]string)
|
||||
}
|
||||
for i, cqlExpr := range xfiles.CQLColumns {
|
||||
colName := fmt.Sprintf("cql%d", i+1)
|
||||
options.ComputedQL[colName] = cqlExpr
|
||||
logger.Debug("X-Files: Added computed column %s: %s", colName, cqlExpr)
|
||||
}
|
||||
}
|
||||
|
||||
// Sorting
|
||||
if len(xfiles.Sort) > 0 {
|
||||
for _, sortField := range xfiles.Sort {
|
||||
direction := "ASC"
|
||||
colName := sortField
|
||||
|
||||
// Handle direction prefixes
|
||||
if strings.HasPrefix(sortField, "-") {
|
||||
direction = "DESC"
|
||||
colName = strings.TrimPrefix(sortField, "-")
|
||||
} else if strings.HasPrefix(sortField, "+") {
|
||||
colName = strings.TrimPrefix(sortField, "+")
|
||||
}
|
||||
|
||||
// Handle DESC suffix
|
||||
if strings.HasSuffix(strings.ToLower(colName), " desc") {
|
||||
direction = "DESC"
|
||||
colName = strings.TrimSuffix(strings.ToLower(colName), " desc")
|
||||
} else if strings.HasSuffix(strings.ToLower(colName), " asc") {
|
||||
colName = strings.TrimSuffix(strings.ToLower(colName), " asc")
|
||||
}
|
||||
|
||||
options.Sort = append(options.Sort, common.SortOption{
|
||||
Column: strings.TrimSpace(colName),
|
||||
Direction: direction,
|
||||
})
|
||||
}
|
||||
logger.Debug("X-Files: Added %d sort options", len(xfiles.Sort))
|
||||
}
|
||||
|
||||
// Filter fields
|
||||
if len(xfiles.FilterFields) > 0 {
|
||||
for _, filterField := range xfiles.FilterFields {
|
||||
options.Filters = append(options.Filters, common.FilterOption{
|
||||
Column: filterField.Field,
|
||||
Operator: filterField.Operator,
|
||||
Value: filterField.Value,
|
||||
LogicOperator: "AND", // Default to AND
|
||||
})
|
||||
}
|
||||
logger.Debug("X-Files: Added %d filter fields", len(xfiles.FilterFields))
|
||||
}
|
||||
|
||||
// SQL AND conditions -> CustomSQLWhere
|
||||
if len(xfiles.SqlAnd) > 0 {
|
||||
if options.CustomSQLWhere != "" {
|
||||
options.CustomSQLWhere += " AND "
|
||||
}
|
||||
options.CustomSQLWhere += "(" + strings.Join(xfiles.SqlAnd, " AND ") + ")"
|
||||
logger.Debug("X-Files: Added SQL AND conditions")
|
||||
}
|
||||
|
||||
// SQL OR conditions -> CustomSQLOr
|
||||
if len(xfiles.SqlOr) > 0 {
|
||||
if options.CustomSQLOr != "" {
|
||||
options.CustomSQLOr += " OR "
|
||||
}
|
||||
options.CustomSQLOr += "(" + strings.Join(xfiles.SqlOr, " OR ") + ")"
|
||||
logger.Debug("X-Files: Added SQL OR conditions")
|
||||
}
|
||||
|
||||
// Pagination - Limit
|
||||
if limitStr := xfiles.Limit.String(); limitStr != "" && limitStr != "0" {
|
||||
if limitVal, err := xfiles.Limit.Int64(); err == nil && limitVal > 0 {
|
||||
limit := int(limitVal)
|
||||
options.Limit = &limit
|
||||
logger.Debug("X-Files: Set limit: %d", limit)
|
||||
}
|
||||
}
|
||||
|
||||
// Pagination - Offset
|
||||
if offsetStr := xfiles.Offset.String(); offsetStr != "" && offsetStr != "0" {
|
||||
if offsetVal, err := xfiles.Offset.Int64(); err == nil && offsetVal > 0 {
|
||||
offset := int(offsetVal)
|
||||
options.Offset = &offset
|
||||
logger.Debug("X-Files: Set offset: %d", offset)
|
||||
}
|
||||
}
|
||||
|
||||
// Cursor pagination
|
||||
if xfiles.CursorForward != "" {
|
||||
options.CursorForward = xfiles.CursorForward
|
||||
logger.Debug("X-Files: Set cursor forward")
|
||||
}
|
||||
if xfiles.CursorBackward != "" {
|
||||
options.CursorBackward = xfiles.CursorBackward
|
||||
logger.Debug("X-Files: Set cursor backward")
|
||||
}
|
||||
|
||||
// Flags
|
||||
if xfiles.Skipcount {
|
||||
options.SkipCount = true
|
||||
logger.Debug("X-Files: Set skip count")
|
||||
}
|
||||
|
||||
// Process ParentTables and ChildTables recursively
|
||||
h.processXFilesRelations(&xfiles, options, "")
|
||||
}
|
||||
|
||||
// processXFilesRelations processes ParentTables and ChildTables from XFiles
|
||||
// and adds them as Preload options recursively
|
||||
func (h *Handler) processXFilesRelations(xfiles *XFiles, options *ExtendedRequestOptions, basePath string) {
|
||||
if xfiles == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Process ParentTables
|
||||
if len(xfiles.ParentTables) > 0 {
|
||||
logger.Debug("X-Files: Processing %d parent tables", len(xfiles.ParentTables))
|
||||
for _, parentTable := range xfiles.ParentTables {
|
||||
h.addXFilesPreload(parentTable, options, basePath)
|
||||
}
|
||||
}
|
||||
|
||||
// Process ChildTables
|
||||
if len(xfiles.ChildTables) > 0 {
|
||||
logger.Debug("X-Files: Processing %d child tables", len(xfiles.ChildTables))
|
||||
for _, childTable := range xfiles.ChildTables {
|
||||
h.addXFilesPreload(childTable, options, basePath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveRelationNamesInOptions resolves all table names to field names in preload options
|
||||
// This is called internally by parseOptionsFromHeaders when a model is provided
|
||||
func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions, model interface{}) {
|
||||
if options == nil || model == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve relation names in all preload options
|
||||
for i := range options.Preload {
|
||||
preload := &options.Preload[i]
|
||||
|
||||
// Split the relation path (e.g., "parent.child.grandchild")
|
||||
parts := strings.Split(preload.Relation, ".")
|
||||
resolvedParts := make([]string, 0, len(parts))
|
||||
|
||||
// Resolve each part of the path
|
||||
currentModel := model
|
||||
for _, part := range parts {
|
||||
resolvedPart := h.resolveRelationName(currentModel, part)
|
||||
resolvedParts = append(resolvedParts, resolvedPart)
|
||||
|
||||
// Try to get the model type for the next level
|
||||
// This allows nested resolution
|
||||
if nextModel := reflection.GetRelationModel(currentModel, resolvedPart); nextModel != nil {
|
||||
currentModel = nextModel
|
||||
}
|
||||
}
|
||||
|
||||
// Update the relation path with resolved names
|
||||
resolvedPath := strings.Join(resolvedParts, ".")
|
||||
if resolvedPath != preload.Relation {
|
||||
logger.Debug("Resolved relation path '%s' -> '%s'", preload.Relation, resolvedPath)
|
||||
preload.Relation = resolvedPath
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve relation names in expand options
|
||||
for i := range options.Expand {
|
||||
expand := &options.Expand[i]
|
||||
resolved := h.resolveRelationName(model, expand.Relation)
|
||||
if resolved != expand.Relation {
|
||||
logger.Debug("Resolved expand relation '%s' -> '%s'", expand.Relation, resolved)
|
||||
expand.Relation = resolved
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveRelationName resolves a relation name or table name to the actual field name in the model
|
||||
// If the input is already a field name, it returns it as-is
|
||||
// If the input is a table name, it looks up the corresponding relation field
|
||||
func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) string {
|
||||
if model == nil || nameOrTable == "" {
|
||||
return nameOrTable
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType == nil {
|
||||
return nameOrTable
|
||||
}
|
||||
|
||||
// Dereference pointer if needed
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Check again after dereferencing
|
||||
if modelType == nil {
|
||||
return nameOrTable
|
||||
}
|
||||
|
||||
// Ensure it's a struct
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
return reflect.Invalid
|
||||
return nameOrTable
|
||||
}
|
||||
|
||||
// Find the field by JSON tag or field name
|
||||
// First, check if the input matches a field name directly
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
if field.Name == nameOrTable {
|
||||
// It's already a field name
|
||||
logger.Debug("Input '%s' is a field name", nameOrTable)
|
||||
return nameOrTable
|
||||
}
|
||||
}
|
||||
|
||||
// Check JSON tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" {
|
||||
// Parse JSON tag (format: "name,omitempty")
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if parts[0] == colName {
|
||||
return field.Type.Kind()
|
||||
// If not found as a field name, try to look it up as a table name
|
||||
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
fieldType := field.Type
|
||||
|
||||
// Check if it's a slice or pointer to a struct
|
||||
var targetType reflect.Type
|
||||
if fieldType.Kind() == reflect.Slice {
|
||||
targetType = fieldType.Elem()
|
||||
} else if fieldType.Kind() == reflect.Ptr {
|
||||
targetType = fieldType.Elem()
|
||||
}
|
||||
|
||||
if targetType != nil {
|
||||
// Dereference pointer if the slice contains pointers
|
||||
if targetType.Kind() == reflect.Ptr {
|
||||
targetType = targetType.Elem()
|
||||
}
|
||||
|
||||
// Check if it's a struct type
|
||||
if targetType.Kind() == reflect.Struct {
|
||||
// Get the type name and normalize it
|
||||
typeName := targetType.Name()
|
||||
|
||||
// Extract the table name from type name
|
||||
// Patterns: ModelCoreMastertaskitem -> mastertaskitem
|
||||
// ModelMastertaskitem -> mastertaskitem
|
||||
normalizedTypeName := strings.ToLower(typeName)
|
||||
|
||||
// Remove common prefixes like "model", "modelcore", etc.
|
||||
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
|
||||
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
|
||||
|
||||
// Compare normalized names
|
||||
if normalizedTypeName == normalizedInput {
|
||||
logger.Debug("Resolved table name '%s' to field '%s' (type: %s)", nameOrTable, field.Name, typeName)
|
||||
return field.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check field name (case-insensitive)
|
||||
if strings.EqualFold(field.Name, colName) {
|
||||
return field.Type.Kind()
|
||||
}
|
||||
// If no match found, return the original input
|
||||
logger.Debug("No field found for '%s', using as-is", nameOrTable)
|
||||
return nameOrTable
|
||||
}
|
||||
|
||||
// Check snake_case conversion
|
||||
snakeCaseName := toSnakeCase(field.Name)
|
||||
if snakeCaseName == colName {
|
||||
return field.Type.Kind()
|
||||
// addXFilesPreload converts an XFiles relation into a PreloadOption
|
||||
// and recursively processes its children
|
||||
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
|
||||
if xfile == nil || xfile.TableName == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Store the table name as-is for now - it will be resolved to field name later
|
||||
// when we have the model instance available
|
||||
relationPath := xfile.TableName
|
||||
if basePath != "" {
|
||||
relationPath = basePath + "." + xfile.TableName
|
||||
}
|
||||
|
||||
logger.Debug("X-Files: Adding preload for relation: %s", relationPath)
|
||||
|
||||
// Create PreloadOption from XFiles configuration
|
||||
preloadOpt := common.PreloadOption{
|
||||
Relation: relationPath,
|
||||
Columns: xfile.Columns,
|
||||
OmitColumns: xfile.OmitColumns,
|
||||
}
|
||||
|
||||
// Add sorting if specified
|
||||
if len(xfile.Sort) > 0 {
|
||||
preloadOpt.Sort = make([]common.SortOption, 0, len(xfile.Sort))
|
||||
for _, sortField := range xfile.Sort {
|
||||
direction := "ASC"
|
||||
colName := sortField
|
||||
|
||||
// Handle direction prefixes
|
||||
if strings.HasPrefix(sortField, "-") {
|
||||
direction = "DESC"
|
||||
colName = strings.TrimPrefix(sortField, "-")
|
||||
} else if strings.HasPrefix(sortField, "+") {
|
||||
colName = strings.TrimPrefix(sortField, "+")
|
||||
}
|
||||
|
||||
preloadOpt.Sort = append(preloadOpt.Sort, common.SortOption{
|
||||
Column: strings.TrimSpace(colName),
|
||||
Direction: direction,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return reflect.Invalid
|
||||
}
|
||||
|
||||
// toSnakeCase converts a string from CamelCase to snake_case
|
||||
func toSnakeCase(s string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range s {
|
||||
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||
result.WriteRune('_')
|
||||
// Add filters if specified
|
||||
if len(xfile.FilterFields) > 0 {
|
||||
preloadOpt.Filters = make([]common.FilterOption, 0, len(xfile.FilterFields))
|
||||
for _, filterField := range xfile.FilterFields {
|
||||
preloadOpt.Filters = append(preloadOpt.Filters, common.FilterOption{
|
||||
Column: filterField.Field,
|
||||
Operator: filterField.Operator,
|
||||
Value: filterField.Value,
|
||||
LogicOperator: "AND",
|
||||
})
|
||||
}
|
||||
result.WriteRune(r)
|
||||
}
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// isNumericType checks if a reflect.Kind is a numeric type
|
||||
func isNumericType(kind reflect.Kind) bool {
|
||||
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
|
||||
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
|
||||
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
|
||||
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
|
||||
}
|
||||
|
||||
// isStringType checks if a reflect.Kind is a string type
|
||||
func isStringType(kind reflect.Kind) bool {
|
||||
return kind == reflect.String
|
||||
}
|
||||
|
||||
// convertToNumericType converts a string value to the appropriate numeric type
|
||||
func convertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
switch kind {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
// Parse as integer
|
||||
bitSize := 64
|
||||
switch kind {
|
||||
case reflect.Int8:
|
||||
bitSize = 8
|
||||
case reflect.Int16:
|
||||
bitSize = 16
|
||||
case reflect.Int32:
|
||||
bitSize = 32
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseInt(value, 10, bitSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid integer value: %w", err)
|
||||
}
|
||||
|
||||
// Return the appropriate type
|
||||
switch kind {
|
||||
case reflect.Int:
|
||||
return int(intVal), nil
|
||||
case reflect.Int8:
|
||||
return int8(intVal), nil
|
||||
case reflect.Int16:
|
||||
return int16(intVal), nil
|
||||
case reflect.Int32:
|
||||
return int32(intVal), nil
|
||||
case reflect.Int64:
|
||||
return intVal, nil
|
||||
}
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
// Parse as unsigned integer
|
||||
bitSize := 64
|
||||
switch kind {
|
||||
case reflect.Uint8:
|
||||
bitSize = 8
|
||||
case reflect.Uint16:
|
||||
bitSize = 16
|
||||
case reflect.Uint32:
|
||||
bitSize = 32
|
||||
}
|
||||
|
||||
uintVal, err := strconv.ParseUint(value, 10, bitSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
|
||||
}
|
||||
|
||||
// Return the appropriate type
|
||||
switch kind {
|
||||
case reflect.Uint:
|
||||
return uint(uintVal), nil
|
||||
case reflect.Uint8:
|
||||
return uint8(uintVal), nil
|
||||
case reflect.Uint16:
|
||||
return uint16(uintVal), nil
|
||||
case reflect.Uint32:
|
||||
return uint32(uintVal), nil
|
||||
case reflect.Uint64:
|
||||
return uintVal, nil
|
||||
}
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
// Parse as float
|
||||
bitSize := 64
|
||||
if kind == reflect.Float32 {
|
||||
bitSize = 32
|
||||
}
|
||||
|
||||
floatVal, err := strconv.ParseFloat(value, bitSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid float value: %w", err)
|
||||
}
|
||||
|
||||
if kind == reflect.Float32 {
|
||||
return float32(floatVal), nil
|
||||
}
|
||||
return floatVal, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
|
||||
}
|
||||
// Add WHERE clause if SQL conditions specified
|
||||
whereConditions := make([]string, 0)
|
||||
if len(xfile.SqlAnd) > 0 {
|
||||
whereConditions = append(whereConditions, xfile.SqlAnd...)
|
||||
}
|
||||
if len(whereConditions) > 0 {
|
||||
preloadOpt.Where = strings.Join(whereConditions, " AND ")
|
||||
}
|
||||
|
||||
// isNumericValue checks if a string value can be parsed as a number
|
||||
func isNumericValue(value string) bool {
|
||||
value = strings.TrimSpace(value)
|
||||
_, err := strconv.ParseFloat(value, 64)
|
||||
return err == nil
|
||||
// Add limit if specified
|
||||
if limitStr := xfile.Limit.String(); limitStr != "" && limitStr != "0" {
|
||||
if limitVal, err := xfile.Limit.Int64(); err == nil && limitVal > 0 {
|
||||
limit := int(limitVal)
|
||||
preloadOpt.Limit = &limit
|
||||
}
|
||||
}
|
||||
|
||||
// Add computed columns (CQL) -> ComputedQL
|
||||
if len(xfile.CQLColumns) > 0 {
|
||||
preloadOpt.ComputedQL = make(map[string]string)
|
||||
for i, cqlExpr := range xfile.CQLColumns {
|
||||
colName := fmt.Sprintf("cql%d", i+1)
|
||||
preloadOpt.ComputedQL[colName] = cqlExpr
|
||||
logger.Debug("X-Files: Added computed column %s to preload %s: %s", colName, relationPath, cqlExpr)
|
||||
}
|
||||
}
|
||||
|
||||
// Set recursive flag
|
||||
preloadOpt.Recursive = xfile.Recursive
|
||||
|
||||
// Extract relationship keys for proper foreign key filtering
|
||||
if xfile.PrimaryKey != "" {
|
||||
preloadOpt.PrimaryKey = xfile.PrimaryKey
|
||||
logger.Debug("X-Files: Set primary key for %s: %s", relationPath, xfile.PrimaryKey)
|
||||
}
|
||||
if xfile.RelatedKey != "" {
|
||||
preloadOpt.RelatedKey = xfile.RelatedKey
|
||||
logger.Debug("X-Files: Set related key for %s: %s", relationPath, xfile.RelatedKey)
|
||||
}
|
||||
if xfile.ForeignKey != "" {
|
||||
preloadOpt.ForeignKey = xfile.ForeignKey
|
||||
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
||||
}
|
||||
|
||||
// Add the preload option
|
||||
options.Preload = append(options.Preload, preloadOpt)
|
||||
|
||||
// Recursively process nested ParentTables and ChildTables
|
||||
if xfile.Recursive {
|
||||
logger.Debug("X-Files: Recursive preload enabled for: %s", relationPath)
|
||||
h.processXFilesRelations(xfile, options, relationPath)
|
||||
} else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 {
|
||||
h.processXFilesRelations(xfile, options, relationPath)
|
||||
}
|
||||
}
|
||||
|
||||
// ColumnCastInfo holds information about whether a column needs casting
|
||||
@@ -659,7 +964,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
|
||||
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
||||
}
|
||||
|
||||
colType := h.getColumnTypeFromModel(model, filter.Column)
|
||||
colType := reflection.GetColumnTypeFromModel(model, filter.Column)
|
||||
if colType == reflect.Invalid {
|
||||
// Column not found in model, no casting needed
|
||||
logger.Debug("Column %s not found in model, skipping type validation", filter.Column)
|
||||
@@ -670,18 +975,18 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
|
||||
valueIsNumeric := false
|
||||
if strVal, ok := filter.Value.(string); ok {
|
||||
strVal = strings.Trim(strVal, "%")
|
||||
valueIsNumeric = isNumericValue(strVal)
|
||||
valueIsNumeric = reflection.IsNumericValue(strVal)
|
||||
}
|
||||
|
||||
// Adjust based on column type
|
||||
switch {
|
||||
case isNumericType(colType):
|
||||
case reflection.IsNumericType(colType):
|
||||
// Column is numeric
|
||||
if valueIsNumeric {
|
||||
// Value is numeric - try to convert it
|
||||
if strVal, ok := filter.Value.(string); ok {
|
||||
strVal = strings.Trim(strVal, "%")
|
||||
numericVal, err := convertToNumericType(strVal, colType)
|
||||
numericVal, err := reflection.ConvertToNumericType(strVal, colType)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to convert value '%s' to numeric type for column %s, will use text cast", strVal, filter.Column)
|
||||
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
|
||||
@@ -696,7 +1001,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
|
||||
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
|
||||
}
|
||||
|
||||
case isStringType(colType):
|
||||
case reflection.IsStringType(colType):
|
||||
// String columns don't need casting
|
||||
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
||||
|
||||
|
||||
403
pkg/restheadspec/query_params_test.go
Normal file
403
pkg/restheadspec/query_params_test.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// MockRequest implements common.Request interface for testing
|
||||
type MockRequest struct {
|
||||
headers map[string]string
|
||||
queryParams map[string]string
|
||||
}
|
||||
|
||||
func (m *MockRequest) Method() string {
|
||||
return "GET"
|
||||
}
|
||||
|
||||
func (m *MockRequest) URL() string {
|
||||
return "http://example.com/test"
|
||||
}
|
||||
|
||||
func (m *MockRequest) Header(key string) string {
|
||||
return m.headers[key]
|
||||
}
|
||||
|
||||
func (m *MockRequest) AllHeaders() map[string]string {
|
||||
return m.headers
|
||||
}
|
||||
|
||||
func (m *MockRequest) Body() ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *MockRequest) PathParam(key string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (m *MockRequest) QueryParam(key string) string {
|
||||
return m.queryParams[key]
|
||||
}
|
||||
|
||||
func (m *MockRequest) AllQueryParams() map[string]string {
|
||||
return m.queryParams
|
||||
}
|
||||
|
||||
func TestParseOptionsFromQueryParams(t *testing.T) {
|
||||
handler := NewHandler(nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams map[string]string
|
||||
headers map[string]string
|
||||
validate func(t *testing.T, options ExtendedRequestOptions)
|
||||
}{
|
||||
{
|
||||
name: "Parse custom SQL WHERE from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-w-1": `("v_webui_clients".clientstatus = 0 or "v_webui_clients".clientstatus is null)`,
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if options.CustomSQLWhere == "" {
|
||||
t.Error("Expected CustomSQLWhere to be set from query param")
|
||||
}
|
||||
expected := `("v_webui_clients".clientstatus = 0 or "v_webui_clients".clientstatus is null)`
|
||||
if options.CustomSQLWhere != expected {
|
||||
t.Errorf("Expected CustomSQLWhere=%q, got %q", expected, options.CustomSQLWhere)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse sort from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-sort": "-applicationdate,name",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.Sort) != 2 {
|
||||
t.Errorf("Expected 2 sort options, got %d", len(options.Sort))
|
||||
return
|
||||
}
|
||||
if options.Sort[0].Column != "applicationdate" || options.Sort[0].Direction != "DESC" {
|
||||
t.Errorf("Expected first sort: applicationdate DESC, got %s %s", options.Sort[0].Column, options.Sort[0].Direction)
|
||||
}
|
||||
if options.Sort[1].Column != "name" || options.Sort[1].Direction != "ASC" {
|
||||
t.Errorf("Expected second sort: name ASC, got %s %s", options.Sort[1].Column, options.Sort[1].Direction)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse limit and offset from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-limit": "100",
|
||||
"x-offset": "50",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if options.Limit == nil || *options.Limit != 100 {
|
||||
t.Errorf("Expected limit=100, got %v", options.Limit)
|
||||
}
|
||||
if options.Offset == nil || *options.Offset != 50 {
|
||||
t.Errorf("Expected offset=50, got %v", options.Offset)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse field filters from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-fieldfilter-status": "active",
|
||||
"x-fieldfilter-type": "user",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.Filters) != 2 {
|
||||
t.Errorf("Expected 2 filters, got %d", len(options.Filters))
|
||||
return
|
||||
}
|
||||
// Check that filters were created
|
||||
foundStatus := false
|
||||
foundType := false
|
||||
for _, filter := range options.Filters {
|
||||
if filter.Column == "status" && filter.Value == "active" && filter.Operator == "eq" {
|
||||
foundStatus = true
|
||||
}
|
||||
if filter.Column == "type" && filter.Value == "user" && filter.Operator == "eq" {
|
||||
foundType = true
|
||||
}
|
||||
}
|
||||
if !foundStatus {
|
||||
t.Error("Expected status filter not found")
|
||||
}
|
||||
if !foundType {
|
||||
t.Error("Expected type filter not found")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse select fields from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-select-fields": "id,name,email",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.Columns) != 3 {
|
||||
t.Errorf("Expected 3 columns, got %d", len(options.Columns))
|
||||
return
|
||||
}
|
||||
expected := []string{"id", "name", "email"}
|
||||
for i, col := range expected {
|
||||
if i >= len(options.Columns) || options.Columns[i] != col {
|
||||
t.Errorf("Expected column[%d]=%s, got %v", i, col, options.Columns)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse preload from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-preload": "posts:title,content|comments",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.Preload) != 2 {
|
||||
t.Errorf("Expected 2 preload options, got %d", len(options.Preload))
|
||||
return
|
||||
}
|
||||
// Check first preload (posts with columns)
|
||||
if options.Preload[0].Relation != "posts" {
|
||||
t.Errorf("Expected first preload relation=posts, got %s", options.Preload[0].Relation)
|
||||
}
|
||||
if len(options.Preload[0].Columns) != 2 {
|
||||
t.Errorf("Expected 2 columns for posts preload, got %d", len(options.Preload[0].Columns))
|
||||
}
|
||||
// Check second preload (comments without columns)
|
||||
if options.Preload[1].Relation != "comments" {
|
||||
t.Errorf("Expected second preload relation=comments, got %s", options.Preload[1].Relation)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Query params take precedence over headers",
|
||||
queryParams: map[string]string{
|
||||
"x-limit": "100",
|
||||
},
|
||||
headers: map[string]string{
|
||||
"X-Limit": "50",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if options.Limit == nil || *options.Limit != 100 {
|
||||
t.Errorf("Expected query param limit=100 to override header, got %v", options.Limit)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse search operators from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-searchop-contains-name": "john",
|
||||
"x-searchop-gt-age": "18",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.Filters) != 2 {
|
||||
t.Errorf("Expected 2 filters, got %d", len(options.Filters))
|
||||
return
|
||||
}
|
||||
// Check for ILIKE filter
|
||||
foundContains := false
|
||||
foundGt := false
|
||||
for _, filter := range options.Filters {
|
||||
if filter.Column == "name" && filter.Operator == "ilike" {
|
||||
foundContains = true
|
||||
}
|
||||
if filter.Column == "age" && filter.Operator == "gt" && filter.Value == "18" {
|
||||
foundGt = true
|
||||
}
|
||||
}
|
||||
if !foundContains {
|
||||
t.Error("Expected contains filter not found")
|
||||
}
|
||||
if !foundGt {
|
||||
t.Error("Expected gt filter not found")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse complex example with multiple params",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-w-1": `("v_webui_clients".clientstatus = 0)`,
|
||||
"x-sort": "-applicationdate",
|
||||
"x-limit": "100",
|
||||
"x-select-fields": "id,name,status",
|
||||
"x-fieldfilter-active": "true",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
// Validate CustomSQLWhere
|
||||
if options.CustomSQLWhere == "" {
|
||||
t.Error("Expected CustomSQLWhere to be set")
|
||||
}
|
||||
// Validate Sort
|
||||
if len(options.Sort) != 1 || options.Sort[0].Column != "applicationdate" || options.Sort[0].Direction != "DESC" {
|
||||
t.Errorf("Expected sort by applicationdate DESC, got %v", options.Sort)
|
||||
}
|
||||
// Validate Limit
|
||||
if options.Limit == nil || *options.Limit != 100 {
|
||||
t.Errorf("Expected limit=100, got %v", options.Limit)
|
||||
}
|
||||
// Validate Columns
|
||||
if len(options.Columns) != 3 {
|
||||
t.Errorf("Expected 3 columns, got %d", len(options.Columns))
|
||||
}
|
||||
// Validate Filters
|
||||
if len(options.Filters) < 1 {
|
||||
t.Error("Expected at least 1 filter")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse distinct flag from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-distinct": "true",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if !options.Distinct {
|
||||
t.Error("Expected Distinct to be true")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse skip count flag from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-skipcount": "true",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if !options.SkipCount {
|
||||
t.Error("Expected SkipCount to be true")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse response format from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-syncfusion": "true",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if options.ResponseFormat != "syncfusion" {
|
||||
t.Errorf("Expected ResponseFormat=syncfusion, got %s", options.ResponseFormat)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse custom SQL OR from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-or": `("field1" = 'value1' OR "field2" = 'value2')`,
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if options.CustomSQLOr == "" {
|
||||
t.Error("Expected CustomSQLOr to be set")
|
||||
}
|
||||
expected := `("field1" = 'value1' OR "field2" = 'value2')`
|
||||
if options.CustomSQLOr != expected {
|
||||
t.Errorf("Expected CustomSQLOr=%q, got %q", expected, options.CustomSQLOr)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create mock request
|
||||
req := &MockRequest{
|
||||
headers: tt.headers,
|
||||
queryParams: tt.queryParams,
|
||||
}
|
||||
if req.headers == nil {
|
||||
req.headers = make(map[string]string)
|
||||
}
|
||||
if req.queryParams == nil {
|
||||
req.queryParams = make(map[string]string)
|
||||
}
|
||||
|
||||
// Parse options
|
||||
options := handler.parseOptionsFromHeaders(req, nil)
|
||||
|
||||
// Validate
|
||||
tt.validate(t, options)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryParamsWithURLEncoding(t *testing.T) {
|
||||
handler := NewHandler(nil, nil)
|
||||
|
||||
// Test with URL-encoded query parameter (like the user's example)
|
||||
req := &MockRequest{
|
||||
headers: make(map[string]string),
|
||||
queryParams: map[string]string{
|
||||
// URL-encoded version of the SQL WHERE clause
|
||||
"x-custom-sql-w-1": `("v_webui_clients".clientstatus = 0 or "v_webui_clients".clientstatus is null) and ("v_webui_clients".inactive = 0 or "v_webui_clients".inactive is null)`,
|
||||
},
|
||||
}
|
||||
|
||||
options := handler.parseOptionsFromHeaders(req, nil)
|
||||
|
||||
if options.CustomSQLWhere == "" {
|
||||
t.Error("Expected CustomSQLWhere to be set from URL-encoded query param")
|
||||
}
|
||||
|
||||
// The SQL should contain the expected conditions
|
||||
if !contains(options.CustomSQLWhere, "clientstatus") {
|
||||
t.Error("Expected CustomSQLWhere to contain 'clientstatus'")
|
||||
}
|
||||
if !contains(options.CustomSQLWhere, "inactive") {
|
||||
t.Error("Expected CustomSQLWhere to contain 'inactive'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadersAndQueryParamsCombined(t *testing.T) {
|
||||
handler := NewHandler(nil, nil)
|
||||
|
||||
// Test that headers and query params can work together
|
||||
req := &MockRequest{
|
||||
headers: map[string]string{
|
||||
"X-Select-Fields": "id,name",
|
||||
"X-Limit": "50",
|
||||
},
|
||||
queryParams: map[string]string{
|
||||
"x-sort": "-created_at",
|
||||
"x-offset": "10",
|
||||
// This should override the header value
|
||||
"x-limit": "100",
|
||||
},
|
||||
}
|
||||
|
||||
options := handler.parseOptionsFromHeaders(req, nil)
|
||||
|
||||
// Verify columns from header
|
||||
if len(options.Columns) != 2 {
|
||||
t.Errorf("Expected 2 columns from header, got %d", len(options.Columns))
|
||||
}
|
||||
|
||||
// Verify sort from query param
|
||||
if len(options.Sort) != 1 || options.Sort[0].Column != "created_at" {
|
||||
t.Errorf("Expected sort from query param, got %v", options.Sort)
|
||||
}
|
||||
|
||||
// Verify offset from query param
|
||||
if options.Offset == nil || *options.Offset != 10 {
|
||||
t.Errorf("Expected offset=10 from query param, got %v", options.Offset)
|
||||
}
|
||||
|
||||
// Verify limit from query param (should override header)
|
||||
if options.Limit == nil {
|
||||
t.Error("Expected limit to be set from query param")
|
||||
} else if *options.Limit != 100 {
|
||||
t.Errorf("Expected limit=100 from query param (overriding header), got %d", *options.Limit)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if a string contains a substring
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
431
pkg/restheadspec/xfiles.go
Normal file
431
pkg/restheadspec/xfiles.go
Normal file
@@ -0,0 +1,431 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type XFiles struct {
|
||||
TableName string `json:"tablename"`
|
||||
Schema string `json:"schema"`
|
||||
PrimaryKey string `json:"primarykey"`
|
||||
ForeignKey string `json:"foreignkey"`
|
||||
RelatedKey string `json:"relatedkey"`
|
||||
Sort []string `json:"sort"`
|
||||
Prefix string `json:"prefix"`
|
||||
Editable bool `json:"editable"`
|
||||
Recursive bool `json:"recursive"`
|
||||
Expand bool `json:"expand"`
|
||||
Rownumber bool `json:"rownumber"`
|
||||
Skipcount bool `json:"skipcount"`
|
||||
Offset json.Number `json:"offset"`
|
||||
Limit json.Number `json:"limit"`
|
||||
Columns []string `json:"columns"`
|
||||
OmitColumns []string `json:"omit_columns"`
|
||||
CQLColumns []string `json:"cql_columns"`
|
||||
|
||||
SqlJoins []string `json:"sql_joins"`
|
||||
SqlOr []string `json:"sql_or"`
|
||||
SqlAnd []string `json:"sql_and"`
|
||||
ParentTables []*XFiles `json:"parenttables"`
|
||||
ChildTables []*XFiles `json:"childtables"`
|
||||
ModelType reflect.Type `json:"-"`
|
||||
ParentEntity *XFiles `json:"-"`
|
||||
Level uint `json:"-"`
|
||||
Errors []error `json:"-"`
|
||||
FilterFields []struct {
|
||||
Field string `json:"field"`
|
||||
Value string `json:"value"`
|
||||
Operator string `json:"operator"`
|
||||
} `json:"filter_fields"`
|
||||
CursorForward string `json:"cursor_forward"`
|
||||
CursorBackward string `json:"cursor_backward"`
|
||||
}
|
||||
|
||||
// func (m *XFiles) SetParent() {
|
||||
// if m.ChildTables != nil {
|
||||
// for _, child := range m.ChildTables {
|
||||
// if child.ParentEntity != nil {
|
||||
// continue
|
||||
// }
|
||||
// child.ParentEntity = m
|
||||
// child.Level = m.Level + 1000
|
||||
// child.SetParent()
|
||||
// }
|
||||
// }
|
||||
// if m.ParentTables != nil {
|
||||
// for _, pt := range m.ParentTables {
|
||||
// if pt.ParentEntity != nil {
|
||||
// continue
|
||||
// }
|
||||
// pt.ParentEntity = m
|
||||
// pt.Level = m.Level + 1
|
||||
// pt.SetParent()
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (m *XFiles) GetParentRelations() []reflection.GormRelationType {
|
||||
// if m.ParentEntity == nil {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// foundRelations := make(GormRelationTypeList, 0)
|
||||
// rels := reflection.GetValidModelRelationTypes(m.ParentEntity.ModelType, false)
|
||||
|
||||
// if m.ParentEntity.ModelType == nil {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// for _, rel := range rels {
|
||||
// // if len(foundRelations) > 0 {
|
||||
// // break
|
||||
// // }
|
||||
// if rel.FieldName != "" && rel.AssociationTable.Name() == m.ModelType.Name() {
|
||||
|
||||
// if rel.AssociationKey != "" && m.RelatedKey != "" && strings.EqualFold(rel.AssociationKey, m.RelatedKey) {
|
||||
// foundRelations = append(foundRelations, rel)
|
||||
// } else if rel.AssociationKey != "" && m.ForeignKey != "" && strings.EqualFold(rel.AssociationKey, m.ForeignKey) {
|
||||
// foundRelations = append(foundRelations, rel)
|
||||
// } else if rel.ForeignKey != "" && m.ForeignKey != "" && strings.EqualFold(rel.ForeignKey, m.ForeignKey) {
|
||||
// foundRelations = append(foundRelations, rel)
|
||||
// } else if rel.ForeignKey != "" && m.RelatedKey != "" && strings.EqualFold(rel.ForeignKey, m.RelatedKey) {
|
||||
// foundRelations = append(foundRelations, rel)
|
||||
// } else if rel.ForeignKey != "" && m.ForeignKey == "" && m.RelatedKey == "" {
|
||||
// foundRelations = append(foundRelations, rel)
|
||||
// }
|
||||
// }
|
||||
|
||||
// //idName := fmt.Sprintf("%s_to_%s_%s=%s_m%v", rel.TableName, rel.AssociationTableName, rel.ForeignKey, rel.AssociationKey, rel.OneToMany)
|
||||
// }
|
||||
|
||||
// sort.Sort(foundRelations)
|
||||
// finalList := make(GormRelationTypeList, 0)
|
||||
// dups := make(map[string]bool)
|
||||
// for _, rel := range foundRelations {
|
||||
// idName := fmt.Sprintf("%s_to_%s_%s_%s=%s_m%v", rel.TableName, rel.AssociationTableName, rel.FieldName, rel.ForeignKey, rel.AssociationKey, rel.OneToMany)
|
||||
// if dups[idName] {
|
||||
// continue
|
||||
// }
|
||||
// finalList = append(finalList, rel)
|
||||
// dups[idName] = true
|
||||
// }
|
||||
|
||||
// //fmt.Printf("GetParentRelations %s: %+v %d=%d\n", m.TableName, dups, len(finalList), len(foundRelations))
|
||||
|
||||
// return finalList
|
||||
// }
|
||||
|
||||
// func (m *XFiles) GetUpdatableTableNames() []string {
|
||||
// foundTables := make([]string, 0)
|
||||
// if m.Editable {
|
||||
// foundTables = append(foundTables, m.TableName)
|
||||
// }
|
||||
// if m.ParentTables != nil {
|
||||
// for _, pt := range m.ParentTables {
|
||||
// list := pt.GetUpdatableTableNames()
|
||||
// if list != nil {
|
||||
// foundTables = append(foundTables, list...)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if m.ChildTables != nil {
|
||||
// for _, ct := range m.ChildTables {
|
||||
// list := ct.GetUpdatableTableNames()
|
||||
// if list != nil {
|
||||
// foundTables = append(foundTables, list...)
|
||||
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return foundTables
|
||||
// }
|
||||
|
||||
// func (m *XFiles) preload(db *gorm.DB, pPath string, pCnt uint) (*gorm.DB, error) {
|
||||
|
||||
// path := pPath
|
||||
// _, colval := JSONSyntaxToSQLIn(path, m.ModelType, "preload")
|
||||
// if colval != "" {
|
||||
// path = colval
|
||||
// }
|
||||
|
||||
// if path == "" {
|
||||
// return db, fmt.Errorf("invalid preload path %s", path)
|
||||
// }
|
||||
|
||||
// sortList := ""
|
||||
// if m.Sort != nil {
|
||||
// for _, sort := range m.Sort {
|
||||
// descSort := false
|
||||
// if strings.HasPrefix(sort, "-") || strings.Contains(strings.ToLower(sort), " desc") {
|
||||
// descSort = true
|
||||
// }
|
||||
// sort = strings.TrimPrefix(strings.TrimPrefix(sort, "+"), "-")
|
||||
// sort = strings.ReplaceAll(strings.ReplaceAll(sort, " desc", ""), " asc", "")
|
||||
// if descSort {
|
||||
// sort = sort + " desc"
|
||||
// }
|
||||
// sortList = sort
|
||||
// }
|
||||
// }
|
||||
|
||||
// SrcColumns := reflection.GetModelSQLColumns(m.ModelType)
|
||||
// Columns := make([]string, 0)
|
||||
|
||||
// for _, s := range SrcColumns {
|
||||
// for _, v := range m.Columns {
|
||||
// if strings.EqualFold(v, s) {
|
||||
// Columns = append(Columns, v)
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// if len(Columns) == 0 {
|
||||
// Columns = SrcColumns
|
||||
// }
|
||||
|
||||
// chain := db
|
||||
|
||||
// // //Do expand where we can
|
||||
// // if m.Expand {
|
||||
// // ops := func(subchain *gorm.DB) *gorm.DB {
|
||||
// // subchain = subchain.Select(strings.Join(m.Columns, ","))
|
||||
|
||||
// // if m.Filter != "" {
|
||||
// // subchain = subchain.Where(m.Filter)
|
||||
// // }
|
||||
// // return subchain
|
||||
// // }
|
||||
// // chain = chain.Joins(path, ops(chain))
|
||||
// // }
|
||||
|
||||
// //fmt.Printf("Preloading %s: %s lvl:%d \n", m.TableName, path, m.Level)
|
||||
// //Do preload
|
||||
// chain = chain.Preload(path, func(db *gorm.DB) *gorm.DB {
|
||||
// subchain := db
|
||||
|
||||
// if sortList != "" {
|
||||
// subchain = subchain.Order(sortList)
|
||||
// }
|
||||
|
||||
// for _, sql := range m.SqlAnd {
|
||||
// fnType, colval := JSONSyntaxToSQL(sql, m.ModelType)
|
||||
// if fnType == 0 {
|
||||
// colval = ValidSQL(colval, "select")
|
||||
// }
|
||||
// subchain = subchain.Where(colval)
|
||||
// }
|
||||
|
||||
// for _, sql := range m.SqlOr {
|
||||
// fnType, colval := JSONSyntaxToSQL(sql, m.ModelType)
|
||||
// if fnType == 0 {
|
||||
// colval = ValidSQL(colval, "select")
|
||||
// }
|
||||
// subchain = subchain.Or(colval)
|
||||
// }
|
||||
|
||||
// limitval, err := m.Limit.Int64()
|
||||
// if err == nil && limitval > 0 {
|
||||
// subchain = subchain.Limit(int(limitval))
|
||||
// }
|
||||
|
||||
// for _, j := range m.SqlJoins {
|
||||
// subchain = subchain.Joins(ValidSQL(j, "select"))
|
||||
// }
|
||||
|
||||
// offsetval, err := m.Offset.Int64()
|
||||
// if err == nil && offsetval > 0 {
|
||||
// subchain = subchain.Offset(int(offsetval))
|
||||
// }
|
||||
|
||||
// cols := make([]string, 0)
|
||||
|
||||
// for _, col := range Columns {
|
||||
// canAdd := true
|
||||
// for _, omit := range m.OmitColumns {
|
||||
// if col == omit {
|
||||
// canAdd = false
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// if canAdd {
|
||||
// cols = append(cols, col)
|
||||
// }
|
||||
// }
|
||||
|
||||
// for i, col := range m.CQLColumns {
|
||||
// cols = append(cols, fmt.Sprintf("(%s) as cql%d", col, i+1))
|
||||
// }
|
||||
|
||||
// if len(cols) > 0 {
|
||||
|
||||
// colStr := strings.Join(cols, ",")
|
||||
// subchain = subchain.Select(colStr)
|
||||
// }
|
||||
|
||||
// if m.Recursive && pCnt < 5 {
|
||||
// paths := strings.Split(path, ".")
|
||||
|
||||
// p := paths[0]
|
||||
// if len(paths) > 1 {
|
||||
// p = strings.Join(paths[1:], ".")
|
||||
// }
|
||||
// for i := uint(0); i < 3; i++ {
|
||||
// inlineStr := strings.Repeat(p+".", int(i+1))
|
||||
// inlineStr = strings.TrimRight(inlineStr, ".")
|
||||
|
||||
// fmt.Printf("Preloading Recursive (%d) %s: %s lvl:%d \n", i, m.TableName, inlineStr, m.Level)
|
||||
// subchain, err = m.preload(subchain, inlineStr, pCnt+i)
|
||||
// if err != nil {
|
||||
// cfg.LogError("Preload (%s,%d) error: %v", m.TableName, pCnt, err)
|
||||
// } else {
|
||||
|
||||
// if m.ChildTables != nil {
|
||||
// for _, child := range m.ChildTables {
|
||||
// if child.ParentEntity == nil {
|
||||
// continue
|
||||
// }
|
||||
// subchain, _ = child.ChainPreload(subchain, inlineStr, pCnt+i)
|
||||
|
||||
// }
|
||||
// }
|
||||
// if m.ParentTables != nil {
|
||||
// for _, pt := range m.ParentTables {
|
||||
// if pt.ParentEntity == nil {
|
||||
// continue
|
||||
// }
|
||||
// subchain, _ = pt.ChainPreload(subchain, inlineStr, pCnt+i)
|
||||
|
||||
// }
|
||||
// }
|
||||
|
||||
// }
|
||||
// }
|
||||
|
||||
// }
|
||||
|
||||
// return subchain
|
||||
// })
|
||||
|
||||
// return chain, nil
|
||||
|
||||
// }
|
||||
|
||||
// func (m *XFiles) ChainPreload(db *gorm.DB, pPath string, pCnt uint) (*gorm.DB, error) {
|
||||
// var err error
|
||||
// chain := db
|
||||
|
||||
// relations := m.GetParentRelations()
|
||||
// if pCnt > 10000 {
|
||||
// cfg.LogError("Preload Max size (%s,%s): %v", m.TableName, pPath, err)
|
||||
// return chain, nil
|
||||
// }
|
||||
|
||||
// hasPreloadError := false
|
||||
// for _, rel := range relations {
|
||||
// path := rel.FieldName
|
||||
// if pPath != "" {
|
||||
// path = fmt.Sprintf("%s.%s", pPath, rel.FieldName)
|
||||
// }
|
||||
|
||||
// chain, err = m.preload(chain, path, pCnt)
|
||||
// if err != nil {
|
||||
// cfg.LogError("Preload Error (%s,%s): %v", m.TableName, path, err)
|
||||
// hasPreloadError = true
|
||||
// //return chain, err
|
||||
// }
|
||||
|
||||
// //fmt.Printf("Preloading Rel %v: %s @ %s lvl:%d \n", m.Recursive, path, m.TableName, m.Level)
|
||||
// if !hasPreloadError && m.ChildTables != nil {
|
||||
// for _, child := range m.ChildTables {
|
||||
// if child.ParentEntity == nil {
|
||||
// continue
|
||||
// }
|
||||
// chain, err = child.ChainPreload(chain, path, pCnt)
|
||||
// if err != nil {
|
||||
// return chain, err
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if !hasPreloadError && m.ParentTables != nil {
|
||||
// for _, pt := range m.ParentTables {
|
||||
// if pt.ParentEntity == nil {
|
||||
// continue
|
||||
// }
|
||||
// chain, err = pt.ChainPreload(chain, path, pCnt)
|
||||
// if err != nil {
|
||||
// return chain, err
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// if len(relations) == 0 {
|
||||
// if m.ChildTables != nil {
|
||||
// for _, child := range m.ChildTables {
|
||||
// if child.ParentEntity == nil {
|
||||
// continue
|
||||
// }
|
||||
// chain, err = child.ChainPreload(chain, pPath, pCnt)
|
||||
// if err != nil {
|
||||
// return chain, err
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// if m.ParentTables != nil {
|
||||
// for _, pt := range m.ParentTables {
|
||||
// if pt.ParentEntity == nil {
|
||||
// continue
|
||||
// }
|
||||
// chain, err = pt.ChainPreload(chain, pPath, pCnt)
|
||||
// if err != nil {
|
||||
// return chain, err
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return chain, nil
|
||||
// }
|
||||
|
||||
// func (m *XFiles) Fill() {
|
||||
// m.ModelType = models.GetModelType(m.Schema, m.TableName)
|
||||
|
||||
// if m.ModelType == nil {
|
||||
// m.Errors = append(m.Errors, fmt.Errorf("ModelType not found for %s", m.TableName))
|
||||
// }
|
||||
// if m.Prefix == "" {
|
||||
// m.Prefix = reflection.GetTablePrefixFromType(m.ModelType)
|
||||
// }
|
||||
// if m.PrimaryKey == "" {
|
||||
// m.PrimaryKey = reflection.GetPKNameFromType(m.ModelType)
|
||||
// }
|
||||
|
||||
// if m.Schema == "" {
|
||||
// m.Schema = reflection.GetSchemaNameFromType(m.ModelType)
|
||||
// }
|
||||
|
||||
// for _, t := range m.ParentTables {
|
||||
// t.Fill()
|
||||
// }
|
||||
|
||||
// for _, t := range m.ChildTables {
|
||||
// t.Fill()
|
||||
// }
|
||||
// }
|
||||
|
||||
// type GormRelationTypeList []reflection.GormRelationType
|
||||
|
||||
// func (s GormRelationTypeList) Len() int { return len(s) }
|
||||
// func (s GormRelationTypeList) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
// func (s GormRelationTypeList) Less(i, j int) bool {
|
||||
// if strings.HasPrefix(strings.ToLower(s[j].FieldName),
|
||||
// strings.ToLower(fmt.Sprintf("%s_%s_%s", s[i].AssociationSchema, s[i].AssociationTable, s[i].AssociationKey))) {
|
||||
// return true
|
||||
// }
|
||||
|
||||
// return s[i].FieldName < s[j].FieldName
|
||||
// }
|
||||
213
pkg/restheadspec/xfiles_example.md
Normal file
213
pkg/restheadspec/xfiles_example.md
Normal file
@@ -0,0 +1,213 @@
|
||||
# X-Files Header Usage
|
||||
|
||||
The `x-files` header allows you to configure complex query options using a single JSON object. The XFiles configuration is parsed and populates the `ExtendedRequestOptions` fields, which means it integrates seamlessly with the existing query building system.
|
||||
|
||||
## Architecture
|
||||
|
||||
When an `x-files` header is received:
|
||||
1. It's parsed into an `XFiles` struct
|
||||
2. The `XFiles` fields populate the `ExtendedRequestOptions` (columns, filters, sort, preload, etc.)
|
||||
3. The normal query building process applies these options to the SQL query
|
||||
4. This allows x-files to work alongside individual headers if needed
|
||||
|
||||
## Basic Example
|
||||
|
||||
```http
|
||||
GET /public/users
|
||||
X-Files: {"tablename":"users","columns":["id","name","email"],"limit":"10","offset":"0"}
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
```http
|
||||
GET /public/users
|
||||
X-Files: {
|
||||
"tablename": "users",
|
||||
"schema": "public",
|
||||
"columns": ["id", "name", "email", "created_at"],
|
||||
"omit_columns": [],
|
||||
"sort": ["-created_at", "name"],
|
||||
"limit": "50",
|
||||
"offset": "0",
|
||||
"filter_fields": [
|
||||
{
|
||||
"field": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"field": "age",
|
||||
"operator": "gt",
|
||||
"value": "18"
|
||||
}
|
||||
],
|
||||
"sql_and": ["deleted_at IS NULL"],
|
||||
"sql_or": [],
|
||||
"cql_columns": ["UPPER(name)"],
|
||||
"skipcount": false,
|
||||
"distinct": false
|
||||
}
|
||||
```
|
||||
|
||||
## Supported Filter Operators
|
||||
|
||||
- `eq` - equals
|
||||
- `neq` - not equals
|
||||
- `gt` - greater than
|
||||
- `gte` - greater than or equals
|
||||
- `lt` - less than
|
||||
- `lte` - less than or equals
|
||||
- `like` - SQL LIKE
|
||||
- `ilike` - case-insensitive LIKE
|
||||
- `in` - IN clause
|
||||
- `between` - between (exclusive)
|
||||
- `between_inclusive` - between (inclusive)
|
||||
- `is_null` - is NULL
|
||||
- `is_not_null` - is NOT NULL
|
||||
|
||||
## Sorting
|
||||
|
||||
Sort fields can be prefixed with:
|
||||
- `+` for ascending (default)
|
||||
- `-` for descending
|
||||
|
||||
Examples:
|
||||
- `"sort": ["name"]` - ascending by name
|
||||
- `"sort": ["-created_at"]` - descending by created_at
|
||||
- `"sort": ["-created_at", "name"]` - multiple sorts
|
||||
|
||||
## Computed Columns (CQL)
|
||||
|
||||
Use `cql_columns` to add computed SQL expressions:
|
||||
|
||||
```json
|
||||
{
|
||||
"cql_columns": [
|
||||
"UPPER(name)",
|
||||
"CONCAT(first_name, ' ', last_name)"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
These will be available as `cql1`, `cql2`, etc. in the response.
|
||||
|
||||
## Cursor Pagination
|
||||
|
||||
```json
|
||||
{
|
||||
"cursor_forward": "eyJpZCI6MTAwfQ==",
|
||||
"cursor_backward": ""
|
||||
}
|
||||
```
|
||||
|
||||
## Base64 Encoding
|
||||
|
||||
For complex JSON, you can base64-encode the value and prefix it with `ZIP_` or `__`:
|
||||
|
||||
```http
|
||||
GET /public/users
|
||||
X-Files: ZIP_eyJ0YWJsZW5hbWUiOiJ1c2VycyIsImxpbWl0IjoiMTAifQ==
|
||||
```
|
||||
|
||||
## XFiles Struct Reference
|
||||
|
||||
```go
|
||||
type XFiles struct {
|
||||
TableName string `json:"tablename"`
|
||||
Schema string `json:"schema"`
|
||||
PrimaryKey string `json:"primarykey"`
|
||||
ForeignKey string `json:"foreignkey"`
|
||||
RelatedKey string `json:"relatedkey"`
|
||||
Sort []string `json:"sort"`
|
||||
Prefix string `json:"prefix"`
|
||||
Editable bool `json:"editable"`
|
||||
Recursive bool `json:"recursive"`
|
||||
Expand bool `json:"expand"`
|
||||
Rownumber bool `json:"rownumber"`
|
||||
Skipcount bool `json:"skipcount"`
|
||||
Offset json.Number `json:"offset"`
|
||||
Limit json.Number `json:"limit"`
|
||||
Columns []string `json:"columns"`
|
||||
OmitColumns []string `json:"omit_columns"`
|
||||
CQLColumns []string `json:"cql_columns"`
|
||||
SqlJoins []string `json:"sql_joins"`
|
||||
SqlOr []string `json:"sql_or"`
|
||||
SqlAnd []string `json:"sql_and"`
|
||||
FilterFields []struct {
|
||||
Field string `json:"field"`
|
||||
Value string `json:"value"`
|
||||
Operator string `json:"operator"`
|
||||
} `json:"filter_fields"`
|
||||
CursorForward string `json:"cursor_forward"`
|
||||
CursorBackward string `json:"cursor_backward"`
|
||||
}
|
||||
```
|
||||
|
||||
## Recursive Preloading with ParentTables and ChildTables
|
||||
|
||||
XFiles now supports recursive preloading of related entities:
|
||||
|
||||
```json
|
||||
{
|
||||
"tablename": "users",
|
||||
"columns": ["id", "name"],
|
||||
"limit": "10",
|
||||
"parenttables": [
|
||||
{
|
||||
"tablename": "Company",
|
||||
"columns": ["id", "name", "industry"],
|
||||
"sort": ["-created_at"]
|
||||
}
|
||||
],
|
||||
"childtables": [
|
||||
{
|
||||
"tablename": "Orders",
|
||||
"columns": ["id", "total", "status"],
|
||||
"limit": "5",
|
||||
"sort": ["-order_date"],
|
||||
"filter_fields": [
|
||||
{"field": "status", "operator": "eq", "value": "completed"}
|
||||
],
|
||||
"childtables": [
|
||||
{
|
||||
"tablename": "OrderItems",
|
||||
"columns": ["id", "product_name", "quantity"],
|
||||
"recursive": true
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### How Recursive Preloading Works
|
||||
|
||||
- **ParentTables**: Preloads parent relationships (e.g., User -> Company)
|
||||
- **ChildTables**: Preloads child relationships (e.g., User -> Orders -> OrderItems)
|
||||
- **Recursive**: When `true`, continues preloading the same relation recursively
|
||||
- Each nested table can have its own:
|
||||
- Column selection (`columns`, `omit_columns`)
|
||||
- Filtering (`filter_fields`, `sql_and`)
|
||||
- Sorting (`sort`)
|
||||
- Pagination (`limit`)
|
||||
- Further nesting (`parenttables`, `childtables`)
|
||||
|
||||
### Relation Path Building
|
||||
|
||||
Relations are built as dot-separated paths:
|
||||
- `Company` (direct parent)
|
||||
- `Orders` (direct child)
|
||||
- `Orders.OrderItems` (nested child)
|
||||
- `Orders.OrderItems.Product` (deeply nested)
|
||||
|
||||
## Notes
|
||||
|
||||
- Individual headers (like `x-select-fields`, `x-sort`, etc.) can still be used alongside `x-files`
|
||||
- X-Files populates `ExtendedRequestOptions` which is then processed by the normal query building logic
|
||||
- ParentTables and ChildTables are converted to `PreloadOption` entries with full support for:
|
||||
- Column selection
|
||||
- Filtering
|
||||
- Sorting
|
||||
- Limit
|
||||
- Recursive nesting
|
||||
- The relation name in ParentTables/ChildTables should match the GORM/Bun relation field name on the model
|
||||
@@ -372,7 +372,14 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
assert.True(t, result["success"].(bool), "Create department should succeed")
|
||||
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
|
||||
if success, ok := result["success"]; ok && success != nil {
|
||||
assert.True(t, success.(bool), "Create department should succeed")
|
||||
} else {
|
||||
// Unwrapped format - verify we got the created data back
|
||||
assert.NotEmpty(t, result, "Create department should return data")
|
||||
assert.Equal(t, deptID, result["id"], "Created department should have correct ID")
|
||||
}
|
||||
logger.Info("Department created successfully: %s", deptID)
|
||||
})
|
||||
|
||||
@@ -393,7 +400,14 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
assert.True(t, result["success"].(bool), "Create employee should succeed")
|
||||
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
|
||||
if success, ok := result["success"]; ok && success != nil {
|
||||
assert.True(t, success.(bool), "Create employee should succeed")
|
||||
} else {
|
||||
// Unwrapped format - verify we got the created data back
|
||||
assert.NotEmpty(t, result, "Create employee should return data")
|
||||
assert.Equal(t, empID, result["id"], "Created employee should have correct ID")
|
||||
}
|
||||
logger.Info("Employee created successfully: %s", empID)
|
||||
})
|
||||
|
||||
@@ -540,7 +554,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
assert.True(t, result["success"].(bool), "Update department should succeed")
|
||||
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
|
||||
if success, ok := result["success"]; ok && success != nil {
|
||||
assert.True(t, success.(bool), "Update department should succeed")
|
||||
} else {
|
||||
// Unwrapped format - verify we got the updated data back
|
||||
assert.NotEmpty(t, result, "Update department should return data")
|
||||
}
|
||||
logger.Info("Department updated successfully: %s", deptID)
|
||||
|
||||
// Verify update by reading the department again
|
||||
@@ -558,7 +578,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
assert.True(t, result["success"].(bool), "Update employee should succeed")
|
||||
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
|
||||
if success, ok := result["success"]; ok && success != nil {
|
||||
assert.True(t, success.(bool), "Update employee should succeed")
|
||||
} else {
|
||||
// Unwrapped format - verify we got the updated data back
|
||||
assert.NotEmpty(t, result, "Update employee should return data")
|
||||
}
|
||||
logger.Info("Employee updated successfully: %s", empID)
|
||||
})
|
||||
|
||||
@@ -569,7 +595,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
assert.True(t, result["success"].(bool), "Delete employee should succeed")
|
||||
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
|
||||
if success, ok := result["success"]; ok && success != nil {
|
||||
assert.True(t, success.(bool), "Delete employee should succeed")
|
||||
} else {
|
||||
// Unwrapped format - verify we got a response (typically {"deleted": count})
|
||||
assert.NotEmpty(t, result, "Delete employee should return data")
|
||||
}
|
||||
logger.Info("Employee deleted successfully: %s", empID)
|
||||
|
||||
// Verify deletion - just log that delete succeeded
|
||||
@@ -582,7 +614,13 @@ func testRestHeadSpecCRUD(t *testing.T, serverURL string) {
|
||||
|
||||
var result map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&result)
|
||||
assert.True(t, result["success"].(bool), "Delete department should succeed")
|
||||
// Check if response has "success" field (wrapped format) or direct data (unwrapped format)
|
||||
if success, ok := result["success"]; ok && success != nil {
|
||||
assert.True(t, success.(bool), "Delete department should succeed")
|
||||
} else {
|
||||
// Unwrapped format - verify we got a response (typically {"deleted": count})
|
||||
assert.NotEmpty(t, result, "Delete department should return data")
|
||||
}
|
||||
logger.Info("Department deleted successfully: %s", deptID)
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user