mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-01-01 01:34:25 +00:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59bd709460 | ||
|
|
05962035b6 | ||
|
|
1cd04b7083 | ||
|
|
0d4909054c |
@@ -46,8 +46,8 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
||||
|
||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("BunAdapter.Exec"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result, err := b.db.ExecContext(ctx, query, args...)
|
||||
@@ -56,8 +56,8 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
|
||||
|
||||
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("BunAdapter.Query"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
||||
@@ -86,8 +86,8 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error {
|
||||
|
||||
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("BunAdapter.RunInTransaction"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||
@@ -235,6 +235,11 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
||||
|
||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
|
||||
}
|
||||
}()
|
||||
if len(apply) == 0 {
|
||||
return sq
|
||||
}
|
||||
@@ -294,8 +299,8 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
|
||||
|
||||
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("BunSelectQuery.Scan"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
if dest == nil {
|
||||
@@ -306,8 +311,8 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
|
||||
|
||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("BunSelectQuery.ScanModel"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
|
||||
}
|
||||
}()
|
||||
if b.query.GetModel() == nil {
|
||||
@@ -319,8 +324,8 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
|
||||
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("BunSelectQuery.Count"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Count", r)
|
||||
count = 0
|
||||
}
|
||||
}()
|
||||
@@ -341,8 +346,8 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
|
||||
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("BunSelectQuery.Exists"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Exists", r)
|
||||
exists = false
|
||||
}
|
||||
}()
|
||||
@@ -392,8 +397,8 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
|
||||
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("BunInsertQuery.Exec"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunInsertQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
if b.values != nil && len(b.values) > 0 {
|
||||
@@ -478,8 +483,8 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
||||
|
||||
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("BunUpdateQuery.Exec"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result, err := b.query.Exec(ctx)
|
||||
@@ -508,8 +513,8 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ
|
||||
|
||||
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("BunDeleteQuery.Exec"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result, err := b.query.Exec(ctx)
|
||||
|
||||
@@ -41,8 +41,8 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
||||
|
||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("GormAdapter.Exec"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Exec(query, args...)
|
||||
@@ -51,8 +51,8 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
|
||||
|
||||
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("GormAdapter.Query"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||
@@ -76,8 +76,8 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
|
||||
|
||||
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("GormAdapter.RunInTransaction"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
@@ -273,8 +273,8 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
|
||||
|
||||
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("GormSelectQuery.Scan"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Find(dest).Error
|
||||
@@ -282,8 +282,8 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
||||
|
||||
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("GormSelectQuery.ScanModel"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
|
||||
}
|
||||
}()
|
||||
if g.db.Statement.Model == nil {
|
||||
@@ -294,8 +294,8 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
|
||||
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("GormSelectQuery.Count"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormSelectQuery.Count", r)
|
||||
count = 0
|
||||
}
|
||||
}()
|
||||
@@ -306,8 +306,8 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
|
||||
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("GormSelectQuery.Exists"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormSelectQuery.Exists", r)
|
||||
exists = false
|
||||
}
|
||||
}()
|
||||
@@ -354,8 +354,8 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
|
||||
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("GormInsertQuery.Exec"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
var result *gorm.DB
|
||||
@@ -446,8 +446,8 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
||||
|
||||
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("GormUpdateQuery.Exec"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||
@@ -478,8 +478,8 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
|
||||
|
||||
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if panicErr := logger.RecoverPanic("GormDeleteQuery.Exec"); panicErr != nil {
|
||||
err = panicErr
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Delete(g.model)
|
||||
|
||||
136
pkg/common/sql_helpers.go
Normal file
136
pkg/common/sql_helpers.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
||||
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
||||
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
||||
if where == "" {
|
||||
return where, nil
|
||||
}
|
||||
|
||||
// Check if the relation name is already present in the WHERE clause
|
||||
lowerWhere := strings.ToLower(where)
|
||||
lowerRelation := strings.ToLower(relationName)
|
||||
|
||||
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
||||
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
||||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
||||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
||||
// Relation prefix is already present
|
||||
return where, nil
|
||||
}
|
||||
|
||||
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
||||
// we can't safely auto-fix it - require explicit prefix
|
||||
if strings.Contains(lowerWhere, " or ") ||
|
||||
strings.Contains(where, "(") ||
|
||||
strings.Contains(where, ")") {
|
||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
||||
}
|
||||
|
||||
// Try to add the relation prefix to simple column references
|
||||
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
||||
// Split by AND to handle multiple conditions (case-insensitive)
|
||||
originalConditions := strings.Split(where, " AND ")
|
||||
|
||||
// If uppercase split didn't work, try lowercase
|
||||
if len(originalConditions) == 1 {
|
||||
originalConditions = strings.Split(where, " and ")
|
||||
}
|
||||
|
||||
fixedConditions := make([]string, 0, len(originalConditions))
|
||||
|
||||
for _, cond := range originalConditions {
|
||||
cond = strings.TrimSpace(cond)
|
||||
if cond == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this condition already has a table prefix (contains a dot)
|
||||
if strings.Contains(cond, ".") {
|
||||
fixedConditions = append(fixedConditions, cond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||
lowerCond := strings.ToLower(strings.TrimSpace(cond))
|
||||
if IsSQLExpression(lowerCond) {
|
||||
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
|
||||
fixedConditions = append(fixedConditions, cond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract the column name (first identifier before operator)
|
||||
columnName := ExtractColumnName(cond)
|
||||
if columnName == "" {
|
||||
// Can't identify column name, require explicit prefix
|
||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
|
||||
}
|
||||
|
||||
// Add relation prefix to the column name only
|
||||
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
|
||||
fixedConditions = append(fixedConditions, fixedCond)
|
||||
}
|
||||
|
||||
fixedWhere := strings.Join(fixedConditions, " AND ")
|
||||
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
||||
return fixedWhere, nil
|
||||
}
|
||||
|
||||
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
|
||||
func IsSQLExpression(cond string) bool {
|
||||
// Common SQL literals and expressions
|
||||
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
|
||||
for _, literal := range sqlLiterals {
|
||||
if cond == literal {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractColumnName extracts the column name from a WHERE condition
|
||||
// For example: "status = 'active'" returns "status"
|
||||
func ExtractColumnName(cond string) string {
|
||||
// Common SQL operators
|
||||
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
||||
|
||||
for _, op := range operators {
|
||||
if idx := strings.Index(cond, op); idx > 0 {
|
||||
columnName := strings.TrimSpace(cond[:idx])
|
||||
// Remove quotes if present
|
||||
columnName = strings.Trim(columnName, "`\"'")
|
||||
return columnName
|
||||
}
|
||||
}
|
||||
|
||||
// If no operator found, check if it's a simple identifier (for boolean columns)
|
||||
parts := strings.Fields(cond)
|
||||
if len(parts) > 0 {
|
||||
columnName := strings.Trim(parts[0], "`\"'")
|
||||
// Check if it's a valid identifier (not a SQL keyword)
|
||||
if !IsSQLKeyword(strings.ToLower(columnName)) {
|
||||
return columnName
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
|
||||
func IsSQLKeyword(word string) bool {
|
||||
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
|
||||
for _, kw := range keywords {
|
||||
if word == kw {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -104,13 +104,17 @@ func CatchPanic(location string) {
|
||||
CatchPanicCallback(location, nil)
|
||||
}
|
||||
|
||||
// RecoverPanic recovers from panics and returns an error
|
||||
// Use this in deferred functions to convert panics into errors
|
||||
func RecoverPanic(methodName string) error {
|
||||
if r := recover(); r != nil {
|
||||
stack := debug.Stack()
|
||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
||||
return fmt.Errorf("panic in %s: %v", methodName, r)
|
||||
}
|
||||
return nil
|
||||
// HandlePanic logs a panic and returns it as an error
|
||||
// This should be called with the result of recover() from a deferred function
|
||||
// Example usage:
|
||||
//
|
||||
// defer func() {
|
||||
// if r := recover(); r != nil {
|
||||
// err = logger.HandlePanic("MethodName", r)
|
||||
// }
|
||||
// }()
|
||||
func HandlePanic(methodName string, r any) error {
|
||||
stack := debug.Stack()
|
||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
||||
return fmt.Errorf("panic in %s: %v", methodName, r)
|
||||
}
|
||||
|
||||
@@ -323,6 +323,127 @@ func ExtractColumnFromBunTag(tag string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetSQLModelColumns extracts column names that have valid SQL field mappings
|
||||
// This function only returns columns that:
|
||||
// 1. Have bun or gorm tags (not just json tags)
|
||||
// 2. Are not relations (no rel:, join:, foreignKey, references, many2many tags)
|
||||
// 3. Are not scan-only embedded fields
|
||||
func GetSQLModelColumns(model any) []string {
|
||||
var columns []string
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Validate that we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return columns
|
||||
}
|
||||
|
||||
collectSQLColumnsFromType(modelType, &columns, false)
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// collectSQLColumnsFromType recursively collects SQL column names from a struct type
|
||||
// scanOnlyEmbedded indicates if we're inside a scan-only embedded struct
|
||||
func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbedded bool) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Check if the embedded struct itself is scan-only
|
||||
isScanOnly := scanOnlyEmbedded
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||
isScanOnly = true
|
||||
}
|
||||
|
||||
// Recursively process embedded struct
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
collectSQLColumnsFromType(fieldType, columns, isScanOnly)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Skip fields in scan-only embedded structs
|
||||
if scanOnlyEmbedded {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get bun and gorm tags
|
||||
bunTag := field.Tag.Get("bun")
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
|
||||
// Skip if neither bun nor gorm tag exists
|
||||
if bunTag == "" && gormTag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if explicitly marked with "-"
|
||||
if bunTag == "-" || gormTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if field itself is scan-only (bun)
|
||||
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if field itself is read-only (gorm)
|
||||
if gormTag != "" && isGormFieldReadOnly(gormTag) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip relation fields (bun)
|
||||
if bunTag != "" {
|
||||
// Skip if it's a bun relation (rel:, join:, or m2m:)
|
||||
if strings.Contains(bunTag, "rel:") ||
|
||||
strings.Contains(bunTag, "join:") ||
|
||||
strings.Contains(bunTag, "m2m:") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Skip relation fields (gorm)
|
||||
if gormTag != "" {
|
||||
// Skip if it has gorm relationship tags
|
||||
if strings.Contains(gormTag, "foreignKey:") ||
|
||||
strings.Contains(gormTag, "references:") ||
|
||||
strings.Contains(gormTag, "many2many:") ||
|
||||
strings.Contains(gormTag, "constraint:") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Get column name
|
||||
columnName := ""
|
||||
if bunTag != "" {
|
||||
columnName = ExtractColumnFromBunTag(bunTag)
|
||||
}
|
||||
if columnName == "" && gormTag != "" {
|
||||
columnName = ExtractColumnFromGormTag(gormTag)
|
||||
}
|
||||
|
||||
// Skip if we couldn't extract a column name
|
||||
if columnName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
*columns = append(*columns, columnName)
|
||||
}
|
||||
}
|
||||
|
||||
// IsColumnWritable checks if a column can be written to in the database
|
||||
// For bun: returns false if the field has "scanonly" tag
|
||||
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
||||
|
||||
@@ -474,3 +474,143 @@ func TestIsColumnWritableWithEmbedded(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test models with relations for GetSQLModelColumns
|
||||
type User struct {
|
||||
ID int `bun:"id,pk" json:"id"`
|
||||
Name string `bun:"name" json:"name"`
|
||||
Email string `bun:"email" json:"email"`
|
||||
ProfileData string `json:"profile_data"` // No bun/gorm tag
|
||||
Posts []Post `bun:"rel:has-many,join:id=user_id" json:"posts"`
|
||||
Profile *Profile `bun:"rel:has-one,join:id=user_id" json:"profile"`
|
||||
RowNumber int64 `bun:",scanonly" json:"_rownumber"`
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||
Title string `gorm:"column:title" json:"title"`
|
||||
UserID int `gorm:"column:user_id;foreignKey" json:"user_id"`
|
||||
User *User `gorm:"foreignKey:UserID;references:ID" json:"user"`
|
||||
Tags []Tag `gorm:"many2many:post_tags" json:"tags"`
|
||||
Content string `json:"content"` // No bun/gorm tag
|
||||
}
|
||||
|
||||
type Profile struct {
|
||||
ID int `bun:"id,pk" json:"id"`
|
||||
Bio string `bun:"bio" json:"bio"`
|
||||
UserID int `bun:"user_id" json:"user_id"`
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||
Name string `gorm:"column:name" json:"name"`
|
||||
}
|
||||
|
||||
// Model with scan-only embedded struct
|
||||
type EntityWithScanOnlyEmbedded struct {
|
||||
ID int `bun:"id,pk" json:"id"`
|
||||
Name string `bun:"name" json:"name"`
|
||||
AdhocBuffer `bun:",scanonly"` // Entire embedded struct is scan-only
|
||||
}
|
||||
|
||||
func TestGetSQLModelColumns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Bun model with relations - excludes relations and non-SQL fields",
|
||||
model: User{},
|
||||
// Should include: id, name, email (has bun tags)
|
||||
// Should exclude: profile_data (no bun tag), Posts/Profile (relations), RowNumber (scan-only in embedded would be excluded)
|
||||
expected: []string{"id", "name", "email"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with relations - excludes relations and non-SQL fields",
|
||||
model: Post{},
|
||||
// Should include: id, title, user_id (has gorm tags)
|
||||
// Should exclude: content (no gorm tag), User/Tags (relations)
|
||||
expected: []string{"id", "title", "user_id"},
|
||||
},
|
||||
{
|
||||
name: "Model with embedded base and scan-only embedded",
|
||||
model: EntityWithScanOnlyEmbedded{},
|
||||
// Should include: id, name from main struct
|
||||
// Should exclude: all fields from AdhocBuffer (scan-only embedded struct)
|
||||
expected: []string{"id", "name"},
|
||||
},
|
||||
{
|
||||
name: "Model with embedded - includes SQL fields, excludes scan-only",
|
||||
model: ModelWithEmbedded{},
|
||||
// Should include: rid_base, created_at (from BaseModel), name, description (from main)
|
||||
// Should exclude: cql1, cql2, _rownumber (from AdhocBuffer - scan-only fields)
|
||||
expected: []string{"rid_base", "created_at", "name", "description"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded - includes SQL fields, excludes scan-only",
|
||||
model: GormModelWithEmbedded{},
|
||||
// Should include: rid_base, created_at (from GormBaseModel), name, description (from main)
|
||||
// Should exclude: cql1, cql2 (scan-only), _rownumber (no gorm column tag, marked as -)
|
||||
expected: []string{"rid_base", "created_at", "name", "description"},
|
||||
},
|
||||
{
|
||||
name: "Simple Profile model",
|
||||
model: Profile{},
|
||||
// Should include all fields with bun tags
|
||||
expected: []string{"id", "bio", "user_id"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetSQLModelColumns(tt.model)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("GetSQLModelColumns() returned %d columns, want %d.\nGot: %v\nWant: %v",
|
||||
len(result), len(tt.expected), result, tt.expected)
|
||||
return
|
||||
}
|
||||
for i, col := range result {
|
||||
if col != tt.expected[i] {
|
||||
t.Errorf("GetSQLModelColumns()[%d] = %v, want %v.\nFull result: %v",
|
||||
i, col, tt.expected[i], result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSQLModelColumnsVsGetModelColumns(t *testing.T) {
|
||||
// Demonstrate the difference between GetModelColumns and GetSQLModelColumns
|
||||
user := User{}
|
||||
|
||||
allColumns := GetModelColumns(user)
|
||||
sqlColumns := GetSQLModelColumns(user)
|
||||
|
||||
t.Logf("GetModelColumns(User): %v", allColumns)
|
||||
t.Logf("GetSQLModelColumns(User): %v", sqlColumns)
|
||||
|
||||
// GetModelColumns should return more columns (includes fields with only json tags)
|
||||
if len(allColumns) <= len(sqlColumns) {
|
||||
t.Errorf("Expected GetModelColumns to return more columns than GetSQLModelColumns")
|
||||
}
|
||||
|
||||
// GetSQLModelColumns should not include 'profile_data' (no bun tag)
|
||||
for _, col := range sqlColumns {
|
||||
if col == "profile_data" {
|
||||
t.Errorf("GetSQLModelColumns should not include 'profile_data' (no bun/gorm tag)")
|
||||
}
|
||||
}
|
||||
|
||||
// GetModelColumns should include 'profile_data' (has json tag)
|
||||
hasProfileData := false
|
||||
for _, col := range allColumns {
|
||||
if col == "profile_data" {
|
||||
hasProfileData = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasProfileData {
|
||||
t.Errorf("GetModelColumns should include 'profile_data' (has json tag)")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,6 +191,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
query = query.Table(tableName)
|
||||
}
|
||||
|
||||
if len(options.Columns) == 0 && (len(options.ComputedColumns) > 0) {
|
||||
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
|
||||
options.Columns = reflection.GetSQLModelColumns(model)
|
||||
}
|
||||
|
||||
// Apply column selection
|
||||
if len(options.Columns) > 0 {
|
||||
logger.Debug("Selecting columns: %v", options.Columns)
|
||||
@@ -1132,15 +1137,20 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
// ORMs like GORM and Bun expect the struct field name, not the JSON name
|
||||
relationFieldName := relInfo.fieldName
|
||||
|
||||
// For now, we'll preload without conditions
|
||||
// TODO: Implement column selection and filtering for preloads
|
||||
// This requires a more sophisticated approach with callbacks or query builders
|
||||
// Apply preloading
|
||||
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||
if len(preload.Where) > 0 {
|
||||
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, relationFieldName)
|
||||
if err != nil {
|
||||
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
|
||||
panic(fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err))
|
||||
}
|
||||
preload.Where = fixedWhere
|
||||
}
|
||||
|
||||
logger.Debug("Applying preload: %s", relationFieldName)
|
||||
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
||||
if len(preload.OmitColumns) > 0 {
|
||||
allCols := reflection.GetModelColumns(model)
|
||||
allCols := reflection.GetSQLModelColumns(model)
|
||||
// Remove omitted columns
|
||||
preload.Columns = []string{}
|
||||
for _, col := range allCols {
|
||||
|
||||
@@ -260,9 +260,12 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
query = query.Table(tableName)
|
||||
}
|
||||
|
||||
// Note: X-Files configuration is now applied via parseXFiles which populates
|
||||
// ExtendedRequestOptions fields (columns, filters, sort, preload, etc.)
|
||||
// These are applied below in the normal query building process
|
||||
// If we have computed columns/expressions but options.Columns is empty,
|
||||
// populate it with all model columns first since computed columns are additions
|
||||
if len(options.Columns) == 0 && (len(options.ComputedQL) > 0 || len(options.ComputedColumns) > 0) {
|
||||
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
|
||||
options.Columns = reflection.GetSQLModelColumns(model)
|
||||
}
|
||||
|
||||
// Apply ComputedQL fields if any
|
||||
if len(options.ComputedQL) > 0 {
|
||||
@@ -344,6 +347,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
for idx := range options.Preload {
|
||||
preload := options.Preload[idx]
|
||||
logger.Debug("Applying preload: %s", preload.Relation)
|
||||
|
||||
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||
if len(preload.Where) > 0 {
|
||||
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, preload.Relation)
|
||||
if err != nil {
|
||||
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
|
||||
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",
|
||||
fmt.Sprintf("Invalid preload WHERE clause for relation '%s'", preload.Relation), err)
|
||||
return
|
||||
}
|
||||
preload.Where = fixedWhere
|
||||
}
|
||||
|
||||
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
||||
if len(preload.OmitColumns) > 0 {
|
||||
allCols := reflection.GetModelColumns(model)
|
||||
|
||||
Reference in New Issue
Block a user