Compare commits

...

3 Commits

Author SHA1 Message Date
Hein
05962035b6 when you specify computed columns without explicitly listing base columns, you'll get all base model column
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-20 17:34:46 +02:00
Hein
1cd04b7083 Better where clause handling for preloads 2025-11-20 17:02:27 +02:00
Hein
0d4909054c Better handling of preload where conditions and a few panic changes 2025-11-20 16:50:26 +02:00
6 changed files with 222 additions and 56 deletions

View File

@@ -46,8 +46,8 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery {
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() {
if panicErr := logger.RecoverPanic("BunAdapter.Exec"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("BunAdapter.Exec", r)
}
}()
result, err := b.db.ExecContext(ctx, query, args...)
@@ -56,8 +56,8 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() {
if panicErr := logger.RecoverPanic("BunAdapter.Query"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("BunAdapter.Query", r)
}
}()
return b.db.NewRaw(query, args...).Scan(ctx, dest)
@@ -86,8 +86,8 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error {
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() {
if panicErr := logger.RecoverPanic("BunAdapter.RunInTransaction"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
}
}()
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
@@ -235,6 +235,11 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() {
if r := recover(); r != nil {
logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
}
}()
if len(apply) == 0 {
return sq
}
@@ -294,8 +299,8 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() {
if panicErr := logger.RecoverPanic("BunSelectQuery.Scan"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Scan", r)
}
}()
if dest == nil {
@@ -306,8 +311,8 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() {
if panicErr := logger.RecoverPanic("BunSelectQuery.ScanModel"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
}
}()
if b.query.GetModel() == nil {
@@ -319,8 +324,8 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if panicErr := logger.RecoverPanic("BunSelectQuery.Count"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Count", r)
count = 0
}
}()
@@ -341,8 +346,8 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() {
if panicErr := logger.RecoverPanic("BunSelectQuery.Exists"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Exists", r)
exists = false
}
}()
@@ -392,8 +397,8 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if panicErr := logger.RecoverPanic("BunInsertQuery.Exec"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("BunInsertQuery.Exec", r)
}
}()
if b.values != nil && len(b.values) > 0 {
@@ -478,8 +483,8 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if panicErr := logger.RecoverPanic("BunUpdateQuery.Exec"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
}
}()
result, err := b.query.Exec(ctx)
@@ -508,8 +513,8 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if panicErr := logger.RecoverPanic("BunDeleteQuery.Exec"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
}
}()
result, err := b.query.Exec(ctx)

View File

@@ -41,8 +41,8 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() {
if panicErr := logger.RecoverPanic("GormAdapter.Exec"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.Exec", r)
}
}()
result := g.db.WithContext(ctx).Exec(query, args...)
@@ -51,8 +51,8 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() {
if panicErr := logger.RecoverPanic("GormAdapter.Query"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.Query", r)
}
}()
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
@@ -76,8 +76,8 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() {
if panicErr := logger.RecoverPanic("GormAdapter.RunInTransaction"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
}
}()
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
@@ -273,8 +273,8 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() {
if panicErr := logger.RecoverPanic("GormSelectQuery.Scan"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Scan", r)
}
}()
return g.db.WithContext(ctx).Find(dest).Error
@@ -282,8 +282,8 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() {
if panicErr := logger.RecoverPanic("GormSelectQuery.ScanModel"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
}
}()
if g.db.Statement.Model == nil {
@@ -294,8 +294,8 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if panicErr := logger.RecoverPanic("GormSelectQuery.Count"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Count", r)
count = 0
}
}()
@@ -306,8 +306,8 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() {
if panicErr := logger.RecoverPanic("GormSelectQuery.Exists"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Exists", r)
exists = false
}
}()
@@ -354,8 +354,8 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if panicErr := logger.RecoverPanic("GormInsertQuery.Exec"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("GormInsertQuery.Exec", r)
}
}()
var result *gorm.DB
@@ -446,8 +446,8 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if panicErr := logger.RecoverPanic("GormUpdateQuery.Exec"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Updates(g.updates)
@@ -478,8 +478,8 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if panicErr := logger.RecoverPanic("GormDeleteQuery.Exec"); panicErr != nil {
err = panicErr
if r := recover(); r != nil {
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Delete(g.model)

136
pkg/common/sql_helpers.go Normal file
View File

@@ -0,0 +1,136 @@
package common
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
// the relation prefix (alias). If not present, it attempts to add it to column references.
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
if where == "" {
return where, nil
}
// Check if the relation name is already present in the WHERE clause
lowerWhere := strings.ToLower(where)
lowerRelation := strings.ToLower(relationName)
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
if strings.Contains(lowerWhere, lowerRelation+".") ||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
// Relation prefix is already present
return where, nil
}
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
// we can't safely auto-fix it - require explicit prefix
if strings.Contains(lowerWhere, " or ") ||
strings.Contains(where, "(") ||
strings.Contains(where, ")") {
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
}
// Try to add the relation prefix to simple column references
// This handles basic cases like "column = value" or "column = value AND other_column = value"
// Split by AND to handle multiple conditions (case-insensitive)
originalConditions := strings.Split(where, " AND ")
// If uppercase split didn't work, try lowercase
if len(originalConditions) == 1 {
originalConditions = strings.Split(where, " and ")
}
fixedConditions := make([]string, 0, len(originalConditions))
for _, cond := range originalConditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Check if this condition already has a table prefix (contains a dot)
if strings.Contains(cond, ".") {
fixedConditions = append(fixedConditions, cond)
continue
}
// Check if this is a SQL expression/literal that shouldn't be prefixed
lowerCond := strings.ToLower(strings.TrimSpace(cond))
if IsSQLExpression(lowerCond) {
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
fixedConditions = append(fixedConditions, cond)
continue
}
// Extract the column name (first identifier before operator)
columnName := ExtractColumnName(cond)
if columnName == "" {
// Can't identify column name, require explicit prefix
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
}
// Add relation prefix to the column name only
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
fixedConditions = append(fixedConditions, fixedCond)
}
fixedWhere := strings.Join(fixedConditions, " AND ")
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
return fixedWhere, nil
}
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
func IsSQLExpression(cond string) bool {
// Common SQL literals and expressions
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
for _, literal := range sqlLiterals {
if cond == literal {
return true
}
}
return false
}
// ExtractColumnName extracts the column name from a WHERE condition
// For example: "status = 'active'" returns "status"
func ExtractColumnName(cond string) string {
// Common SQL operators
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
for _, op := range operators {
if idx := strings.Index(cond, op); idx > 0 {
columnName := strings.TrimSpace(cond[:idx])
// Remove quotes if present
columnName = strings.Trim(columnName, "`\"'")
return columnName
}
}
// If no operator found, check if it's a simple identifier (for boolean columns)
parts := strings.Fields(cond)
if len(parts) > 0 {
columnName := strings.Trim(parts[0], "`\"'")
// Check if it's a valid identifier (not a SQL keyword)
if !IsSQLKeyword(strings.ToLower(columnName)) {
return columnName
}
}
return ""
}
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
func IsSQLKeyword(word string) bool {
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
for _, kw := range keywords {
if word == kw {
return true
}
}
return false
}

