Lots of refactoring, Fixes to preloads

This commit is contained in:
Hein
2025-11-21 10:17:20 +02:00
parent 7853a3f56a
commit 7e76977dcc
12 changed files with 563 additions and 386 deletions

View File

@@ -237,7 +237,10 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() {
if r := recover(); r != nil {
logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
err := logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
if err != nil {
return
}
}
}()
if len(apply) == 0 {
@@ -401,7 +404,7 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
err = logger.HandlePanic("BunInsertQuery.Exec", r)
}
}()
if b.values != nil && len(b.values) > 0 {
if len(b.values) > 0 {
if !b.hasModel {
// If no model was set, use the values map as the model
// Bun can insert map[string]interface{} directly

View File

@@ -96,6 +96,117 @@ func IsSQLExpression(cond string) bool {
return false
}
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
// These conditions should be removed from WHERE clauses as they have no filtering effect
func IsTrivialCondition(cond string) bool {
cond = strings.TrimSpace(cond)
lowerCond := strings.ToLower(cond)
// Conditions that always evaluate to true
trivialConditions := []string{
"1=1", "1 = 1", "1= 1", "1 =1",
"true", "true = true", "true=true", "true= true", "true =true",
"0=0", "0 = 0", "0= 0", "0 =0",
}
for _, trivial := range trivialConditions {
if lowerCond == trivial {
return true
}
}
return false
}
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
//
// Parameters:
// - where: The WHERE clause string to sanitize
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
//
// Returns:
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
// - An empty string if all conditions were trivial or the input was empty
func SanitizeWhereClause(where string, tableName string) string {
if where == "" {
return ""
}
where = strings.TrimSpace(where)
// Split by AND to handle multiple conditions
conditions := splitByAND(where)
validConditions := make([]string, 0, len(conditions))
for _, cond := range conditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Skip trivial conditions that always evaluate to true
if IsTrivialCondition(cond) {
logger.Debug("Removing trivial condition: '%s'", cond)
continue
}
// If tableName is provided and the condition doesn't already have a table prefix,
// attempt to add it
if tableName != "" && !hasTablePrefix(cond) {
// Check if this is a SQL expression/literal that shouldn't be prefixed
if !IsSQLExpression(strings.ToLower(cond)) {
// Extract the column name and prefix it
columnName := ExtractColumnName(cond)
if columnName != "" {
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
logger.Debug("Prefixed column in condition: '%s'", cond)
}
}
}
validConditions = append(validConditions, cond)
}
if len(validConditions) == 0 {
return ""
}
result := strings.Join(validConditions, " AND ")
if result != where {
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
}
return result
}
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
// This is a simple split that doesn't handle nested parentheses or complex expressions
func splitByAND(where string) []string {
// First try uppercase AND
conditions := strings.Split(where, " AND ")
// If we didn't split on uppercase, try lowercase
if len(conditions) == 1 {
conditions = strings.Split(where, " and ")
}
// If we still didn't split, try mixed case
if len(conditions) == 1 {
conditions = strings.Split(where, " And ")
}
return conditions
}
// hasTablePrefix checks if a condition already has a table/relation prefix (contains a dot)
func hasTablePrefix(cond string) bool {
// Look for patterns like "table.column" or "`table`.`column`" or "\"table\".\"column\""
return strings.Contains(cond, ".")
}
// ExtractColumnName extracts the column name from a WHERE condition
// For example: "status = 'active'" returns "status"
func ExtractColumnName(cond string) string {

View File

@@ -238,13 +238,13 @@ func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
var err error
if b == nil {
t = &SqlTimeStamp{}
return nil
}
s := strings.Trim(strings.Trim(string(b), " "), "\"")
if s == "null" || s == "" || s == "0" ||
s == "0001-01-01T00:00:00" || s == "0001-01-01" {
t = &SqlTimeStamp{}
return nil
}
@@ -293,7 +293,7 @@ func (t *SqlTimeStamp) Scan(value interface{}) error {
// String - Override String format of time
func (t SqlTimeStamp) String() string {
return fmt.Sprintf("%s", time.Time(t).Format("2006-01-02T15:04:05"))
return time.Time(t).Format("2006-01-02T15:04:05")
}
// GetTime - Returns Time
@@ -308,7 +308,7 @@ func (t *SqlTimeStamp) SetTime(pTime time.Time) {
// Format - Formats the time
func (t SqlTimeStamp) Format(layout string) string {
return fmt.Sprintf("%s", time.Time(t).Format(layout))
return time.Time(t).Format(layout)
}
func SqlTimeStampNow() SqlTimeStamp {
@@ -420,7 +420,6 @@ func (t *SqlDate) UnmarshalJSON(b []byte) error {
if s == "null" || s == "" || s == "0" ||
strings.HasPrefix(s, "0001-01-01T00:00:00") ||
s == "0001-01-01" {
t = &SqlDate{}
return nil
}
@@ -434,7 +433,7 @@ func (t *SqlDate) UnmarshalJSON(b []byte) error {
// MarshalJSON - Override JSON format of time
func (t SqlDate) MarshalJSON() ([]byte, error) {
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339
if strings.HasPrefix(tmstr, "0001-01-01") {
return []byte("null"), nil
}
@@ -482,7 +481,7 @@ func (t SqlDate) Int64() int64 {
// String - Override String format of time
func (t SqlDate) String() string {
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339
if strings.HasPrefix(tmstr, "0001-01-01") || strings.HasPrefix(tmstr, "1800-12-31") {
return "0"
}
@@ -517,8 +516,8 @@ func (t *SqlTime) UnmarshalJSON(b []byte) error {
*t = SqlTime{}
return nil
}
tx := time.Time{}
tx, err = tryParseDT(s)
tx, err := tryParseDT(s)
*t = SqlTime(tx)
return err
@@ -642,9 +641,8 @@ func (n SqlJSONB) AsSlice() ([]any, error) {
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
s := strings.Trim(strings.Trim(string(b), " "), "\"")
invalid := (s == "null" || s == "" || len(s) < 2) || !(strings.Contains(s, "{") || strings.Contains(s, "["))
invalid := (s == "null" || s == "" || len(s) < 2) || (!strings.Contains(s, "{") && !strings.Contains(s, "["))
if invalid {
s = ""
return nil
}
@@ -661,7 +659,7 @@ func (n SqlJSONB) MarshalJSON() ([]byte, error) {
var obj interface{}
err := json.Unmarshal(n, &obj)
if err != nil {
//fmt.Printf("Invalid JSON %v", err)
// fmt.Printf("Invalid JSON %v", err)
return []byte("null"), nil
}
@@ -725,7 +723,6 @@ func (n *SqlUUID) UnmarshalJSON(b []byte) error {
s := strings.Trim(strings.Trim(string(b), " "), "\"")
invalid := (s == "null" || s == "" || len(s) < 30)
if invalid {
s = ""
return nil
}
*n = SqlUUID(sql.NullString{String: s, Valid: !invalid})

View File

@@ -40,7 +40,7 @@ type PreloadOption struct {
Where string `json:"where"`
Limit *int `json:"limit"`
Offset *int `json:"offset"`
Updatable *bool `json:"updateable"` // if true, the relation can be updated
Updatable *bool `json:"updateable"` // if true, the relation can be updated
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
}

View File

@@ -6,6 +6,7 @@ import (
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ColumnValidator validates column names against a model's fields
@@ -92,23 +93,6 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
return strings.ToLower(field.Name)
}
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func extractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// ValidateColumn validates a single column name
// Returns nil if valid, error if invalid
// Columns prefixed with "cql" (case insensitive) are always valid
@@ -125,7 +109,7 @@ func (v *ColumnValidator) ValidateColumn(column string) error {
}
// Extract source column name (remove JSON operators like ->> or ->)
sourceColumn := extractSourceColumn(column)
sourceColumn := reflection.ExtractSourceColumn(column)
// Check if column exists in model
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {

View File

@@ -2,6 +2,8 @@ package common
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
func TestExtractSourceColumn(t *testing.T) {
@@ -49,9 +51,9 @@ func TestExtractSourceColumn(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := extractSourceColumn(tc.input)
result := reflection.ExtractSourceColumn(tc.input)
if result != tc.expected {
t.Errorf("extractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
t.Errorf("reflection.ExtractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
}
})
}