mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 15:54:26 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e1abd5ebc1 | ||
|
|
ca4e53969b |
@@ -13,7 +13,7 @@ func TestNormalizeTableAlias(t *testing.T) {
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "strips incorrect alias from simple condition",
|
||||
name: "strips plausible alias from simple condition",
|
||||
query: "APIL.rid_hub = 2576",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
@@ -27,14 +27,14 @@ func TestNormalizeTableAlias(t *testing.T) {
|
||||
want: "apiproviderlink.rid_hub = 2576",
|
||||
},
|
||||
{
|
||||
name: "strips incorrect alias with multiple conditions",
|
||||
name: "strips plausible alias with multiple conditions",
|
||||
query: "APIL.rid_hub = ? AND APIL.active = ?",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
want: "rid_hub = ? AND active = ?",
|
||||
},
|
||||
{
|
||||
name: "handles mixed correct and incorrect aliases",
|
||||
name: "handles mixed correct and plausible aliases",
|
||||
query: "APIL.rid_hub = ? AND apiproviderlink.active = ?",
|
||||
expectedAlias: "apiproviderlink",
|
||||
tableName: "apiproviderlink",
|
||||
@@ -54,6 +54,20 @@ func TestNormalizeTableAlias(t *testing.T) {
|
||||
tableName: "apiproviderlink",
|
||||
want: "rid_hub = ?",
|
||||
},
|
||||
{
|
||||
name: "keeps reference to different table (not in current table name)",
|
||||
query: "APIL.rid_hub = ?",
|
||||
expectedAlias: "apiprovider",
|
||||
tableName: "apiprovider",
|
||||
want: "APIL.rid_hub = ?",
|
||||
},
|
||||
{
|
||||
name: "keeps reference with short prefix that might be ambiguous",
|
||||
query: "AP.rid = ?",
|
||||
expectedAlias: "apiprovider",
|
||||
tableName: "apiprovider",
|
||||
want: "AP.rid = ?",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -252,17 +252,32 @@ func isOperatorOrKeyword(s string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// isAcronymMatch checks if prefix is an acronym of tableName
|
||||
// For example, "apil" matches "apiproviderlink" because each letter appears in sequence
|
||||
func isAcronymMatch(prefix, tableName string) bool {
|
||||
if len(prefix) == 0 || len(tableName) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
prefixIdx := 0
|
||||
for i := 0; i < len(tableName) && prefixIdx < len(prefix); i++ {
|
||||
if tableName[i] == prefix[prefixIdx] {
|
||||
prefixIdx++
|
||||
}
|
||||
}
|
||||
|
||||
// All characters of prefix were found in sequence in tableName
|
||||
return prefixIdx == len(prefix)
|
||||
}
|
||||
|
||||
// normalizeTableAlias replaces table alias prefixes in SQL conditions
|
||||
// This handles cases where a user references a table alias that doesn't match
|
||||
// what Bun generates (common in preload contexts)
|
||||
func normalizeTableAlias(query, expectedAlias, tableName string) string {
|
||||
// Pattern: <word>.<column> where <word> might be an incorrect alias
|
||||
// We'll look for patterns like "APIL.column" and either:
|
||||
// 1. Remove the alias prefix entirely (safest)
|
||||
// 2. Replace with the expected alias
|
||||
|
||||
// For now, we'll use a simple approach: if the query contains a dot (qualified reference)
|
||||
// and that prefix is not the expected alias or table name, strip it
|
||||
// 1. Remove the alias prefix if it's clearly meant for this table
|
||||
// 2. Leave it alone if it might be referring to another table (JOIN/preload)
|
||||
|
||||
// Split on spaces and parentheses to find qualified references
|
||||
parts := strings.FieldsFunc(query, func(r rune) bool {
|
||||
@@ -277,13 +292,39 @@ func normalizeTableAlias(query, expectedAlias, tableName string) string {
|
||||
column := part[dotIndex+1:]
|
||||
|
||||
// Check if the prefix matches our expected alias or table name (case-insensitive)
|
||||
if !strings.EqualFold(prefix, expectedAlias) &&
|
||||
!strings.EqualFold(prefix, tableName) &&
|
||||
!strings.EqualFold(prefix, strings.ToLower(tableName)) {
|
||||
// This is a different alias - remove the prefix
|
||||
logger.Debug("Stripping incorrect alias '%s' from WHERE condition, keeping just '%s'", prefix, column)
|
||||
if strings.EqualFold(prefix, expectedAlias) ||
|
||||
strings.EqualFold(prefix, tableName) ||
|
||||
strings.EqualFold(prefix, strings.ToLower(tableName)) {
|
||||
// Prefix matches current table, it's safe but redundant - leave it
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if the prefix could plausibly be an alias/acronym for this table
|
||||
// Only strip if we're confident it's meant for this table
|
||||
// For example: "APIL" could be an acronym for "apiproviderlink"
|
||||
prefixLower := strings.ToLower(prefix)
|
||||
tableNameLower := strings.ToLower(tableName)
|
||||
|
||||
// Check if prefix is a substring of table name
|
||||
isSubstring := strings.Contains(tableNameLower, prefixLower) && len(prefixLower) > 2
|
||||
|
||||
// Check if prefix is an acronym of table name
|
||||
// e.g., "APIL" matches "ApiProviderLink" (A-p-I-providerL-ink)
|
||||
isAcronym := false
|
||||
if !isSubstring && len(prefixLower) > 2 {
|
||||
isAcronym = isAcronymMatch(prefixLower, tableNameLower)
|
||||
}
|
||||
|
||||
if isSubstring || isAcronym {
|
||||
// This looks like it could be an alias for this table - strip it
|
||||
logger.Debug("Stripping plausible alias '%s' from WHERE condition, keeping just '%s'", prefix, column)
|
||||
// Replace the qualified reference with just the column name
|
||||
modified = strings.ReplaceAll(modified, part, column)
|
||||
} else {
|
||||
// Prefix doesn't match the current table at all
|
||||
// It's likely referring to a different table (JOIN/preload)
|
||||
// DON'T strip it - leave the qualified reference as-is
|
||||
logger.Debug("Keeping qualified reference '%s' - prefix '%s' doesn't match current table '%s'", part, prefix, tableName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,17 +78,22 @@ func IsTrivialCondition(cond string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
|
||||
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
|
||||
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
||||
//
|
||||
// Parameters:
|
||||
// - where: The WHERE clause string to sanitize
|
||||
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
|
||||
// - tableName: The correct table/relation name to use when fixing incorrect prefixes
|
||||
// - options: Optional RequestOptions containing preload relations that should be allowed as valid prefixes
|
||||
//
|
||||
// Returns:
|
||||
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
|
||||
// - The sanitized WHERE clause with trivial conditions removed and incorrect prefixes fixed
|
||||
// - An empty string if all conditions were trivial or the input was empty
|
||||
func SanitizeWhereClause(where string, tableName string) string {
|
||||
//
|
||||
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
|
||||
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
|
||||
// prefix matches a preloaded relation name, in which case it's left unchanged.
|
||||
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
|
||||
if where == "" {
|
||||
return ""
|
||||
}
|
||||
@@ -104,6 +109,22 @@ func SanitizeWhereClause(where string, tableName string) string {
|
||||
validColumns = getValidColumnsForTable(tableName)
|
||||
}
|
||||
|
||||
// Build a set of allowed table prefixes (main table + preloaded relations)
|
||||
allowedPrefixes := make(map[string]bool)
|
||||
if tableName != "" {
|
||||
allowedPrefixes[tableName] = true
|
||||
}
|
||||
|
||||
// Add preload relation names as allowed prefixes
|
||||
if len(options) > 0 && options[0] != nil {
|
||||
for pi := range options[0].Preload {
|
||||
if options[0].Preload[pi].Relation != "" {
|
||||
allowedPrefixes[options[0].Preload[pi].Relation] = true
|
||||
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Split by AND to handle multiple conditions
|
||||
conditions := splitByAND(where)
|
||||
|
||||
@@ -124,22 +145,23 @@ func SanitizeWhereClause(where string, tableName string) string {
|
||||
continue
|
||||
}
|
||||
|
||||
// If tableName is provided and the condition doesn't already have a table prefix,
|
||||
// attempt to add it
|
||||
if tableName != "" && !hasTablePrefix(condToCheck) {
|
||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||
if !IsSQLExpression(strings.ToLower(condToCheck)) {
|
||||
// Extract the column name and prefix it
|
||||
columnName := ExtractColumnName(condToCheck)
|
||||
if columnName != "" {
|
||||
// Only prefix if this is a valid column in the model
|
||||
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
|
||||
// If tableName is provided and the condition HAS a table prefix, check if it's correct
|
||||
if tableName != "" && hasTablePrefix(condToCheck) {
|
||||
// Extract the current prefix and column name
|
||||
currentPrefix, columnName := extractTableAndColumn(condToCheck)
|
||||
|
||||
if currentPrefix != "" && columnName != "" {
|
||||
// Check if the prefix is allowed (main table or preload relation)
|
||||
if !allowedPrefixes[currentPrefix] {
|
||||
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
|
||||
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||
// Replace in the original condition (without stripped parens)
|
||||
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
|
||||
logger.Debug("Prefixed column in condition: '%s'", cond)
|
||||
// Replace the incorrect prefix with the correct main table name
|
||||
oldRef := currentPrefix + "." + columnName
|
||||
newRef := tableName + "." + columnName
|
||||
cond = strings.Replace(cond, oldRef, newRef, 1)
|
||||
logger.Debug("Fixed incorrect table prefix in condition: '%s' -> '%s'", oldRef, newRef)
|
||||
} else {
|
||||
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
|
||||
logger.Debug("Skipping prefix fix for '%s.%s' - not a valid column in main table (might be preload relation)", currentPrefix, columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -288,6 +310,53 @@ func getValidColumnsForTable(tableName string) map[string]bool {
|
||||
return columnMap
|
||||
}
|
||||
|
||||
// extractTableAndColumn extracts the table prefix and column name from a qualified reference
|
||||
// For example: "users.status = 'active'" returns ("users", "status")
|
||||
// Returns empty strings if no table prefix is found
|
||||
func extractTableAndColumn(cond string) (table string, column string) {
|
||||
// Common SQL operators to find the column reference
|
||||
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
||||
|
||||
var columnRef string
|
||||
|
||||
// Find the column reference (left side of the operator)
|
||||
for _, op := range operators {
|
||||
if idx := strings.Index(cond, op); idx > 0 {
|
||||
columnRef = strings.TrimSpace(cond[:idx])
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no operator found, the whole condition might be the column reference
|
||||
if columnRef == "" {
|
||||
parts := strings.Fields(cond)
|
||||
if len(parts) > 0 {
|
||||
columnRef = parts[0]
|
||||
}
|
||||
}
|
||||
|
||||
if columnRef == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Remove any quotes
|
||||
columnRef = strings.Trim(columnRef, "`\"'")
|
||||
|
||||
// Check if it contains a dot (qualified reference)
|
||||
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
|
||||
table = columnRef[:dotIdx]
|
||||
column = columnRef[dotIdx+1:]
|
||||
|
||||
// Remove quotes from table and column if present
|
||||
table = strings.Trim(table, "`\"'")
|
||||
column = strings.Trim(column, "`\"'")
|
||||
|
||||
return table, column
|
||||
}
|
||||
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// isValidColumn checks if a column name exists in the valid columns map
|
||||
// Handles case-insensitive comparison
|
||||
func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||
|
||||
@@ -32,29 +32,41 @@ func TestSanitizeWhereClause(t *testing.T) {
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "valid condition with parentheses",
|
||||
name: "valid condition with parentheses - no prefix added",
|
||||
where: "(status = 'active')",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
expected: "status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "mixed trivial and valid conditions",
|
||||
name: "mixed trivial and valid conditions - no prefix added",
|
||||
where: "true AND status = 'active' AND 1=1",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
expected: "status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "condition already with table prefix",
|
||||
name: "condition with correct table prefix - unchanged",
|
||||
where: "users.status = 'active'",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "multiple valid conditions",
|
||||
where: "status = 'active' AND age > 18",
|
||||
name: "condition with incorrect table prefix - fixed",
|
||||
where: "wrong_table.status = 'active'",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "multiple conditions with incorrect prefix - fixed",
|
||||
where: "wrong_table.status = 'active' AND wrong_table.age > 18",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' AND users.age > 18",
|
||||
},
|
||||
{
|
||||
name: "multiple valid conditions without prefix - no prefix added",
|
||||
where: "status = 'active' AND age > 18",
|
||||
tableName: "users",
|
||||
expected: "status = 'active' AND age > 18",
|
||||
},
|
||||
{
|
||||
name: "no table name provided",
|
||||
where: "status = 'active'",
|
||||
@@ -67,6 +79,12 @@ func TestSanitizeWhereClause(t *testing.T) {
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "mixed correct and incorrect prefixes",
|
||||
where: "users.status = 'active' AND wrong_table.age > 18",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' AND users.age > 18",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -159,6 +177,158 @@ func TestIsTrivialCondition(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTableAndColumn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedTable string
|
||||
expectedCol string
|
||||
}{
|
||||
{
|
||||
name: "qualified column with equals",
|
||||
input: "users.status = 'active'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "status",
|
||||
},
|
||||
{
|
||||
name: "qualified column with greater than",
|
||||
input: "users.age > 18",
|
||||
expectedTable: "users",
|
||||
expectedCol: "age",
|
||||
},
|
||||
{
|
||||
name: "qualified column with LIKE",
|
||||
input: "users.name LIKE '%john%'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "name",
|
||||
},
|
||||
{
|
||||
name: "qualified column with IN",
|
||||
input: "users.status IN ('active', 'pending')",
|
||||
expectedTable: "users",
|
||||
expectedCol: "status",
|
||||
},
|
||||
{
|
||||
name: "unqualified column",
|
||||
input: "status = 'active'",
|
||||
expectedTable: "",
|
||||
expectedCol: "",
|
||||
},
|
||||
{
|
||||
name: "qualified with backticks",
|
||||
input: "`users`.`status` = 'active'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "status",
|
||||
},
|
||||
{
|
||||
name: "schema.table.column reference",
|
||||
input: "public.users.status = 'active'",
|
||||
expectedTable: "public.users",
|
||||
expectedCol: "status",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expectedTable: "",
|
||||
expectedCol: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
table, col := extractTableAndColumn(tt.input)
|
||||
if table != tt.expectedTable || col != tt.expectedCol {
|
||||
t.Errorf("extractTableAndColumn(%q) = (%q, %q); want (%q, %q)",
|
||||
tt.input, table, col, tt.expectedTable, tt.expectedCol)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
options *RequestOptions
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "preload relation prefix is preserved",
|
||||
where: "Department.name = 'Engineering'",
|
||||
tableName: "users",
|
||||
options: &RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{Relation: "Department"},
|
||||
},
|
||||
},
|
||||
expected: "Department.name = 'Engineering'",
|
||||
},
|
||||
{
|
||||
name: "multiple preload relations - all preserved",
|
||||
where: "Department.name = 'Engineering' AND Manager.status = 'active'",
|
||||
tableName: "users",
|
||||
options: &RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{Relation: "Department"},
|
||||
{Relation: "Manager"},
|
||||
},
|
||||
},
|
||||
expected: "Department.name = 'Engineering' AND Manager.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "mix of main table and preload relation",
|
||||
where: "users.status = 'active' AND Department.name = 'Engineering'",
|
||||
tableName: "users",
|
||||
options: &RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{Relation: "Department"},
|
||||
},
|
||||
},
|
||||
expected: "users.status = 'active' AND Department.name = 'Engineering'",
|
||||
},
|
||||
{
|
||||
name: "incorrect prefix fixed when not a preload relation",
|
||||
where: "wrong_table.status = 'active' AND Department.name = 'Engineering'",
|
||||
tableName: "users",
|
||||
options: &RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{Relation: "Department"},
|
||||
},
|
||||
},
|
||||
expected: "users.status = 'active' AND Department.name = 'Engineering'",
|
||||
},
|
||||
{
|
||||
name: "no options provided - works as before",
|
||||
where: "wrong_table.status = 'active'",
|
||||
tableName: "users",
|
||||
options: nil,
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "empty preload list - works as before",
|
||||
where: "wrong_table.status = 'active'",
|
||||
tableName: "users",
|
||||
options: &RequestOptions{Preload: []PreloadOption{}},
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var result string
|
||||
if tt.options != nil {
|
||||
result = SanitizeWhereClause(tt.where, tt.tableName, tt.options)
|
||||
} else {
|
||||
result = SanitizeWhereClause(tt.where, tt.tableName)
|
||||
}
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q, options) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test model for model-aware sanitization tests
|
||||
type MasterTask struct {
|
||||
ID int `bun:"id,pk"`
|
||||
@@ -182,34 +352,52 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "valid column gets prefixed",
|
||||
name: "valid column without prefix - no prefix added",
|
||||
where: "status = 'active'",
|
||||
tableName: "mastertask",
|
||||
expected: "status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "multiple valid columns without prefix - no prefix added",
|
||||
where: "status = 'active' AND user_id = 123",
|
||||
tableName: "mastertask",
|
||||
expected: "status = 'active' AND user_id = 123",
|
||||
},
|
||||
{
|
||||
name: "incorrect table prefix on valid column - fixed",
|
||||
where: "wrong_table.status = 'active'",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "multiple valid columns get prefixed",
|
||||
where: "status = 'active' AND user_id = 123",
|
||||
name: "incorrect prefix on invalid column - not fixed",
|
||||
where: "wrong_table.invalid_column = 'value'",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
|
||||
},
|
||||
{
|
||||
name: "invalid column does not get prefixed",
|
||||
where: "invalid_column = 'value'",
|
||||
tableName: "mastertask",
|
||||
expected: "invalid_column = 'value'",
|
||||
expected: "wrong_table.invalid_column = 'value'",
|
||||
},
|
||||
{
|
||||
name: "mix of valid and trivial conditions",
|
||||
where: "true AND status = 'active' AND 1=1",
|
||||
tableName: "mastertask",
|
||||
expected: "status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "parentheses with valid column - no prefix added",
|
||||
where: "(status = 'active')",
|
||||
tableName: "mastertask",
|
||||
expected: "status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "correct prefix - unchanged",
|
||||
where: "mastertask.status = 'active'",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "parentheses with valid column",
|
||||
where: "(status = 'active')",
|
||||
name: "multiple conditions with mixed prefixes",
|
||||
where: "mastertask.status = 'active' AND wrong_table.user_id = 123",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
331
pkg/reflection/generic_model_test.go
Normal file
331
pkg/reflection/generic_model_test.go
Normal file
@@ -0,0 +1,331 @@
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Test models for GetModelColumnDetail
|
||||
type TestModelForColumnDetail struct {
|
||||
ID int `gorm:"column:rid_test;primaryKey;type:bigserial;not null" json:"id"`
|
||||
Name string `gorm:"column:name;type:varchar(255);not null" json:"name"`
|
||||
Email string `gorm:"column:email;type:varchar(255);unique;nullable" json:"email"`
|
||||
Description string `gorm:"column:description;type:text;null" json:"description"`
|
||||
ForeignKey int `gorm:"foreignKey:parent_id" json:"foreign_key"`
|
||||
}
|
||||
|
||||
type EmbeddedBase struct {
|
||||
ID int `gorm:"column:rid_base;primaryKey;identity" json:"id"`
|
||||
CreatedAt string `gorm:"column:created_at;type:timestamp" json:"created_at"`
|
||||
}
|
||||
|
||||
type ModelWithEmbeddedForDetail struct {
|
||||
EmbeddedBase
|
||||
Title string `gorm:"column:title;type:varchar(100);not null" json:"title"`
|
||||
Content string `gorm:"column:content;type:text" json:"content"`
|
||||
}
|
||||
|
||||
// Model with nil embedded pointer
|
||||
type ModelWithNilEmbedded struct {
|
||||
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||
*EmbeddedBase
|
||||
Name string `gorm:"column:name" json:"name"`
|
||||
}
|
||||
|
||||
func TestGetModelColumnDetail(t *testing.T) {
|
||||
t.Run("simple struct", func(t *testing.T) {
|
||||
model := TestModelForColumnDetail{
|
||||
ID: 1,
|
||||
Name: "Test",
|
||||
Email: "test@example.com",
|
||||
Description: "Test description",
|
||||
ForeignKey: 100,
|
||||
}
|
||||
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
if len(details) != 5 {
|
||||
t.Errorf("Expected 5 fields, got %d", len(details))
|
||||
}
|
||||
|
||||
// Check ID field
|
||||
found := false
|
||||
for _, detail := range details {
|
||||
if detail.Name == "ID" {
|
||||
found = true
|
||||
if detail.SQLName != "rid_test" {
|
||||
t.Errorf("Expected SQLName 'rid_test', got '%s'", detail.SQLName)
|
||||
}
|
||||
// Note: primaryKey (without underscore) is not detected as primary_key
|
||||
// The function looks for "identity" or "primary_key" (with underscore)
|
||||
if detail.SQLDataType != "bigserial" {
|
||||
t.Errorf("Expected SQLDataType 'bigserial', got '%s'", detail.SQLDataType)
|
||||
}
|
||||
if detail.Nullable {
|
||||
t.Errorf("Expected Nullable false, got true")
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("ID field not found in details")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("struct with embedded fields", func(t *testing.T) {
|
||||
model := ModelWithEmbeddedForDetail{
|
||||
EmbeddedBase: EmbeddedBase{
|
||||
ID: 1,
|
||||
CreatedAt: "2024-01-01",
|
||||
},
|
||||
Title: "Test Title",
|
||||
Content: "Test Content",
|
||||
}
|
||||
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
// Should have 4 fields: ID, CreatedAt from embedded, Title, Content from main
|
||||
if len(details) != 4 {
|
||||
t.Errorf("Expected 4 fields, got %d", len(details))
|
||||
}
|
||||
|
||||
// Check that embedded field is included
|
||||
foundID := false
|
||||
foundCreatedAt := false
|
||||
for _, detail := range details {
|
||||
if detail.Name == "ID" {
|
||||
foundID = true
|
||||
if detail.SQLKey != "primary_key" {
|
||||
t.Errorf("Expected SQLKey 'primary_key' for embedded ID, got '%s'", detail.SQLKey)
|
||||
}
|
||||
}
|
||||
if detail.Name == "CreatedAt" {
|
||||
foundCreatedAt = true
|
||||
}
|
||||
}
|
||||
if !foundID {
|
||||
t.Errorf("Embedded ID field not found")
|
||||
}
|
||||
if !foundCreatedAt {
|
||||
t.Errorf("Embedded CreatedAt field not found")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil embedded pointer is skipped", func(t *testing.T) {
|
||||
model := ModelWithNilEmbedded{
|
||||
ID: 1,
|
||||
Name: "Test",
|
||||
EmbeddedBase: nil, // nil embedded pointer
|
||||
}
|
||||
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
// Should have 2 fields: ID and Name (embedded is nil, so skipped)
|
||||
if len(details) != 2 {
|
||||
t.Errorf("Expected 2 fields (nil embedded skipped), got %d", len(details))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("pointer to struct", func(t *testing.T) {
|
||||
model := &TestModelForColumnDetail{
|
||||
ID: 1,
|
||||
Name: "Test",
|
||||
}
|
||||
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
if len(details) != 5 {
|
||||
t.Errorf("Expected 5 fields, got %d", len(details))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid value", func(t *testing.T) {
|
||||
var invalid reflect.Value
|
||||
details := GetModelColumnDetail(invalid)
|
||||
|
||||
if len(details) != 0 {
|
||||
t.Errorf("Expected 0 fields for invalid value, got %d", len(details))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-struct type", func(t *testing.T) {
|
||||
details := GetModelColumnDetail(reflect.ValueOf(123))
|
||||
|
||||
if len(details) != 0 {
|
||||
t.Errorf("Expected 0 fields for non-struct, got %d", len(details))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nullable and not null detection", func(t *testing.T) {
|
||||
model := TestModelForColumnDetail{}
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
for _, detail := range details {
|
||||
switch detail.Name {
|
||||
case "ID":
|
||||
if detail.Nullable {
|
||||
t.Errorf("ID should not be nullable (has 'not null')")
|
||||
}
|
||||
case "Name":
|
||||
if detail.Nullable {
|
||||
t.Errorf("Name should not be nullable (has 'not null')")
|
||||
}
|
||||
case "Email":
|
||||
if !detail.Nullable {
|
||||
t.Errorf("Email should be nullable (has 'nullable')")
|
||||
}
|
||||
case "Description":
|
||||
if !detail.Nullable {
|
||||
t.Errorf("Description should be nullable (has 'null')")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unique and uniqueindex detection", func(t *testing.T) {
|
||||
type UniqueTestModel struct {
|
||||
ID int `gorm:"column:id;primary_key"`
|
||||
Username string `gorm:"column:username;unique"`
|
||||
Email string `gorm:"column:email;uniqueindex"`
|
||||
}
|
||||
|
||||
model := UniqueTestModel{}
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
for _, detail := range details {
|
||||
switch detail.Name {
|
||||
case "ID":
|
||||
if detail.SQLKey != "primary_key" {
|
||||
t.Errorf("ID should have SQLKey 'primary_key', got '%s'", detail.SQLKey)
|
||||
}
|
||||
case "Username":
|
||||
if detail.SQLKey != "unique" {
|
||||
t.Errorf("Username should have SQLKey 'unique', got '%s'", detail.SQLKey)
|
||||
}
|
||||
case "Email":
|
||||
// The function checks for "unique" first, so uniqueindex is also detected as "unique"
|
||||
// This is expected behavior based on the code logic
|
||||
if detail.SQLKey != "unique" {
|
||||
t.Errorf("Email should have SQLKey 'unique' (uniqueindex contains 'unique'), got '%s'", detail.SQLKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("foreign key detection", func(t *testing.T) {
|
||||
// Note: The foreignkey extraction in generic_model.go has a bug where
|
||||
// it requires ik > 0, so foreignkey at the start won't extract the value
|
||||
type FKTestModel struct {
|
||||
ParentID int `gorm:"column:parent_id;foreignkey:rid_parent;association_foreignkey:id_atevent"`
|
||||
}
|
||||
|
||||
model := FKTestModel{}
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
if len(details) == 0 {
|
||||
t.Fatal("Expected at least 1 field")
|
||||
}
|
||||
|
||||
detail := details[0]
|
||||
if detail.SQLKey != "foreign_key" {
|
||||
t.Errorf("Expected SQLKey 'foreign_key', got '%s'", detail.SQLKey)
|
||||
}
|
||||
// Due to the bug in the code (requires ik > 0), the SQLName will be extracted
|
||||
// when foreignkey is not at the beginning of the string
|
||||
if detail.SQLName != "rid_parent" {
|
||||
t.Errorf("Expected SQLName 'rid_parent', got '%s'", detail.SQLName)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFnFindKeyVal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
src string
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "find column",
|
||||
src: "column:user_id;primaryKey;type:bigint",
|
||||
key: "column:",
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
name: "find type",
|
||||
src: "column:name;type:varchar(255);not null",
|
||||
key: "type:",
|
||||
expected: "varchar(255)",
|
||||
},
|
||||
{
|
||||
name: "key not found",
|
||||
src: "primaryKey;autoIncrement",
|
||||
key: "column:",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "key at end without semicolon",
|
||||
src: "primaryKey;column:id",
|
||||
key: "column:",
|
||||
expected: "id",
|
||||
},
|
||||
{
|
||||
name: "case insensitive search",
|
||||
src: "Column:user_id;primaryKey",
|
||||
key: "column:",
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
name: "empty src",
|
||||
src: "",
|
||||
key: "column:",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "multiple occurrences (returns first)",
|
||||
src: "column:first;column:second",
|
||||
key: "column:",
|
||||
expected: "first",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := fnFindKeyVal(tt.src, tt.key)
|
||||
if result != tt.expected {
|
||||
t.Errorf("fnFindKeyVal(%q, %q) = %q, want %q", tt.src, tt.key, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelColumnDetail_FieldValue(t *testing.T) {
|
||||
model := TestModelForColumnDetail{
|
||||
ID: 123,
|
||||
Name: "TestName",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
|
||||
details := GetModelColumnDetail(reflect.ValueOf(model))
|
||||
|
||||
for _, detail := range details {
|
||||
if !detail.FieldValue.IsValid() {
|
||||
t.Errorf("Field %s has invalid FieldValue", detail.Name)
|
||||
}
|
||||
|
||||
// Check that FieldValue matches the actual value
|
||||
switch detail.Name {
|
||||
case "ID":
|
||||
if detail.FieldValue.Int() != 123 {
|
||||
t.Errorf("Expected ID FieldValue 123, got %v", detail.FieldValue.Int())
|
||||
}
|
||||
case "Name":
|
||||
if detail.FieldValue.String() != "TestName" {
|
||||
t.Errorf("Expected Name FieldValue 'TestName', got %v", detail.FieldValue.String())
|
||||
}
|
||||
case "Email":
|
||||
if detail.FieldValue.String() != "test@example.com" {
|
||||
t.Errorf("Expected Email FieldValue 'test@example.com', got %v", detail.FieldValue.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -316,7 +316,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply cursor filter to query
|
||||
if cursorFilter != "" {
|
||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName))
|
||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
||||
if sanitizedCursor != "" {
|
||||
query = query.Where(sanitizedCursor)
|
||||
}
|
||||
@@ -1351,7 +1351,9 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
}
|
||||
|
||||
if len(preload.Where) > 0 {
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||
// Build RequestOptions with all preloads to allow references to sibling relations
|
||||
preloadOpts := &common.RequestOptions{Preload: preloads}
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
||||
if len(sanitizedWhere) > 0 {
|
||||
sq = sq.Where(sanitizedWhere)
|
||||
}
|
||||
|
||||
@@ -450,7 +450,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
|
||||
// Apply the preload with recursive support
|
||||
query = h.applyPreloadWithRecursion(query, preload, model, 0)
|
||||
query = h.applyPreloadWithRecursion(query, preload, options.Preload, model, 0)
|
||||
}
|
||||
|
||||
// Apply DISTINCT if requested
|
||||
@@ -480,8 +480,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply custom SQL WHERE clause (AND condition)
|
||||
if options.CustomSQLWhere != "" {
|
||||
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
||||
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
||||
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
if sanitizedWhere != "" {
|
||||
query = query.Where(sanitizedWhere)
|
||||
}
|
||||
@@ -490,8 +490,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply custom SQL WHERE clause (OR condition)
|
||||
if options.CustomSQLOr != "" {
|
||||
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
||||
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
||||
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
if sanitizedOr != "" {
|
||||
query = query.WhereOr(sanitizedOr)
|
||||
}
|
||||
@@ -625,7 +625,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply cursor filter to query
|
||||
if cursorFilter != "" {
|
||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName))
|
||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
if sanitizedCursor != "" {
|
||||
query = query.Where(sanitizedCursor)
|
||||
}
|
||||
@@ -703,7 +703,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
|
||||
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
||||
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery {
|
||||
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, allPreloads []common.PreloadOption, model interface{}, depth int) common.SelectQuery {
|
||||
// Log relationship keys if they're specified (from XFiles)
|
||||
if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" {
|
||||
logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s",
|
||||
@@ -799,7 +799,9 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
|
||||
// Apply WHERE clause
|
||||
if len(preload.Where) > 0 {
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||
// Build RequestOptions with all preloads to allow references to sibling relations
|
||||
preloadOpts := &common.RequestOptions{Preload: allPreloads}
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
||||
if len(sanitizedWhere) > 0 {
|
||||
sq = sq.Where(sanitizedWhere)
|
||||
}
|
||||
@@ -832,7 +834,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
recursivePreload.Relation = preload.Relation + "." + lastRelationName
|
||||
|
||||
// Recursively apply preload until we reach depth 5
|
||||
query = h.applyPreloadWithRecursion(query, recursivePreload, model, depth+1)
|
||||
query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1)
|
||||
}
|
||||
|
||||
return query
|
||||
|
||||
Reference in New Issue
Block a user