View File

@@ -104,13 +104,17 @@ func CatchPanic(location string) {
CatchPanicCallback(location, nil)
}
// RecoverPanic recovers from panics and returns an error
// Use this in deferred functions to convert panics into errors
func RecoverPanic(methodName string) error {
if r := recover(); r != nil {
stack := debug.Stack()
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
return fmt.Errorf("panic in %s: %v", methodName, r)
}
return nil
// HandlePanic logs a panic and returns it as an error
// This should be called with the result of recover() from a deferred function
// Example usage:
//
// defer func() {
// if r := recover(); r != nil {
// err = logger.HandlePanic("MethodName", r)
// }
// }()
func HandlePanic(methodName string, r any) error {
stack := debug.Stack()
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
return fmt.Errorf("panic in %s: %v", methodName, r)
}

View File

@@ -1132,10 +1132,15 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
// ORMs like GORM and Bun expect the struct field name, not the JSON name
relationFieldName := relInfo.fieldName
// For now, we'll preload without conditions
// TODO: Implement column selection and filtering for preloads
// This requires a more sophisticated approach with callbacks or query builders
// Apply preloading
// Validate and fix WHERE clause to ensure it contains the relation prefix
if len(preload.Where) > 0 {
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, relationFieldName)
if err != nil {
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err))
}
preload.Where = fixedWhere
}
logger.Debug("Applying preload: %s", relationFieldName)
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {

View File

@@ -260,9 +260,12 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Table(tableName)
}
// Note: X-Files configuration is now applied via parseXFiles which populates
// ExtendedRequestOptions fields (columns, filters, sort, preload, etc.)
// These are applied below in the normal query building process
// If we have computed columns/expressions but options.Columns is empty,
// populate it with all model columns first since computed columns are additions
if len(options.Columns) == 0 && (len(options.ComputedQL) > 0 || len(options.ComputedColumns) > 0) {
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
options.Columns = reflection.GetModelColumns(model)
}
// Apply ComputedQL fields if any
if len(options.ComputedQL) > 0 {
@@ -344,6 +347,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
for idx := range options.Preload {
preload := options.Preload[idx]
logger.Debug("Applying preload: %s", preload.Relation)
// Validate and fix WHERE clause to ensure it contains the relation prefix
if len(preload.Where) > 0 {
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, preload.Relation)
if err != nil {
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",
fmt.Sprintf("Invalid preload WHERE clause for relation '%s'", preload.Relation), err)
return
}
preload.Where = fixedWhere
}
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
if len(preload.OmitColumns) > 0 {
allCols := reflection.GetModelColumns(model)