mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-30 08:14:25 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1261960e97 | ||
|
|
76bbf33db2 | ||
|
|
02c9b96b0c | ||
|
|
9a3564f05f | ||
|
|
a931b8cdd2 | ||
|
|
7e76977dcc |
@@ -86,7 +86,6 @@
|
|||||||
"emptyFallthrough",
|
"emptyFallthrough",
|
||||||
"equalFold",
|
"equalFold",
|
||||||
"flagName",
|
"flagName",
|
||||||
"ifElseChain",
|
|
||||||
"indexAlloc",
|
"indexAlloc",
|
||||||
"initClause",
|
"initClause",
|
||||||
"methodExprCall",
|
"methodExprCall",
|
||||||
@@ -106,6 +105,9 @@
|
|||||||
"unnecessaryBlock",
|
"unnecessaryBlock",
|
||||||
"weakCond",
|
"weakCond",
|
||||||
"yodaStyleExpr"
|
"yodaStyleExpr"
|
||||||
|
],
|
||||||
|
"disabled-checks": [
|
||||||
|
"ifElseChain"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"revive": {
|
"revive": {
|
||||||
|
|||||||
@@ -237,7 +237,10 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
|||||||
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
|
err := logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if len(apply) == 0 {
|
if len(apply) == 0 {
|
||||||
@@ -401,7 +404,7 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
|
|||||||
err = logger.HandlePanic("BunInsertQuery.Exec", r)
|
err = logger.HandlePanic("BunInsertQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if b.values != nil && len(b.values) > 0 {
|
if len(b.values) > 0 {
|
||||||
if !b.hasModel {
|
if !b.hasModel {
|
||||||
// If no model was set, use the values map as the model
|
// If no model was set, use the values map as the model
|
||||||
// Bun can insert map[string]interface{} directly
|
// Bun can insert map[string]interface{} directly
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||||
@@ -96,6 +98,173 @@ func IsSQLExpression(cond string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
|
||||||
|
// These conditions should be removed from WHERE clauses as they have no filtering effect
|
||||||
|
func IsTrivialCondition(cond string) bool {
|
||||||
|
cond = strings.TrimSpace(cond)
|
||||||
|
lowerCond := strings.ToLower(cond)
|
||||||
|
|
||||||
|
// Conditions that always evaluate to true
|
||||||
|
trivialConditions := []string{
|
||||||
|
"1=1", "1 = 1", "1= 1", "1 =1",
|
||||||
|
"true", "true = true", "true=true", "true= true", "true =true",
|
||||||
|
"0=0", "0 = 0", "0= 0", "0 =0",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, trivial := range trivialConditions {
|
||||||
|
if lowerCond == trivial {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
|
||||||
|
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - where: The WHERE clause string to sanitize
|
||||||
|
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
|
||||||
|
// - An empty string if all conditions were trivial or the input was empty
|
||||||
|
func SanitizeWhereClause(where string, tableName string) string {
|
||||||
|
if where == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
where = strings.TrimSpace(where)
|
||||||
|
|
||||||
|
// Strip outer parentheses and re-trim
|
||||||
|
where = stripOuterParentheses(where)
|
||||||
|
|
||||||
|
// Get valid columns from the model if tableName is provided
|
||||||
|
var validColumns map[string]bool
|
||||||
|
if tableName != "" {
|
||||||
|
validColumns = getValidColumnsForTable(tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split by AND to handle multiple conditions
|
||||||
|
conditions := splitByAND(where)
|
||||||
|
|
||||||
|
validConditions := make([]string, 0, len(conditions))
|
||||||
|
|
||||||
|
for _, cond := range conditions {
|
||||||
|
cond = strings.TrimSpace(cond)
|
||||||
|
if cond == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip parentheses from the condition before checking
|
||||||
|
condToCheck := stripOuterParentheses(cond)
|
||||||
|
|
||||||
|
// Skip trivial conditions that always evaluate to true
|
||||||
|
if IsTrivialCondition(condToCheck) {
|
||||||
|
logger.Debug("Removing trivial condition: '%s'", cond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If tableName is provided and the condition doesn't already have a table prefix,
|
||||||
|
// attempt to add it
|
||||||
|
if tableName != "" && !hasTablePrefix(condToCheck) {
|
||||||
|
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||||
|
if !IsSQLExpression(strings.ToLower(condToCheck)) {
|
||||||
|
// Extract the column name and prefix it
|
||||||
|
columnName := ExtractColumnName(condToCheck)
|
||||||
|
if columnName != "" {
|
||||||
|
// Only prefix if this is a valid column in the model
|
||||||
|
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
|
||||||
|
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||||
|
// Replace in the original condition (without stripped parens)
|
||||||
|
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
|
||||||
|
logger.Debug("Prefixed column in condition: '%s'", cond)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
validConditions = append(validConditions, cond)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validConditions) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
result := strings.Join(validConditions, " AND ")
|
||||||
|
|
||||||
|
if result != where {
|
||||||
|
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripOuterParentheses removes matching outer parentheses from a string
|
||||||
|
// It handles nested parentheses correctly
|
||||||
|
func stripOuterParentheses(s string) string {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if these parentheses match (i.e., they're the outermost pair)
|
||||||
|
depth := 0
|
||||||
|
matched := false
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
switch s[i] {
|
||||||
|
case '(':
|
||||||
|
depth++
|
||||||
|
case ')':
|
||||||
|
depth--
|
||||||
|
if depth == 0 && i == len(s)-1 {
|
||||||
|
matched = true
|
||||||
|
} else if depth == 0 {
|
||||||
|
// Found a closing paren before the end, so outer parens don't match
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !matched {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip the outer parentheses and continue
|
||||||
|
s = strings.TrimSpace(s[1 : len(s)-1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||||
|
// This is a simple split that doesn't handle nested parentheses or complex expressions
|
||||||
|
func splitByAND(where string) []string {
|
||||||
|
// First try uppercase AND
|
||||||
|
conditions := strings.Split(where, " AND ")
|
||||||
|
|
||||||
|
// If we didn't split on uppercase, try lowercase
|
||||||
|
if len(conditions) == 1 {
|
||||||
|
conditions = strings.Split(where, " and ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we still didn't split, try mixed case
|
||||||
|
if len(conditions) == 1 {
|
||||||
|
conditions = strings.Split(where, " And ")
|
||||||
|
}
|
||||||
|
|
||||||
|
return conditions
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasTablePrefix checks if a condition already has a table/relation prefix (contains a dot)
|
||||||
|
func hasTablePrefix(cond string) bool {
|
||||||
|
// Look for patterns like "table.column" or "`table`.`column`" or "\"table\".\"column\""
|
||||||
|
return strings.Contains(cond, ".")
|
||||||
|
}
|
||||||
|
|
||||||
// ExtractColumnName extracts the column name from a WHERE condition
|
// ExtractColumnName extracts the column name from a WHERE condition
|
||||||
// For example: "status = 'active'" returns "status"
|
// For example: "status = 'active'" returns "status"
|
||||||
func ExtractColumnName(cond string) string {
|
func ExtractColumnName(cond string) string {
|
||||||
@@ -134,3 +303,38 @@ func IsSQLKeyword(word string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry
|
||||||
|
// Returns a map of column names for fast lookup, or nil if the model is not found
|
||||||
|
func getValidColumnsForTable(tableName string) map[string]bool {
|
||||||
|
// Try to get the model from the registry
|
||||||
|
model, err := modelregistry.GetModelByName(tableName)
|
||||||
|
if err != nil {
|
||||||
|
// Model not found, return nil to indicate we should use fallback behavior
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get SQL columns from the model
|
||||||
|
columns := reflection.GetSQLModelColumns(model)
|
||||||
|
if len(columns) == 0 {
|
||||||
|
// No columns found, return nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a map for fast lookup
|
||||||
|
columnMap := make(map[string]bool, len(columns))
|
||||||
|
for _, col := range columns {
|
||||||
|
columnMap[strings.ToLower(col)] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return columnMap
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidColumn checks if a column name exists in the valid columns map
|
||||||
|
// Handles case-insensitive comparison
|
||||||
|
func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||||
|
if validColumns == nil {
|
||||||
|
return true // No model info, assume valid
|
||||||
|
}
|
||||||
|
return validColumns[strings.ToLower(columnName)]
|
||||||
|
}
|
||||||
|
|||||||
224
pkg/common/sql_helpers_test.go
Normal file
224
pkg/common/sql_helpers_test.go
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSanitizeWhereClause(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
where string
|
||||||
|
tableName string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "trivial conditions in parentheses",
|
||||||
|
where: "(true AND true AND true)",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trivial conditions without parentheses",
|
||||||
|
where: "true AND true AND true",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single trivial condition",
|
||||||
|
where: "true",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid condition with parentheses",
|
||||||
|
where: "(status = 'active')",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed trivial and valid conditions",
|
||||||
|
where: "true AND status = 'active' AND 1=1",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "condition already with table prefix",
|
||||||
|
where: "users.status = 'active'",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple valid conditions",
|
||||||
|
where: "status = 'active' AND age > 18",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no table name provided",
|
||||||
|
where: "status = 'active'",
|
||||||
|
tableName: "",
|
||||||
|
expected: "status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty where clause",
|
||||||
|
where: "",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SanitizeWhereClause(tt.where, tt.tableName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripOuterParentheses(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single level parentheses",
|
||||||
|
input: "(true)",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple levels",
|
||||||
|
input: "((true))",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no parentheses",
|
||||||
|
input: "true",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mismatched parentheses",
|
||||||
|
input: "(true",
|
||||||
|
expected: "(true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex expression",
|
||||||
|
input: "(a AND b)",
|
||||||
|
expected: "a AND b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested but not outer",
|
||||||
|
input: "(a AND (b OR c)) AND d",
|
||||||
|
expected: "(a AND (b OR c)) AND d",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with spaces",
|
||||||
|
input: " ( true ) ",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := stripOuterParentheses(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("stripOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTrivialCondition(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"true", "true", true},
|
||||||
|
{"true with spaces", " true ", true},
|
||||||
|
{"TRUE uppercase", "TRUE", true},
|
||||||
|
{"1=1", "1=1", true},
|
||||||
|
{"1 = 1", "1 = 1", true},
|
||||||
|
{"true = true", "true = true", true},
|
||||||
|
{"valid condition", "status = 'active'", false},
|
||||||
|
{"false", "false", false},
|
||||||
|
{"column name", "is_active", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := IsTrivialCondition(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsTrivialCondition(%q) = %v; want %v", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test model for model-aware sanitization tests
|
||||||
|
type MasterTask struct {
|
||||||
|
ID int `bun:"id,pk"`
|
||||||
|
Name string `bun:"name"`
|
||||||
|
Status string `bun:"status"`
|
||||||
|
UserID int `bun:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||||
|
// Register the test model
|
||||||
|
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
|
||||||
|
if err != nil {
|
||||||
|
// Model might already be registered, ignore error
|
||||||
|
t.Logf("Model registration returned: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
where string
|
||||||
|
tableName string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid column gets prefixed",
|
||||||
|
where: "status = 'active'",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "mastertask.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple valid columns get prefixed",
|
||||||
|
where: "status = 'active' AND user_id = 123",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid column does not get prefixed",
|
||||||
|
where: "invalid_column = 'value'",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "invalid_column = 'value'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mix of valid and trivial conditions",
|
||||||
|
where: "true AND status = 'active' AND 1=1",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "mastertask.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parentheses with valid column",
|
||||||
|
where: "(status = 'active')",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "mastertask.status = 'active'",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SanitizeWhereClause(tt.where, tt.tableName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -238,13 +238,13 @@ func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
|
|||||||
var err error
|
var err error
|
||||||
|
|
||||||
if b == nil {
|
if b == nil {
|
||||||
t = &SqlTimeStamp{}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
if s == "null" || s == "" || s == "0" ||
|
if s == "null" || s == "" || s == "0" ||
|
||||||
s == "0001-01-01T00:00:00" || s == "0001-01-01" {
|
s == "0001-01-01T00:00:00" || s == "0001-01-01" {
|
||||||
t = &SqlTimeStamp{}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -293,7 +293,7 @@ func (t *SqlTimeStamp) Scan(value interface{}) error {
|
|||||||
|
|
||||||
// String - Override String format of time
|
// String - Override String format of time
|
||||||
func (t SqlTimeStamp) String() string {
|
func (t SqlTimeStamp) String() string {
|
||||||
return fmt.Sprintf("%s", time.Time(t).Format("2006-01-02T15:04:05"))
|
return time.Time(t).Format("2006-01-02T15:04:05")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTime - Returns Time
|
// GetTime - Returns Time
|
||||||
@@ -308,7 +308,7 @@ func (t *SqlTimeStamp) SetTime(pTime time.Time) {
|
|||||||
|
|
||||||
// Format - Formats the time
|
// Format - Formats the time
|
||||||
func (t SqlTimeStamp) Format(layout string) string {
|
func (t SqlTimeStamp) Format(layout string) string {
|
||||||
return fmt.Sprintf("%s", time.Time(t).Format(layout))
|
return time.Time(t).Format(layout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func SqlTimeStampNow() SqlTimeStamp {
|
func SqlTimeStampNow() SqlTimeStamp {
|
||||||
@@ -420,7 +420,6 @@ func (t *SqlDate) UnmarshalJSON(b []byte) error {
|
|||||||
if s == "null" || s == "" || s == "0" ||
|
if s == "null" || s == "" || s == "0" ||
|
||||||
strings.HasPrefix(s, "0001-01-01T00:00:00") ||
|
strings.HasPrefix(s, "0001-01-01T00:00:00") ||
|
||||||
s == "0001-01-01" {
|
s == "0001-01-01" {
|
||||||
t = &SqlDate{}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -434,7 +433,7 @@ func (t *SqlDate) UnmarshalJSON(b []byte) error {
|
|||||||
|
|
||||||
// MarshalJSON - Override JSON format of time
|
// MarshalJSON - Override JSON format of time
|
||||||
func (t SqlDate) MarshalJSON() ([]byte, error) {
|
func (t SqlDate) MarshalJSON() ([]byte, error) {
|
||||||
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
|
tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339
|
||||||
if strings.HasPrefix(tmstr, "0001-01-01") {
|
if strings.HasPrefix(tmstr, "0001-01-01") {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
}
|
}
|
||||||
@@ -482,7 +481,7 @@ func (t SqlDate) Int64() int64 {
|
|||||||
|
|
||||||
// String - Override String format of time
|
// String - Override String format of time
|
||||||
func (t SqlDate) String() string {
|
func (t SqlDate) String() string {
|
||||||
tmstr := time.Time(t).Format("2006-01-02") //time.RFC3339
|
tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339
|
||||||
if strings.HasPrefix(tmstr, "0001-01-01") || strings.HasPrefix(tmstr, "1800-12-31") {
|
if strings.HasPrefix(tmstr, "0001-01-01") || strings.HasPrefix(tmstr, "1800-12-31") {
|
||||||
return "0"
|
return "0"
|
||||||
}
|
}
|
||||||
@@ -517,8 +516,8 @@ func (t *SqlTime) UnmarshalJSON(b []byte) error {
|
|||||||
*t = SqlTime{}
|
*t = SqlTime{}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
tx := time.Time{}
|
|
||||||
tx, err = tryParseDT(s)
|
tx, err := tryParseDT(s)
|
||||||
*t = SqlTime(tx)
|
*t = SqlTime(tx)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@@ -642,9 +641,8 @@ func (n SqlJSONB) AsSlice() ([]any, error) {
|
|||||||
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
|
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
|
||||||
|
|
||||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
invalid := (s == "null" || s == "" || len(s) < 2) || !(strings.Contains(s, "{") || strings.Contains(s, "["))
|
invalid := (s == "null" || s == "" || len(s) < 2) || (!strings.Contains(s, "{") && !strings.Contains(s, "["))
|
||||||
if invalid {
|
if invalid {
|
||||||
s = ""
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -661,7 +659,7 @@ func (n SqlJSONB) MarshalJSON() ([]byte, error) {
|
|||||||
var obj interface{}
|
var obj interface{}
|
||||||
err := json.Unmarshal(n, &obj)
|
err := json.Unmarshal(n, &obj)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//fmt.Printf("Invalid JSON %v", err)
|
// fmt.Printf("Invalid JSON %v", err)
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -725,7 +723,6 @@ func (n *SqlUUID) UnmarshalJSON(b []byte) error {
|
|||||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||||
invalid := (s == "null" || s == "" || len(s) < 30)
|
invalid := (s == "null" || s == "" || len(s) < 30)
|
||||||
if invalid {
|
if invalid {
|
||||||
s = ""
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
*n = SqlUUID(sql.NullString{String: s, Valid: !invalid})
|
*n = SqlUUID(sql.NullString{String: s, Valid: !invalid})
|
||||||
|
|||||||
@@ -43,6 +43,11 @@ type PreloadOption struct {
|
|||||||
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
||||||
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
|
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
|
||||||
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
|
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
|
||||||
|
|
||||||
|
// Relationship keys from XFiles - used to build proper foreign key filters
|
||||||
|
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
||||||
|
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
||||||
|
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
||||||
}
|
}
|
||||||
|
|
||||||
type FilterOption struct {
|
type FilterOption struct {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ColumnValidator validates column names against a model's fields
|
// ColumnValidator validates column names against a model's fields
|
||||||
@@ -92,23 +93,6 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
|
|||||||
return strings.ToLower(field.Name)
|
return strings.ToLower(field.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
|
|
||||||
// Examples:
|
|
||||||
// - "columna->>'val'" returns "columna"
|
|
||||||
// - "columna->'key'" returns "columna"
|
|
||||||
// - "columna" returns "columna"
|
|
||||||
// - "table.columna->>'val'" returns "table.columna"
|
|
||||||
func extractSourceColumn(colName string) string {
|
|
||||||
// Check for PostgreSQL JSON operators: -> and ->>
|
|
||||||
if idx := strings.Index(colName, "->>"); idx != -1 {
|
|
||||||
return strings.TrimSpace(colName[:idx])
|
|
||||||
}
|
|
||||||
if idx := strings.Index(colName, "->"); idx != -1 {
|
|
||||||
return strings.TrimSpace(colName[:idx])
|
|
||||||
}
|
|
||||||
return colName
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateColumn validates a single column name
|
// ValidateColumn validates a single column name
|
||||||
// Returns nil if valid, error if invalid
|
// Returns nil if valid, error if invalid
|
||||||
// Columns prefixed with "cql" (case insensitive) are always valid
|
// Columns prefixed with "cql" (case insensitive) are always valid
|
||||||
@@ -125,7 +109,7 @@ func (v *ColumnValidator) ValidateColumn(column string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Extract source column name (remove JSON operators like ->> or ->)
|
// Extract source column name (remove JSON operators like ->> or ->)
|
||||||
sourceColumn := extractSourceColumn(column)
|
sourceColumn := reflection.ExtractSourceColumn(column)
|
||||||
|
|
||||||
// Check if column exists in model
|
// Check if column exists in model
|
||||||
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {
|
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestExtractSourceColumn(t *testing.T) {
|
func TestExtractSourceColumn(t *testing.T) {
|
||||||
@@ -49,9 +51,9 @@ func TestExtractSourceColumn(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
result := extractSourceColumn(tc.input)
|
result := reflection.ExtractSourceColumn(tc.input)
|
||||||
if result != tc.expected {
|
if result != tc.expected {
|
||||||
t.Errorf("extractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
|
t.Errorf("reflection.ExtractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,8 +26,7 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var lst []ModelFieldDetail
|
lst := make([]ModelFieldDetail, 0)
|
||||||
lst = make([]ModelFieldDetail, 0)
|
|
||||||
|
|
||||||
if !record.IsValid() {
|
if !record.IsValid() {
|
||||||
return lst
|
return lst
|
||||||
|
|||||||
@@ -17,3 +17,33 @@ func Len(v any) int {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExtractTableNameOnly extracts the table name from a fully qualified table reference.
|
||||||
|
// It removes any schema prefix (e.g., "schema.table" -> "table") and truncates at
|
||||||
|
// the first delimiter (comma, space, tab, or newline). If the input contains multiple
|
||||||
|
// dots, it returns everything after the last dot up to the first delimiter.
|
||||||
|
func ExtractTableNameOnly(fullName string) string {
|
||||||
|
// First, split by dot to remove schema prefix if present
|
||||||
|
lastDotIndex := -1
|
||||||
|
for i, char := range fullName {
|
||||||
|
if char == '.' {
|
||||||
|
lastDotIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start from after the last dot (or from beginning if no dot)
|
||||||
|
startIndex := 0
|
||||||
|
if lastDotIndex != -1 {
|
||||||
|
startIndex = lastDotIndex + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now find the end (first delimiter after the table name)
|
||||||
|
for i := startIndex; i < len(fullName); i++ {
|
||||||
|
char := rune(fullName[i])
|
||||||
|
if char == ',' || char == ' ' || char == '\t' || char == '\n' {
|
||||||
|
return fullName[startIndex:i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fullName[startIndex:]
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package reflection
|
package reflection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
@@ -132,7 +134,7 @@ func findFieldByName(val reflect.Value, name string) any {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if field name matches
|
// Check if field name matches
|
||||||
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() {
|
if strings.EqualFold(field.Name, name) && fieldValue.CanInterface() {
|
||||||
return fieldValue.Interface()
|
return fieldValue.Interface()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -472,7 +474,7 @@ func IsColumnWritable(model any, columnName string) bool {
|
|||||||
|
|
||||||
// isColumnWritableInType recursively searches for a column and checks if it's writable
|
// isColumnWritableInType recursively searches for a column and checks if it's writable
|
||||||
// Returns (found, writable) where found indicates if the column was found
|
// Returns (found, writable) where found indicates if the column was found
|
||||||
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) {
|
func isColumnWritableInType(typ reflect.Type, columnName string) (found bool, writable bool) {
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
field := typ.Field(i)
|
field := typ.Field(i)
|
||||||
|
|
||||||
@@ -561,3 +563,321 @@ func isGormFieldReadOnly(tag string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExtractSourceColumn extracts the base column name from PostgreSQL JSON operators
|
||||||
|
// Examples:
|
||||||
|
// - "columna->>'val'" returns "columna"
|
||||||
|
// - "columna->'key'" returns "columna"
|
||||||
|
// - "columna" returns "columna"
|
||||||
|
// - "table.columna->>'val'" returns "table.columna"
|
||||||
|
func ExtractSourceColumn(colName string) string {
|
||||||
|
// Check for PostgreSQL JSON operators: -> and ->>
|
||||||
|
if idx := strings.Index(colName, "->>"); idx != -1 {
|
||||||
|
return strings.TrimSpace(colName[:idx])
|
||||||
|
}
|
||||||
|
if idx := strings.Index(colName, "->"); idx != -1 {
|
||||||
|
return strings.TrimSpace(colName[:idx])
|
||||||
|
}
|
||||||
|
return colName
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToSnakeCase converts a string from CamelCase to snake_case
|
||||||
|
func ToSnakeCase(s string) string {
|
||||||
|
var result strings.Builder
|
||||||
|
for i, r := range s {
|
||||||
|
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||||
|
result.WriteRune('_')
|
||||||
|
}
|
||||||
|
result.WriteRune(r)
|
||||||
|
}
|
||||||
|
return strings.ToLower(result.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetColumnTypeFromModel uses reflection to determine the Go type of a column in a model
|
||||||
|
func GetColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
||||||
|
if model == nil {
|
||||||
|
return reflect.Invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the source column name (remove JSON operators like ->> or ->)
|
||||||
|
sourceColName := ExtractSourceColumn(colName)
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
// Dereference pointer if needed
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure it's a struct
|
||||||
|
if modelType.Kind() != reflect.Struct {
|
||||||
|
return reflect.Invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the field by JSON tag or field name
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
|
||||||
|
// Check JSON tag
|
||||||
|
jsonTag := field.Tag.Get("json")
|
||||||
|
if jsonTag != "" {
|
||||||
|
// Parse JSON tag (format: "name,omitempty")
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if parts[0] == sourceColName {
|
||||||
|
return field.Type.Kind()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check field name (case-insensitive)
|
||||||
|
if strings.EqualFold(field.Name, sourceColName) {
|
||||||
|
return field.Type.Kind()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check snake_case conversion
|
||||||
|
snakeCaseName := ToSnakeCase(field.Name)
|
||||||
|
if snakeCaseName == sourceColName {
|
||||||
|
return field.Type.Kind()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return reflect.Invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNumericType checks if a reflect.Kind is a numeric type
|
||||||
|
func IsNumericType(kind reflect.Kind) bool {
|
||||||
|
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
|
||||||
|
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
|
||||||
|
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
|
||||||
|
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsStringType checks if a reflect.Kind is a string type
|
||||||
|
func IsStringType(kind reflect.Kind) bool {
|
||||||
|
return kind == reflect.String
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNumericValue checks if a string value can be parsed as a number
|
||||||
|
func IsNumericValue(value string) bool {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
_, err := strconv.ParseFloat(value, 64)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertToNumericType converts a string value to the appropriate numeric type
|
||||||
|
func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
// Parse as integer
|
||||||
|
bitSize := 64
|
||||||
|
switch kind {
|
||||||
|
case reflect.Int8:
|
||||||
|
bitSize = 8
|
||||||
|
case reflect.Int16:
|
||||||
|
bitSize = 16
|
||||||
|
case reflect.Int32:
|
||||||
|
bitSize = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
intVal, err := strconv.ParseInt(value, 10, bitSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid integer value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the appropriate type
|
||||||
|
switch kind {
|
||||||
|
case reflect.Int:
|
||||||
|
return int(intVal), nil
|
||||||
|
case reflect.Int8:
|
||||||
|
return int8(intVal), nil
|
||||||
|
case reflect.Int16:
|
||||||
|
return int16(intVal), nil
|
||||||
|
case reflect.Int32:
|
||||||
|
return int32(intVal), nil
|
||||||
|
case reflect.Int64:
|
||||||
|
return intVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
// Parse as unsigned integer
|
||||||
|
bitSize := 64
|
||||||
|
switch kind {
|
||||||
|
case reflect.Uint8:
|
||||||
|
bitSize = 8
|
||||||
|
case reflect.Uint16:
|
||||||
|
bitSize = 16
|
||||||
|
case reflect.Uint32:
|
||||||
|
bitSize = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
uintVal, err := strconv.ParseUint(value, 10, bitSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the appropriate type
|
||||||
|
switch kind {
|
||||||
|
case reflect.Uint:
|
||||||
|
return uint(uintVal), nil
|
||||||
|
case reflect.Uint8:
|
||||||
|
return uint8(uintVal), nil
|
||||||
|
case reflect.Uint16:
|
||||||
|
return uint16(uintVal), nil
|
||||||
|
case reflect.Uint32:
|
||||||
|
return uint32(uintVal), nil
|
||||||
|
case reflect.Uint64:
|
||||||
|
return uintVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
// Parse as float
|
||||||
|
bitSize := 64
|
||||||
|
if kind == reflect.Float32 {
|
||||||
|
bitSize = 32
|
||||||
|
}
|
||||||
|
|
||||||
|
floatVal, err := strconv.ParseFloat(value, bitSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid float value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if kind == reflect.Float32 {
|
||||||
|
return float32(floatVal), nil
|
||||||
|
}
|
||||||
|
return floatVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelationModel gets the model type for a relation field
|
||||||
|
// It searches for the field by name in the following order (case-insensitive):
|
||||||
|
// 1. Actual field name
|
||||||
|
// 2. Bun tag name (if exists)
|
||||||
|
// 3. Gorm tag name (if exists)
|
||||||
|
// 4. JSON tag name (if exists)
|
||||||
|
//
|
||||||
|
// Supports recursive field paths using dot notation (e.g., "MAL.MAL.DEF")
|
||||||
|
// For nested fields, it traverses through each level of the struct hierarchy
|
||||||
|
func GetRelationModel(model interface{}, fieldName string) interface{} {
|
||||||
|
if model == nil || fieldName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split the field name by "." to handle nested/recursive relations
|
||||||
|
fieldParts := strings.Split(fieldName, ".")
|
||||||
|
|
||||||
|
// Start with the current model
|
||||||
|
currentModel := model
|
||||||
|
|
||||||
|
// Traverse through each level of the field path
|
||||||
|
for _, part := range fieldParts {
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
currentModel = getRelationModelSingleLevel(currentModel, part)
|
||||||
|
if currentModel == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return currentModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
||||||
|
// This is a helper function used by GetRelationModel to handle one level at a time
|
||||||
|
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
||||||
|
if model == nil || fieldName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the field by checking in priority order (case-insensitive)
|
||||||
|
var field *reflect.StructField
|
||||||
|
normalizedFieldName := strings.ToLower(fieldName)
|
||||||
|
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
f := modelType.Field(i)
|
||||||
|
|
||||||
|
// 1. Check actual field name (case-insensitive)
|
||||||
|
if strings.EqualFold(f.Name, fieldName) {
|
||||||
|
field = &f
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check bun tag name
|
||||||
|
bunTag := f.Tag.Get("bun")
|
||||||
|
if bunTag != "" {
|
||||||
|
bunColName := ExtractColumnFromBunTag(bunTag)
|
||||||
|
if bunColName != "" && strings.EqualFold(bunColName, normalizedFieldName) {
|
||||||
|
field = &f
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Check gorm tag name
|
||||||
|
gormTag := f.Tag.Get("gorm")
|
||||||
|
if gormTag != "" {
|
||||||
|
gormColName := ExtractColumnFromGormTag(gormTag)
|
||||||
|
if gormColName != "" && strings.EqualFold(gormColName, normalizedFieldName) {
|
||||||
|
field = &f
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Check JSON tag name
|
||||||
|
jsonTag := f.Tag.Get("json")
|
||||||
|
if jsonTag != "" {
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
|
||||||
|
if strings.EqualFold(parts[0], normalizedFieldName) {
|
||||||
|
field = &f
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if field == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the target type
|
||||||
|
targetType := field.Type
|
||||||
|
if targetType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetType.Kind() == reflect.Slice {
|
||||||
|
targetType = targetType.Elem()
|
||||||
|
if targetType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if targetType.Kind() == reflect.Ptr {
|
||||||
|
targetType = targetType.Elem()
|
||||||
|
if targetType == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if targetType.Kind() != reflect.Struct {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a zero value of the target type
|
||||||
|
return reflect.New(targetType).Elem().Interface()
|
||||||
|
}
|
||||||
|
|||||||
@@ -199,7 +199,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply column selection
|
// Apply column selection
|
||||||
if len(options.Columns) > 0 {
|
if len(options.Columns) > 0 {
|
||||||
logger.Debug("Selecting columns: %v", options.Columns)
|
logger.Debug("Selecting columns: %v", options.Columns)
|
||||||
query = query.Column(options.Columns...)
|
for _, col := range options.Columns {
|
||||||
|
query = query.Column(reflection.ExtractSourceColumn(col))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(options.ComputedColumns) > 0 {
|
if len(options.ComputedColumns) > 0 {
|
||||||
@@ -1149,6 +1151,11 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
|
|
||||||
logger.Debug("Applying preload: %s", relationFieldName)
|
logger.Debug("Applying preload: %s", relationFieldName)
|
||||||
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
|
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
|
||||||
|
preload.Columns = reflection.GetSQLModelColumns(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle column selection and omission
|
||||||
if len(preload.OmitColumns) > 0 {
|
if len(preload.OmitColumns) > 0 {
|
||||||
allCols := reflection.GetSQLModelColumns(model)
|
allCols := reflection.GetSQLModelColumns(model)
|
||||||
// Remove omitted columns
|
// Remove omitted columns
|
||||||
@@ -1204,7 +1211,10 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
sq = sq.Where(preload.Where)
|
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||||
|
if len(sanitizedWhere) > 0 {
|
||||||
|
sq = sq.Where(sanitizedWhere)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if preload.Limit != nil && *preload.Limit > 0 {
|
if preload.Limit != nil && *preload.Limit > 0 {
|
||||||
|
|||||||
@@ -213,6 +213,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
model := GetModel(ctx)
|
model := GetModel(ctx)
|
||||||
|
|
||||||
|
if id == "" {
|
||||||
|
options.SingleRecordAsObject = false
|
||||||
|
}
|
||||||
|
|
||||||
// Execute BeforeRead hooks
|
// Execute BeforeRead hooks
|
||||||
hookCtx := &HookContext{
|
hookCtx := &HookContext{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
@@ -299,7 +303,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply column selection
|
// Apply column selection
|
||||||
if len(options.Columns) > 0 {
|
if len(options.Columns) > 0 {
|
||||||
logger.Debug("Selecting columns: %v", options.Columns)
|
logger.Debug("Selecting columns: %v", options.Columns)
|
||||||
query = query.Column(options.Columns...)
|
for _, col := range options.Columns {
|
||||||
|
query = query.Column(reflection.ExtractSourceColumn(col))
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply expand (Just expand to Preload for now)
|
// Apply expand (Just expand to Preload for now)
|
||||||
@@ -391,13 +398,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply custom SQL WHERE clause (AND condition)
|
// Apply custom SQL WHERE clause (AND condition)
|
||||||
if options.CustomSQLWhere != "" {
|
if options.CustomSQLWhere != "" {
|
||||||
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
||||||
query = query.Where(options.CustomSQLWhere)
|
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
||||||
|
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
||||||
|
if sanitizedWhere != "" {
|
||||||
|
query = query.Where(sanitizedWhere)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply custom SQL WHERE clause (OR condition)
|
// Apply custom SQL WHERE clause (OR condition)
|
||||||
if options.CustomSQLOr != "" {
|
if options.CustomSQLOr != "" {
|
||||||
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
||||||
query = query.WhereOr(options.CustomSQLOr)
|
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
||||||
|
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
||||||
|
if sanitizedOr != "" {
|
||||||
|
query = query.WhereOr(sanitizedOr)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If ID is provided, filter by ID
|
// If ID is provided, filter by ID
|
||||||
@@ -473,7 +488,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply cursor filter to query
|
// Apply cursor filter to query
|
||||||
if cursorFilter != "" {
|
if cursorFilter != "" {
|
||||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||||
query = query.Where(cursorFilter)
|
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName))
|
||||||
|
if sanitizedCursor != "" {
|
||||||
|
query = query.Where(sanitizedCursor)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -549,18 +567,41 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
|
|
||||||
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
||||||
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery {
|
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery {
|
||||||
|
// Log relationship keys if they're specified (from XFiles)
|
||||||
|
if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" {
|
||||||
|
logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s",
|
||||||
|
preload.Relation, preload.PrimaryKey, preload.RelatedKey, preload.ForeignKey)
|
||||||
|
|
||||||
|
// Build a WHERE clause using the relationship keys if needed
|
||||||
|
// Note: Bun's PreloadRelation typically handles the relationship join automatically via struct tags
|
||||||
|
// However, if the relationship keys are explicitly provided from XFiles, we can use them
|
||||||
|
// to add additional filtering or validation
|
||||||
|
if preload.RelatedKey != "" && preload.Where == "" {
|
||||||
|
// For child tables: ensure the child's relatedkey column will be matched
|
||||||
|
// The actual parent value is dynamic and handled by Bun's preload mechanism
|
||||||
|
// We just log this for visibility
|
||||||
|
logger.Debug("Child table %s will be filtered by %s matching parent's primary key",
|
||||||
|
preload.Relation, preload.RelatedKey)
|
||||||
|
}
|
||||||
|
if preload.ForeignKey != "" && preload.Where == "" {
|
||||||
|
// For parent tables: ensure the parent's primary key matches the current table's foreign key
|
||||||
|
logger.Debug("Parent table %s will be filtered by primary key matching current table's %s",
|
||||||
|
preload.Relation, preload.ForeignKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Apply the preload
|
// Apply the preload
|
||||||
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
// Get the related model for column operations
|
// Get the related model for column operations
|
||||||
relatedModel := h.getRelationModel(model, preload.Relation)
|
relatedModel := reflection.GetRelationModel(model, preload.Relation)
|
||||||
if relatedModel == nil {
|
if relatedModel == nil {
|
||||||
logger.Warn("Could not get related model for preload: %s", preload.Relation)
|
logger.Warn("Could not get related model for preload: %s", preload.Relation)
|
||||||
relatedModel = model // fallback to parent model
|
// relatedModel = model // fallback to parent model
|
||||||
}
|
} else {
|
||||||
|
|
||||||
// If we have computed columns but no explicit columns, populate with all model columns first
|
// If we have computed columns but no explicit columns, populate with all model columns first
|
||||||
// since computed columns are additions
|
// since computed columns are additions
|
||||||
if len(preload.Columns) == 0 && len(preload.ComputedQL) > 0 && relatedModel != nil {
|
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
|
||||||
logger.Debug("Populating preload columns with all model columns since computed columns are additions")
|
logger.Debug("Populating preload columns with all model columns since computed columns are additions")
|
||||||
preload.Columns = reflection.GetSQLModelColumns(relatedModel)
|
preload.Columns = reflection.GetSQLModelColumns(relatedModel)
|
||||||
}
|
}
|
||||||
@@ -581,8 +622,8 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle OmitColumns
|
// Handle OmitColumns
|
||||||
if len(preload.OmitColumns) > 0 && relatedModel != nil {
|
if len(preload.OmitColumns) > 0 {
|
||||||
allCols := reflection.GetModelColumns(relatedModel)
|
allCols := preload.Columns
|
||||||
// Remove omitted columns
|
// Remove omitted columns
|
||||||
preload.Columns = []string{}
|
preload.Columns = []string{}
|
||||||
for _, col := range allCols {
|
for _, col := range allCols {
|
||||||
@@ -603,6 +644,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
if len(preload.Columns) > 0 {
|
if len(preload.Columns) > 0 {
|
||||||
sq = sq.Column(preload.Columns...)
|
sq = sq.Column(preload.Columns...)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Apply filters
|
// Apply filters
|
||||||
if len(preload.Filters) > 0 {
|
if len(preload.Filters) > 0 {
|
||||||
@@ -620,7 +662,10 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
|
|
||||||
// Apply WHERE clause
|
// Apply WHERE clause
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
sq = sq.Where(preload.Where)
|
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||||
|
if len(sanitizedWhere) > 0 {
|
||||||
|
sq = sq.Where(sanitizedWhere)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply limit
|
// Apply limit
|
||||||
@@ -628,6 +673,10 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
sq = sq.Limit(*preload.Limit)
|
sq = sq.Limit(*preload.Limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if preload.Offset != nil && *preload.Offset > 0 {
|
||||||
|
sq = sq.Offset(*preload.Offset)
|
||||||
|
}
|
||||||
|
|
||||||
return sq
|
return sq
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1312,7 +1361,7 @@ func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
|
|||||||
func (h *Handler) extractNestedRelations(
|
func (h *Handler) extractNestedRelations(
|
||||||
data map[string]interface{},
|
data map[string]interface{},
|
||||||
model interface{},
|
model interface{},
|
||||||
) (map[string]interface{}, map[string]interface{}, error) {
|
) (_cleanedData map[string]interface{}, _relations map[string]interface{}, _err error) {
|
||||||
// Get model type for reflection
|
// Get model type for reflection
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
@@ -1741,7 +1790,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
|
|||||||
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
||||||
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
||||||
if data == nil {
|
if data == nil {
|
||||||
return data
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use reflection to check if data is a slice or array
|
// Use reflection to check if data is a slice or array
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ExtendedRequestOptions extends common.RequestOptions with additional features
|
// ExtendedRequestOptions extends common.RequestOptions with additional features
|
||||||
@@ -122,78 +123,85 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
|
|
||||||
// Merge headers and query parameters - query parameters take precedence
|
// Merge headers and query parameters - query parameters take precedence
|
||||||
// This allows the same parameters to be specified in either headers or query string
|
// This allows the same parameters to be specified in either headers or query string
|
||||||
|
// Normalize keys to lowercase to ensure query params properly override headers
|
||||||
combinedParams := make(map[string]string)
|
combinedParams := make(map[string]string)
|
||||||
for key, value := range headers {
|
for key, value := range headers {
|
||||||
combinedParams[key] = value
|
combinedParams[strings.ToLower(key)] = value
|
||||||
}
|
}
|
||||||
for key, value := range queryParams {
|
for key, value := range queryParams {
|
||||||
combinedParams[key] = value
|
combinedParams[strings.ToLower(key)] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process each parameter (from both headers and query params)
|
// Process each parameter (from both headers and query params)
|
||||||
|
// Note: keys are already normalized to lowercase in combinedParams
|
||||||
for key, value := range combinedParams {
|
for key, value := range combinedParams {
|
||||||
// Normalize parameter key to lowercase for consistent matching
|
|
||||||
normalizedKey := strings.ToLower(key)
|
|
||||||
|
|
||||||
// Decode value if it's base64 encoded
|
// Decode value if it's base64 encoded
|
||||||
decodedValue := decodeHeaderValue(value)
|
decodedValue := decodeHeaderValue(value)
|
||||||
|
|
||||||
// Parse based on parameter prefix/name
|
// Parse based on parameter prefix/name
|
||||||
switch {
|
switch {
|
||||||
// Field Selection
|
// Field Selection
|
||||||
case strings.HasPrefix(normalizedKey, "x-select-fields"):
|
case strings.HasPrefix(key, "x-select-fields"):
|
||||||
h.parseSelectFields(&options, decodedValue)
|
h.parseSelectFields(&options, decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-not-select-fields"):
|
case strings.HasPrefix(key, "x-not-select-fields"):
|
||||||
h.parseNotSelectFields(&options, decodedValue)
|
h.parseNotSelectFields(&options, decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-clean-json"):
|
case strings.HasPrefix(key, "x-clean-json"):
|
||||||
options.CleanJSON = strings.EqualFold(decodedValue, "true")
|
options.CleanJSON = strings.EqualFold(decodedValue, "true")
|
||||||
|
|
||||||
// Filtering & Search
|
// Filtering & Search
|
||||||
case strings.HasPrefix(normalizedKey, "x-fieldfilter-"):
|
case strings.HasPrefix(key, "x-fieldfilter-"):
|
||||||
h.parseFieldFilter(&options, normalizedKey, decodedValue)
|
h.parseFieldFilter(&options, key, decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-searchfilter-"):
|
case strings.HasPrefix(key, "x-searchfilter-"):
|
||||||
h.parseSearchFilter(&options, normalizedKey, decodedValue)
|
h.parseSearchFilter(&options, key, decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-searchop-"):
|
case strings.HasPrefix(key, "x-searchop-"):
|
||||||
h.parseSearchOp(&options, normalizedKey, decodedValue, "AND")
|
h.parseSearchOp(&options, key, decodedValue, "AND")
|
||||||
case strings.HasPrefix(normalizedKey, "x-searchor-"):
|
case strings.HasPrefix(key, "x-searchor-"):
|
||||||
h.parseSearchOp(&options, normalizedKey, decodedValue, "OR")
|
h.parseSearchOp(&options, key, decodedValue, "OR")
|
||||||
case strings.HasPrefix(normalizedKey, "x-searchand-"):
|
case strings.HasPrefix(key, "x-searchand-"):
|
||||||
h.parseSearchOp(&options, normalizedKey, decodedValue, "AND")
|
h.parseSearchOp(&options, key, decodedValue, "AND")
|
||||||
case strings.HasPrefix(normalizedKey, "x-searchcols"):
|
case strings.HasPrefix(key, "x-searchcols"):
|
||||||
options.SearchColumns = h.parseCommaSeparated(decodedValue)
|
options.SearchColumns = h.parseCommaSeparated(decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-custom-sql-w"):
|
case strings.HasPrefix(key, "x-custom-sql-w"):
|
||||||
|
if options.CustomSQLWhere != "" {
|
||||||
|
options.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", options.CustomSQLWhere, decodedValue)
|
||||||
|
} else {
|
||||||
options.CustomSQLWhere = decodedValue
|
options.CustomSQLWhere = decodedValue
|
||||||
case strings.HasPrefix(normalizedKey, "x-custom-sql-or"):
|
}
|
||||||
|
case strings.HasPrefix(key, "x-custom-sql-or"):
|
||||||
|
if options.CustomSQLOr != "" {
|
||||||
|
options.CustomSQLOr = fmt.Sprintf("%s OR (%s)", options.CustomSQLOr, decodedValue)
|
||||||
|
} else {
|
||||||
options.CustomSQLOr = decodedValue
|
options.CustomSQLOr = decodedValue
|
||||||
|
}
|
||||||
|
|
||||||
// Joins & Relations
|
// Joins & Relations
|
||||||
case strings.HasPrefix(normalizedKey, "x-preload"):
|
case strings.HasPrefix(key, "x-preload"):
|
||||||
if strings.HasSuffix(normalizedKey, "-where") {
|
if strings.HasSuffix(key, "-where") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
whereClaude := combinedParams[fmt.Sprintf("%s-where", key)]
|
whereClaude := combinedParams[fmt.Sprintf("%s-where", key)]
|
||||||
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
|
h.parsePreload(&options, decodedValue, decodeHeaderValue(whereClaude))
|
||||||
|
|
||||||
case strings.HasPrefix(normalizedKey, "x-expand"):
|
case strings.HasPrefix(key, "x-expand"):
|
||||||
h.parseExpand(&options, decodedValue)
|
h.parseExpand(&options, decodedValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-custom-sql-join"):
|
case strings.HasPrefix(key, "x-custom-sql-join"):
|
||||||
// TODO: Implement custom SQL join
|
// TODO: Implement custom SQL join
|
||||||
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
|
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
|
||||||
|
|
||||||
// Sorting & Pagination
|
// Sorting & Pagination
|
||||||
case strings.HasPrefix(normalizedKey, "x-sort"):
|
case strings.HasPrefix(key, "x-sort"):
|
||||||
h.parseSorting(&options, decodedValue)
|
h.parseSorting(&options, decodedValue)
|
||||||
//Special cases for older clients using sort(a,b,-c) syntax
|
// Special cases for older clients using sort(a,b,-c) syntax
|
||||||
case strings.HasPrefix(normalizedKey, "sort(") && strings.Contains(normalizedKey, ")"):
|
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
|
||||||
sortValue := normalizedKey[strings.Index(normalizedKey, "(")+1 : strings.Index(normalizedKey, ")")]
|
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
|
||||||
h.parseSorting(&options, sortValue)
|
h.parseSorting(&options, sortValue)
|
||||||
case strings.HasPrefix(normalizedKey, "x-limit"):
|
case strings.HasPrefix(key, "x-limit"):
|
||||||
if limit, err := strconv.Atoi(decodedValue); err == nil {
|
if limit, err := strconv.Atoi(decodedValue); err == nil {
|
||||||
options.Limit = &limit
|
options.Limit = &limit
|
||||||
}
|
}
|
||||||
//Special cases for older clients using limit(n) syntax
|
// Special cases for older clients using limit(n) syntax
|
||||||
case strings.HasPrefix(normalizedKey, "limit(") && strings.Contains(normalizedKey, ")"):
|
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
|
||||||
limitValue := normalizedKey[strings.Index(normalizedKey, "(")+1 : strings.Index(normalizedKey, ")")]
|
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
|
||||||
limitValueParts := strings.Split(limitValue, ",")
|
limitValueParts := strings.Split(limitValue, ",")
|
||||||
|
|
||||||
if len(limitValueParts) > 1 {
|
if len(limitValueParts) > 1 {
|
||||||
@@ -209,42 +217,43 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case strings.HasPrefix(normalizedKey, "x-offset"):
|
case strings.HasPrefix(key, "x-offset"):
|
||||||
if offset, err := strconv.Atoi(decodedValue); err == nil {
|
if offset, err := strconv.Atoi(decodedValue); err == nil {
|
||||||
options.Offset = &offset
|
options.Offset = &offset
|
||||||
}
|
}
|
||||||
|
|
||||||
case strings.HasPrefix(normalizedKey, "x-cursor-forward"):
|
case strings.HasPrefix(key, "x-cursor-forward"):
|
||||||
options.CursorForward = decodedValue
|
options.CursorForward = decodedValue
|
||||||
case strings.HasPrefix(normalizedKey, "x-cursor-backward"):
|
case strings.HasPrefix(key, "x-cursor-backward"):
|
||||||
options.CursorBackward = decodedValue
|
options.CursorBackward = decodedValue
|
||||||
|
|
||||||
// Advanced Features
|
// Advanced Features
|
||||||
case strings.HasPrefix(normalizedKey, "x-advsql-"):
|
case strings.HasPrefix(key, "x-advsql-"):
|
||||||
colName := strings.TrimPrefix(normalizedKey, "x-advsql-")
|
colName := strings.TrimPrefix(key, "x-advsql-")
|
||||||
options.AdvancedSQL[colName] = decodedValue
|
options.AdvancedSQL[colName] = decodedValue
|
||||||
case strings.HasPrefix(normalizedKey, "x-cql-sel-"):
|
case strings.HasPrefix(key, "x-cql-sel-"):
|
||||||
colName := strings.TrimPrefix(normalizedKey, "x-cql-sel-")
|
colName := strings.TrimPrefix(key, "x-cql-sel-")
|
||||||
options.ComputedQL[colName] = decodedValue
|
options.ComputedQL[colName] = decodedValue
|
||||||
case strings.HasPrefix(normalizedKey, "x-distinct"):
|
|
||||||
|
case strings.HasPrefix(key, "x-distinct"):
|
||||||
options.Distinct = strings.EqualFold(decodedValue, "true")
|
options.Distinct = strings.EqualFold(decodedValue, "true")
|
||||||
case strings.HasPrefix(normalizedKey, "x-skipcount"):
|
case strings.HasPrefix(key, "x-skipcount"):
|
||||||
options.SkipCount = strings.EqualFold(decodedValue, "true")
|
options.SkipCount = strings.EqualFold(decodedValue, "true")
|
||||||
case strings.HasPrefix(normalizedKey, "x-skipcache"):
|
case strings.HasPrefix(key, "x-skipcache"):
|
||||||
options.SkipCache = strings.EqualFold(decodedValue, "true")
|
options.SkipCache = strings.EqualFold(decodedValue, "true")
|
||||||
case strings.HasPrefix(normalizedKey, "x-fetch-rownumber"):
|
case strings.HasPrefix(key, "x-fetch-rownumber"):
|
||||||
options.FetchRowNumber = &decodedValue
|
options.FetchRowNumber = &decodedValue
|
||||||
case strings.HasPrefix(normalizedKey, "x-pkrow"):
|
case strings.HasPrefix(key, "x-pkrow"):
|
||||||
options.PKRow = &decodedValue
|
options.PKRow = &decodedValue
|
||||||
|
|
||||||
// Response Format
|
// Response Format
|
||||||
case strings.HasPrefix(normalizedKey, "x-simpleapi"):
|
case strings.HasPrefix(key, "x-simpleapi"):
|
||||||
options.ResponseFormat = "simple"
|
options.ResponseFormat = "simple"
|
||||||
case strings.HasPrefix(normalizedKey, "x-detailapi"):
|
case strings.HasPrefix(key, "x-detailapi"):
|
||||||
options.ResponseFormat = "detail"
|
options.ResponseFormat = "detail"
|
||||||
case strings.HasPrefix(normalizedKey, "x-syncfusion"):
|
case strings.HasPrefix(key, "x-syncfusion"):
|
||||||
options.ResponseFormat = "syncfusion"
|
options.ResponseFormat = "syncfusion"
|
||||||
case strings.HasPrefix(normalizedKey, "x-single-record-as-object"):
|
case strings.HasPrefix(key, "x-single-record-as-object"):
|
||||||
// Parse as boolean - "false" disables, "true" enables (default is true)
|
// Parse as boolean - "false" disables, "true" enables (default is true)
|
||||||
if strings.EqualFold(decodedValue, "false") {
|
if strings.EqualFold(decodedValue, "false") {
|
||||||
options.SingleRecordAsObject = false
|
options.SingleRecordAsObject = false
|
||||||
@@ -253,11 +262,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Transaction Control
|
// Transaction Control
|
||||||
case strings.HasPrefix(normalizedKey, "x-transaction-atomic"):
|
case strings.HasPrefix(key, "x-transaction-atomic"):
|
||||||
options.AtomicTransaction = strings.EqualFold(decodedValue, "true")
|
options.AtomicTransaction = strings.EqualFold(decodedValue, "true")
|
||||||
|
|
||||||
// X-Files - comprehensive JSON configuration
|
// X-Files - comprehensive JSON configuration
|
||||||
case strings.HasPrefix(normalizedKey, "x-files"):
|
case strings.HasPrefix(key, "x-files"):
|
||||||
h.parseXFiles(&options, decodedValue)
|
h.parseXFiles(&options, decodedValue)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -267,6 +276,12 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
h.resolveRelationNamesInOptions(&options, model)
|
h.resolveRelationNamesInOptions(&options, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always sort according to the primary key if no sorting is specified
|
||||||
|
if len(options.Sort) == 0 {
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||||
|
}
|
||||||
|
|
||||||
return options
|
return options
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -720,7 +735,7 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
|
|||||||
|
|
||||||
// Try to get the model type for the next level
|
// Try to get the model type for the next level
|
||||||
// This allows nested resolution
|
// This allows nested resolution
|
||||||
if nextModel := h.getRelationModel(currentModel, resolvedPart); nextModel != nil {
|
if nextModel := reflection.GetRelationModel(currentModel, resolvedPart); nextModel != nil {
|
||||||
currentModel = nextModel
|
currentModel = nextModel
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -744,58 +759,6 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRelationModel gets the model type for a relation field
|
|
||||||
func (h *Handler) getRelationModel(model interface{}, fieldName string) interface{} {
|
|
||||||
if model == nil || fieldName == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
|
||||||
if modelType == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelType.Kind() == reflect.Ptr {
|
|
||||||
modelType = modelType.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the field
|
|
||||||
field, found := modelType.FieldByName(fieldName)
|
|
||||||
if !found {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the target type
|
|
||||||
targetType := field.Type
|
|
||||||
if targetType == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Slice {
|
|
||||||
targetType = targetType.Elem()
|
|
||||||
if targetType == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
|
||||||
targetType = targetType.Elem()
|
|
||||||
if targetType == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetType.Kind() != reflect.Struct {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a zero value of the target type
|
|
||||||
return reflect.New(targetType).Elem().Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
// resolveRelationName resolves a relation name or table name to the actual field name in the model
|
// resolveRelationName resolves a relation name or table name to the actual field name in the model
|
||||||
// If the input is already a field name, it returns it as-is
|
// If the input is already a field name, it returns it as-is
|
||||||
// If the input is a table name, it looks up the corresponding relation field
|
// If the input is a table name, it looks up the corresponding relation field
|
||||||
@@ -971,6 +934,20 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
// Set recursive flag
|
// Set recursive flag
|
||||||
preloadOpt.Recursive = xfile.Recursive
|
preloadOpt.Recursive = xfile.Recursive
|
||||||
|
|
||||||
|
// Extract relationship keys for proper foreign key filtering
|
||||||
|
if xfile.PrimaryKey != "" {
|
||||||
|
preloadOpt.PrimaryKey = xfile.PrimaryKey
|
||||||
|
logger.Debug("X-Files: Set primary key for %s: %s", relationPath, xfile.PrimaryKey)
|
||||||
|
}
|
||||||
|
if xfile.RelatedKey != "" {
|
||||||
|
preloadOpt.RelatedKey = xfile.RelatedKey
|
||||||
|
logger.Debug("X-Files: Set related key for %s: %s", relationPath, xfile.RelatedKey)
|
||||||
|
}
|
||||||
|
if xfile.ForeignKey != "" {
|
||||||
|
preloadOpt.ForeignKey = xfile.ForeignKey
|
||||||
|
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
||||||
|
}
|
||||||
|
|
||||||
// Add the preload option
|
// Add the preload option
|
||||||
options.Preload = append(options.Preload, preloadOpt)
|
options.Preload = append(options.Preload, preloadOpt)
|
||||||
|
|
||||||
@@ -983,192 +960,6 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
|
|
||||||
// Examples:
|
|
||||||
// - "columna->>'val'" returns "columna"
|
|
||||||
// - "columna->'key'" returns "columna"
|
|
||||||
// - "columna" returns "columna"
|
|
||||||
// - "table.columna->>'val'" returns "table.columna"
|
|
||||||
func extractSourceColumn(colName string) string {
|
|
||||||
// Check for PostgreSQL JSON operators: -> and ->>
|
|
||||||
if idx := strings.Index(colName, "->>"); idx != -1 {
|
|
||||||
return strings.TrimSpace(colName[:idx])
|
|
||||||
}
|
|
||||||
if idx := strings.Index(colName, "->"); idx != -1 {
|
|
||||||
return strings.TrimSpace(colName[:idx])
|
|
||||||
}
|
|
||||||
return colName
|
|
||||||
}
|
|
||||||
|
|
||||||
// getColumnTypeFromModel uses reflection to determine the Go type of a column in a model
|
|
||||||
func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
|
||||||
if model == nil {
|
|
||||||
return reflect.Invalid
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract the source column name (remove JSON operators like ->> or ->)
|
|
||||||
sourceColName := extractSourceColumn(colName)
|
|
||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
|
||||||
// Dereference pointer if needed
|
|
||||||
if modelType.Kind() == reflect.Ptr {
|
|
||||||
modelType = modelType.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure it's a struct
|
|
||||||
if modelType.Kind() != reflect.Struct {
|
|
||||||
return reflect.Invalid
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the field by JSON tag or field name
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
|
||||||
field := modelType.Field(i)
|
|
||||||
|
|
||||||
// Check JSON tag
|
|
||||||
jsonTag := field.Tag.Get("json")
|
|
||||||
if jsonTag != "" {
|
|
||||||
// Parse JSON tag (format: "name,omitempty")
|
|
||||||
parts := strings.Split(jsonTag, ",")
|
|
||||||
if parts[0] == sourceColName {
|
|
||||||
return field.Type.Kind()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check field name (case-insensitive)
|
|
||||||
if strings.EqualFold(field.Name, sourceColName) {
|
|
||||||
return field.Type.Kind()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check snake_case conversion
|
|
||||||
snakeCaseName := toSnakeCase(field.Name)
|
|
||||||
if snakeCaseName == sourceColName {
|
|
||||||
return field.Type.Kind()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return reflect.Invalid
|
|
||||||
}
|
|
||||||
|
|
||||||
// toSnakeCase converts a string from CamelCase to snake_case
|
|
||||||
func toSnakeCase(s string) string {
|
|
||||||
var result strings.Builder
|
|
||||||
for i, r := range s {
|
|
||||||
if i > 0 && r >= 'A' && r <= 'Z' {
|
|
||||||
result.WriteRune('_')
|
|
||||||
}
|
|
||||||
result.WriteRune(r)
|
|
||||||
}
|
|
||||||
return strings.ToLower(result.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// isNumericType checks if a reflect.Kind is a numeric type
|
|
||||||
func isNumericType(kind reflect.Kind) bool {
|
|
||||||
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
|
|
||||||
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
|
|
||||||
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
|
|
||||||
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
|
|
||||||
}
|
|
||||||
|
|
||||||
// isStringType checks if a reflect.Kind is a string type
|
|
||||||
func isStringType(kind reflect.Kind) bool {
|
|
||||||
return kind == reflect.String
|
|
||||||
}
|
|
||||||
|
|
||||||
// convertToNumericType converts a string value to the appropriate numeric type
|
|
||||||
func convertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
|
|
||||||
value = strings.TrimSpace(value)
|
|
||||||
|
|
||||||
switch kind {
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
||||||
// Parse as integer
|
|
||||||
bitSize := 64
|
|
||||||
switch kind {
|
|
||||||
case reflect.Int8:
|
|
||||||
bitSize = 8
|
|
||||||
case reflect.Int16:
|
|
||||||
bitSize = 16
|
|
||||||
case reflect.Int32:
|
|
||||||
bitSize = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
intVal, err := strconv.ParseInt(value, 10, bitSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid integer value: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the appropriate type
|
|
||||||
switch kind {
|
|
||||||
case reflect.Int:
|
|
||||||
return int(intVal), nil
|
|
||||||
case reflect.Int8:
|
|
||||||
return int8(intVal), nil
|
|
||||||
case reflect.Int16:
|
|
||||||
return int16(intVal), nil
|
|
||||||
case reflect.Int32:
|
|
||||||
return int32(intVal), nil
|
|
||||||
case reflect.Int64:
|
|
||||||
return intVal, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
||||||
// Parse as unsigned integer
|
|
||||||
bitSize := 64
|
|
||||||
switch kind {
|
|
||||||
case reflect.Uint8:
|
|
||||||
bitSize = 8
|
|
||||||
case reflect.Uint16:
|
|
||||||
bitSize = 16
|
|
||||||
case reflect.Uint32:
|
|
||||||
bitSize = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
uintVal, err := strconv.ParseUint(value, 10, bitSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the appropriate type
|
|
||||||
switch kind {
|
|
||||||
case reflect.Uint:
|
|
||||||
return uint(uintVal), nil
|
|
||||||
case reflect.Uint8:
|
|
||||||
return uint8(uintVal), nil
|
|
||||||
case reflect.Uint16:
|
|
||||||
return uint16(uintVal), nil
|
|
||||||
case reflect.Uint32:
|
|
||||||
return uint32(uintVal), nil
|
|
||||||
case reflect.Uint64:
|
|
||||||
return uintVal, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Float32, reflect.Float64:
|
|
||||||
// Parse as float
|
|
||||||
bitSize := 64
|
|
||||||
if kind == reflect.Float32 {
|
|
||||||
bitSize = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
floatVal, err := strconv.ParseFloat(value, bitSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid float value: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if kind == reflect.Float32 {
|
|
||||||
return float32(floatVal), nil
|
|
||||||
}
|
|
||||||
return floatVal, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
|
|
||||||
}
|
|
||||||
|
|
||||||
// isNumericValue checks if a string value can be parsed as a number
|
|
||||||
func isNumericValue(value string) bool {
|
|
||||||
value = strings.TrimSpace(value)
|
|
||||||
_, err := strconv.ParseFloat(value, 64)
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ColumnCastInfo holds information about whether a column needs casting
|
// ColumnCastInfo holds information about whether a column needs casting
|
||||||
type ColumnCastInfo struct {
|
type ColumnCastInfo struct {
|
||||||
NeedsCast bool
|
NeedsCast bool
|
||||||
@@ -1182,7 +973,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
|
|||||||
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
||||||
}
|
}
|
||||||
|
|
||||||
colType := h.getColumnTypeFromModel(model, filter.Column)
|
colType := reflection.GetColumnTypeFromModel(model, filter.Column)
|
||||||
if colType == reflect.Invalid {
|
if colType == reflect.Invalid {
|
||||||
// Column not found in model, no casting needed
|
// Column not found in model, no casting needed
|
||||||
logger.Debug("Column %s not found in model, skipping type validation", filter.Column)
|
logger.Debug("Column %s not found in model, skipping type validation", filter.Column)
|
||||||
@@ -1193,18 +984,18 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
|
|||||||
valueIsNumeric := false
|
valueIsNumeric := false
|
||||||
if strVal, ok := filter.Value.(string); ok {
|
if strVal, ok := filter.Value.(string); ok {
|
||||||
strVal = strings.Trim(strVal, "%")
|
strVal = strings.Trim(strVal, "%")
|
||||||
valueIsNumeric = isNumericValue(strVal)
|
valueIsNumeric = reflection.IsNumericValue(strVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adjust based on column type
|
// Adjust based on column type
|
||||||
switch {
|
switch {
|
||||||
case isNumericType(colType):
|
case reflection.IsNumericType(colType):
|
||||||
// Column is numeric
|
// Column is numeric
|
||||||
if valueIsNumeric {
|
if valueIsNumeric {
|
||||||
// Value is numeric - try to convert it
|
// Value is numeric - try to convert it
|
||||||
if strVal, ok := filter.Value.(string); ok {
|
if strVal, ok := filter.Value.(string); ok {
|
||||||
strVal = strings.Trim(strVal, "%")
|
strVal = strings.Trim(strVal, "%")
|
||||||
numericVal, err := convertToNumericType(strVal, colType)
|
numericVal, err := reflection.ConvertToNumericType(strVal, colType)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debug("Failed to convert value '%s' to numeric type for column %s, will use text cast", strVal, filter.Column)
|
logger.Debug("Failed to convert value '%s' to numeric type for column %s, will use text cast", strVal, filter.Column)
|
||||||
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
|
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
|
||||||
@@ -1219,7 +1010,7 @@ func (h *Handler) ValidateAndAdjustFilterForColumnType(filter *common.FilterOpti
|
|||||||
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
|
return ColumnCastInfo{NeedsCast: true, IsNumericType: true}
|
||||||
}
|
}
|
||||||
|
|
||||||
case isStringType(colType):
|
case reflection.IsStringType(colType):
|
||||||
// String columns don't need casting
|
// String columns don't need casting
|
||||||
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
return ColumnCastInfo{NeedsCast: false, IsNumericType: false}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user