mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-31 17:28:58 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa99e8e4bc | ||
|
|
163593901f | ||
|
|
1261960e97 | ||
|
|
76bbf33db2 | ||
|
|
02c9b96b0c | ||
|
|
9a3564f05f | ||
|
|
a931b8cdd2 |
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@@ -105,6 +106,14 @@ type BunSelectQuery struct {
|
|||||||
schema string // Separated schema name
|
schema string // Separated schema name
|
||||||
tableName string // Just the table name, without schema
|
tableName string // Just the table name, without schema
|
||||||
tableAlias string
|
tableAlias string
|
||||||
|
deferredPreloads []deferredPreload // Preloads to execute as separate queries
|
||||||
|
}
|
||||||
|
|
||||||
|
// deferredPreload represents a preload that will be executed as a separate query
|
||||||
|
// to avoid PostgreSQL identifier length limits
|
||||||
|
type deferredPreload struct {
|
||||||
|
relation string
|
||||||
|
apply []func(common.SelectQuery) common.SelectQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||||
@@ -233,7 +242,92 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// // shortenAliasForPostgres shortens a table/relation alias if it would exceed PostgreSQL's 63-char limit
|
||||||
|
// // when combined with typical column names
|
||||||
|
// func shortenAliasForPostgres(relationPath string) (string, bool) {
|
||||||
|
// // Convert relation path to the alias format Bun uses: dots become double underscores
|
||||||
|
// // Also convert to lowercase and use snake_case as Bun does
|
||||||
|
// parts := strings.Split(relationPath, ".")
|
||||||
|
// alias := strings.ToLower(strings.Join(parts, "__"))
|
||||||
|
|
||||||
|
// // PostgreSQL truncates identifiers to 63 chars
|
||||||
|
// // If the alias + typical column name would exceed this, we need to shorten
|
||||||
|
// // Reserve at least 30 chars for column names (e.g., "__rid_mastertype_hubtype")
|
||||||
|
// const maxAliasLength = 30
|
||||||
|
|
||||||
|
// if len(alias) > maxAliasLength {
|
||||||
|
// // Create a shortened alias using a hash of the original
|
||||||
|
// hash := md5.Sum([]byte(alias))
|
||||||
|
// hashStr := hex.EncodeToString(hash[:])[:8]
|
||||||
|
|
||||||
|
// // Keep first few chars of original for readability + hash
|
||||||
|
// prefixLen := maxAliasLength - 9 // 9 = 1 underscore + 8 hash chars
|
||||||
|
// if prefixLen > len(alias) {
|
||||||
|
// prefixLen = len(alias)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// shortened := alias[:prefixLen] + "_" + hashStr
|
||||||
|
// logger.Debug("Shortened alias '%s' (%d chars) to '%s' (%d chars) to avoid PostgreSQL 63-char limit",
|
||||||
|
// alias, len(alias), shortened, len(shortened))
|
||||||
|
// return shortened, true
|
||||||
|
// }
|
||||||
|
|
||||||
|
// return alias, false
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // estimateColumnAliasLength estimates the length of a column alias in a nested preload
|
||||||
|
// // Bun creates aliases like: relationChain__columnName
|
||||||
|
// func estimateColumnAliasLength(relationPath string, columnName string) int {
|
||||||
|
// relationParts := strings.Split(relationPath, ".")
|
||||||
|
// aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
|
||||||
|
// // Bun adds "__" between alias and column name
|
||||||
|
// return len(aliasChain) + 2 + len(columnName)
|
||||||
|
// }
|
||||||
|
|
||||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||||
|
// Check if this relation chain would create problematic long aliases
|
||||||
|
relationParts := strings.Split(relation, ".")
|
||||||
|
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
|
||||||
|
|
||||||
|
// PostgreSQL's identifier limit is 63 characters
|
||||||
|
const postgresIdentifierLimit = 63
|
||||||
|
const safeAliasLimit = 35 // Leave room for column names
|
||||||
|
|
||||||
|
// If the alias chain is too long, defer this preload to be executed as a separate query
|
||||||
|
if len(aliasChain) > safeAliasLimit {
|
||||||
|
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
|
||||||
|
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
|
||||||
|
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
|
||||||
|
|
||||||
|
// For nested preloads (e.g., "Parent.Child"), split into separate preloads
|
||||||
|
// This avoids the long concatenated alias
|
||||||
|
if len(relationParts) > 1 {
|
||||||
|
// Load first level normally: "Parent"
|
||||||
|
firstLevel := relationParts[0]
|
||||||
|
remainingPath := strings.Join(relationParts[1:], ".")
|
||||||
|
|
||||||
|
logger.Info("Splitting nested preload: loading '%s' first, then '%s' separately",
|
||||||
|
firstLevel, remainingPath)
|
||||||
|
|
||||||
|
// Apply the first level preload normally
|
||||||
|
b.query = b.query.Relation(firstLevel)
|
||||||
|
|
||||||
|
// Store the remaining nested preload to be executed after the main query
|
||||||
|
b.deferredPreloads = append(b.deferredPreloads, deferredPreload{
|
||||||
|
relation: relation,
|
||||||
|
apply: apply,
|
||||||
|
})
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single level but still too long - just warn and continue
|
||||||
|
logger.Warn("Single-level preload '%s' has a very long name (%d chars). "+
|
||||||
|
"Consider renaming the field to avoid potential issues.",
|
||||||
|
relation, len(aliasChain))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normal preload handling
|
||||||
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
@@ -309,7 +403,23 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
|
|||||||
if dest == nil {
|
if dest == nil {
|
||||||
return fmt.Errorf("destination cannot be nil")
|
return fmt.Errorf("destination cannot be nil")
|
||||||
}
|
}
|
||||||
return b.query.Scan(ctx, dest)
|
|
||||||
|
// Execute the main query first
|
||||||
|
err = b.query.Scan(ctx, dest)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute any deferred preloads
|
||||||
|
if len(b.deferredPreloads) > 0 {
|
||||||
|
err = b.executeDeferredPreloads(ctx, dest)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to execute deferred preloads: %v", err)
|
||||||
|
// Don't fail the whole query, just log the warning
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||||
@@ -322,7 +432,132 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
return fmt.Errorf("model is nil")
|
return fmt.Errorf("model is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
return b.query.Scan(ctx)
|
// Execute the main query first
|
||||||
|
err = b.query.Scan(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute any deferred preloads
|
||||||
|
if len(b.deferredPreloads) > 0 {
|
||||||
|
model := b.query.GetModel()
|
||||||
|
err = b.executeDeferredPreloads(ctx, model.Value())
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to execute deferred preloads: %v", err)
|
||||||
|
// Don't fail the whole query, just log the warning
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeDeferredPreloads executes preloads that were deferred to avoid PostgreSQL identifier length limits
|
||||||
|
func (b *BunSelectQuery) executeDeferredPreloads(ctx context.Context, dest interface{}) error {
|
||||||
|
if len(b.deferredPreloads) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dp := range b.deferredPreloads {
|
||||||
|
err := b.executeSingleDeferredPreload(ctx, dest, dp)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute deferred preload '%s': %w", dp.relation, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeSingleDeferredPreload executes a single deferred preload
|
||||||
|
// For a relation like "Parent.Child", it:
|
||||||
|
// 1. Finds all loaded Parent records in dest
|
||||||
|
// 2. Loads Child records for those Parents using a separate query (loading only "Child", not "Parent.Child")
|
||||||
|
// 3. Bun automatically assigns the Child records to the appropriate Parent.Child field
|
||||||
|
func (b *BunSelectQuery) executeSingleDeferredPreload(ctx context.Context, dest interface{}, dp deferredPreload) error {
|
||||||
|
relationParts := strings.Split(dp.relation, ".")
|
||||||
|
if len(relationParts) < 2 {
|
||||||
|
return fmt.Errorf("deferred preload must be nested (e.g., 'Parent.Child'), got: %s", dp.relation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The parent relation that was already loaded
|
||||||
|
parentRelation := relationParts[0]
|
||||||
|
// The child relation we need to load
|
||||||
|
childRelation := strings.Join(relationParts[1:], ".")
|
||||||
|
|
||||||
|
logger.Debug("Executing deferred preload: loading '%s' on already-loaded '%s'", childRelation, parentRelation)
|
||||||
|
|
||||||
|
// Use reflection to access the parent relation field(s) in the loaded records
|
||||||
|
// Then load the child relation for those parent records
|
||||||
|
destValue := reflect.ValueOf(dest)
|
||||||
|
if destValue.Kind() == reflect.Ptr {
|
||||||
|
destValue = destValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle both slice and single record
|
||||||
|
if destValue.Kind() == reflect.Slice {
|
||||||
|
// Iterate through each record in the slice
|
||||||
|
for i := 0; i < destValue.Len(); i++ {
|
||||||
|
record := destValue.Index(i)
|
||||||
|
if err := b.loadChildRelationForRecord(ctx, record, parentRelation, childRelation, dp.apply); err != nil {
|
||||||
|
logger.Warn("Failed to load child relation '%s' for record %d: %v", childRelation, i, err)
|
||||||
|
// Continue with other records
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Single record
|
||||||
|
if err := b.loadChildRelationForRecord(ctx, destValue, parentRelation, childRelation, dp.apply); err != nil {
|
||||||
|
return fmt.Errorf("failed to load child relation '%s': %w", childRelation, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadChildRelationForRecord loads a child relation for a single parent record
|
||||||
|
func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record reflect.Value, parentRelation, childRelation string, apply []func(common.SelectQuery) common.SelectQuery) error {
|
||||||
|
// Ensure we're working with the actual struct value, not a pointer
|
||||||
|
if record.Kind() == reflect.Ptr {
|
||||||
|
record = record.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the parent relation field
|
||||||
|
parentField := record.FieldByName(parentRelation)
|
||||||
|
if !parentField.IsValid() {
|
||||||
|
// Parent relation field doesn't exist
|
||||||
|
logger.Debug("Parent relation field '%s' not found in record", parentRelation)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the parent field is nil (for pointer fields)
|
||||||
|
if parentField.Kind() == reflect.Ptr && parentField.IsNil() {
|
||||||
|
// Parent relation not loaded or nil, skip
|
||||||
|
logger.Debug("Parent relation field '%s' is nil, skipping child preload", parentRelation)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the interface value to pass to Bun
|
||||||
|
parentValue := parentField.Interface()
|
||||||
|
|
||||||
|
// Load the child relation on the parent record
|
||||||
|
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
|
||||||
|
return b.db.NewSelect().
|
||||||
|
Model(parentValue).
|
||||||
|
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||||
|
// Apply any custom query modifications
|
||||||
|
if len(apply) > 0 {
|
||||||
|
wrapper := &BunSelectQuery{query: sq, db: b.db}
|
||||||
|
current := common.SelectQuery(wrapper)
|
||||||
|
for _, fn := range apply {
|
||||||
|
if fn != nil {
|
||||||
|
current = fn(current)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if finalBun, ok := current.(*BunSelectQuery); ok {
|
||||||
|
return finalBun.query
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sq
|
||||||
|
}).
|
||||||
|
Scan(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
import "context"
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
// Database interface designed to work with both GORM and Bun
|
// Database interface designed to work with both GORM and Bun
|
||||||
type Database interface {
|
type Database interface {
|
||||||
@@ -130,6 +135,99 @@ type ResponseWriter interface {
|
|||||||
// HTTPHandlerFunc type for HTTP handlers
|
// HTTPHandlerFunc type for HTTP handlers
|
||||||
type HTTPHandlerFunc func(ResponseWriter, Request)
|
type HTTPHandlerFunc func(ResponseWriter, Request)
|
||||||
|
|
||||||
|
// WrapHTTPRequest wraps standard http.ResponseWriter and *http.Request into common interfaces
|
||||||
|
func WrapHTTPRequest(w http.ResponseWriter, r *http.Request) (ResponseWriter, Request) {
|
||||||
|
return &StandardResponseWriter{w: w}, &StandardRequest{r: r}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StandardResponseWriter adapts http.ResponseWriter to ResponseWriter interface
|
||||||
|
type StandardResponseWriter struct {
|
||||||
|
w http.ResponseWriter
|
||||||
|
status int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StandardResponseWriter) SetHeader(key, value string) {
|
||||||
|
s.w.Header().Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StandardResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
s.status = statusCode
|
||||||
|
s.w.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StandardResponseWriter) Write(data []byte) (int, error) {
|
||||||
|
return s.w.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
|
||||||
|
s.SetHeader("Content-Type", "application/json")
|
||||||
|
return json.NewEncoder(s.w).Encode(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StandardRequest adapts *http.Request to Request interface
|
||||||
|
type StandardRequest struct {
|
||||||
|
r *http.Request
|
||||||
|
body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StandardRequest) Method() string {
|
||||||
|
return s.r.Method
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StandardRequest) URL() string {
|
||||||
|
return s.r.URL.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StandardRequest) Header(key string) string {
|
||||||
|
return s.r.Header.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StandardRequest) AllHeaders() map[string]string {
|
||||||
|
headers := make(map[string]string)
|
||||||
|
for key, values := range s.r.Header {
|
||||||
|
if len(values) > 0 {
|
||||||
|
headers[key] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StandardRequest) Body() ([]byte, error) {
|
||||||
|
if s.body != nil {
|
||||||
|
return s.body, nil
|
||||||
|
}
|
||||||
|
if s.r.Body == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
defer s.r.Body.Close()
|
||||||
|
body, err := io.ReadAll(s.r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.body = body
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StandardRequest) PathParam(key string) string {
|
||||||
|
// Standard http.Request doesn't have path params
|
||||||
|
// This should be set by the router
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StandardRequest) QueryParam(key string) string {
|
||||||
|
return s.r.URL.Query().Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StandardRequest) AllQueryParams() map[string]string {
|
||||||
|
params := make(map[string]string)
|
||||||
|
for key, values := range s.r.URL.Query() {
|
||||||
|
if len(values) > 0 {
|
||||||
|
params[key] = values[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
// TableNameProvider interface for models that provide table names
|
// TableNameProvider interface for models that provide table names
|
||||||
type TableNameProvider interface {
|
type TableNameProvider interface {
|
||||||
TableName() string
|
TableName() string
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||||
@@ -135,6 +137,15 @@ func SanitizeWhereClause(where string, tableName string) string {
|
|||||||
|
|
||||||
where = strings.TrimSpace(where)
|
where = strings.TrimSpace(where)
|
||||||
|
|
||||||
|
// Strip outer parentheses and re-trim
|
||||||
|
where = stripOuterParentheses(where)
|
||||||
|
|
||||||
|
// Get valid columns from the model if tableName is provided
|
||||||
|
var validColumns map[string]bool
|
||||||
|
if tableName != "" {
|
||||||
|
validColumns = getValidColumnsForTable(tableName)
|
||||||
|
}
|
||||||
|
|
||||||
// Split by AND to handle multiple conditions
|
// Split by AND to handle multiple conditions
|
||||||
conditions := splitByAND(where)
|
conditions := splitByAND(where)
|
||||||
|
|
||||||
@@ -146,22 +157,32 @@ func SanitizeWhereClause(where string, tableName string) string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Strip parentheses from the condition before checking
|
||||||
|
condToCheck := stripOuterParentheses(cond)
|
||||||
|
|
||||||
// Skip trivial conditions that always evaluate to true
|
// Skip trivial conditions that always evaluate to true
|
||||||
if IsTrivialCondition(cond) {
|
if IsTrivialCondition(condToCheck) {
|
||||||
logger.Debug("Removing trivial condition: '%s'", cond)
|
logger.Debug("Removing trivial condition: '%s'", cond)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// If tableName is provided and the condition doesn't already have a table prefix,
|
// If tableName is provided and the condition doesn't already have a table prefix,
|
||||||
// attempt to add it
|
// attempt to add it
|
||||||
if tableName != "" && !hasTablePrefix(cond) {
|
if tableName != "" && !hasTablePrefix(condToCheck) {
|
||||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||||
if !IsSQLExpression(strings.ToLower(cond)) {
|
if !IsSQLExpression(strings.ToLower(condToCheck)) {
|
||||||
// Extract the column name and prefix it
|
// Extract the column name and prefix it
|
||||||
columnName := ExtractColumnName(cond)
|
columnName := ExtractColumnName(condToCheck)
|
||||||
if columnName != "" {
|
if columnName != "" {
|
||||||
|
// Only prefix if this is a valid column in the model
|
||||||
|
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
|
||||||
|
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||||
|
// Replace in the original condition (without stripped parens)
|
||||||
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
|
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
|
||||||
logger.Debug("Prefixed column in condition: '%s'", cond)
|
logger.Debug("Prefixed column in condition: '%s'", cond)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -182,6 +203,43 @@ func SanitizeWhereClause(where string, tableName string) string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stripOuterParentheses removes matching outer parentheses from a string
|
||||||
|
// It handles nested parentheses correctly
|
||||||
|
func stripOuterParentheses(s string) string {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
|
||||||
|
for {
|
||||||
|
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if these parentheses match (i.e., they're the outermost pair)
|
||||||
|
depth := 0
|
||||||
|
matched := false
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
switch s[i] {
|
||||||
|
case '(':
|
||||||
|
depth++
|
||||||
|
case ')':
|
||||||
|
depth--
|
||||||
|
if depth == 0 && i == len(s)-1 {
|
||||||
|
matched = true
|
||||||
|
} else if depth == 0 {
|
||||||
|
// Found a closing paren before the end, so outer parens don't match
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !matched {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip the outer parentheses and continue
|
||||||
|
s = strings.TrimSpace(s[1 : len(s)-1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||||
// This is a simple split that doesn't handle nested parentheses or complex expressions
|
// This is a simple split that doesn't handle nested parentheses or complex expressions
|
||||||
func splitByAND(where string) []string {
|
func splitByAND(where string) []string {
|
||||||
@@ -245,3 +303,38 @@ func IsSQLKeyword(word string) bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry
|
||||||
|
// Returns a map of column names for fast lookup, or nil if the model is not found
|
||||||
|
func getValidColumnsForTable(tableName string) map[string]bool {
|
||||||
|
// Try to get the model from the registry
|
||||||
|
model, err := modelregistry.GetModelByName(tableName)
|
||||||
|
if err != nil {
|
||||||
|
// Model not found, return nil to indicate we should use fallback behavior
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get SQL columns from the model
|
||||||
|
columns := reflection.GetSQLModelColumns(model)
|
||||||
|
if len(columns) == 0 {
|
||||||
|
// No columns found, return nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a map for fast lookup
|
||||||
|
columnMap := make(map[string]bool, len(columns))
|
||||||
|
for _, col := range columns {
|
||||||
|
columnMap[strings.ToLower(col)] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return columnMap
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidColumn checks if a column name exists in the valid columns map
|
||||||
|
// Handles case-insensitive comparison
|
||||||
|
func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||||
|
if validColumns == nil {
|
||||||
|
return true // No model info, assume valid
|
||||||
|
}
|
||||||
|
return validColumns[strings.ToLower(columnName)]
|
||||||
|
}
|
||||||
|
|||||||
224
pkg/common/sql_helpers_test.go
Normal file
224
pkg/common/sql_helpers_test.go
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSanitizeWhereClause(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
where string
|
||||||
|
tableName string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "trivial conditions in parentheses",
|
||||||
|
where: "(true AND true AND true)",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trivial conditions without parentheses",
|
||||||
|
where: "true AND true AND true",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single trivial condition",
|
||||||
|
where: "true",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid condition with parentheses",
|
||||||
|
where: "(status = 'active')",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed trivial and valid conditions",
|
||||||
|
where: "true AND status = 'active' AND 1=1",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "condition already with table prefix",
|
||||||
|
where: "users.status = 'active'",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple valid conditions",
|
||||||
|
where: "status = 'active' AND age > 18",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no table name provided",
|
||||||
|
where: "status = 'active'",
|
||||||
|
tableName: "",
|
||||||
|
expected: "status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty where clause",
|
||||||
|
where: "",
|
||||||
|
tableName: "users",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SanitizeWhereClause(tt.where, tt.tableName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripOuterParentheses(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single level parentheses",
|
||||||
|
input: "(true)",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple levels",
|
||||||
|
input: "((true))",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no parentheses",
|
||||||
|
input: "true",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mismatched parentheses",
|
||||||
|
input: "(true",
|
||||||
|
expected: "(true",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex expression",
|
||||||
|
input: "(a AND b)",
|
||||||
|
expected: "a AND b",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested but not outer",
|
||||||
|
input: "(a AND (b OR c)) AND d",
|
||||||
|
expected: "(a AND (b OR c)) AND d",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with spaces",
|
||||||
|
input: " ( true ) ",
|
||||||
|
expected: "true",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := stripOuterParentheses(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("stripOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTrivialCondition(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"true", "true", true},
|
||||||
|
{"true with spaces", " true ", true},
|
||||||
|
{"TRUE uppercase", "TRUE", true},
|
||||||
|
{"1=1", "1=1", true},
|
||||||
|
{"1 = 1", "1 = 1", true},
|
||||||
|
{"true = true", "true = true", true},
|
||||||
|
{"valid condition", "status = 'active'", false},
|
||||||
|
{"false", "false", false},
|
||||||
|
{"column name", "is_active", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := IsTrivialCondition(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsTrivialCondition(%q) = %v; want %v", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test model for model-aware sanitization tests
|
||||||
|
type MasterTask struct {
|
||||||
|
ID int `bun:"id,pk"`
|
||||||
|
Name string `bun:"name"`
|
||||||
|
Status string `bun:"status"`
|
||||||
|
UserID int `bun:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||||
|
// Register the test model
|
||||||
|
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
|
||||||
|
if err != nil {
|
||||||
|
// Model might already be registered, ignore error
|
||||||
|
t.Logf("Model registration returned: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
where string
|
||||||
|
tableName string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid column gets prefixed",
|
||||||
|
where: "status = 'active'",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "mastertask.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple valid columns get prefixed",
|
||||||
|
where: "status = 'active' AND user_id = 123",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid column does not get prefixed",
|
||||||
|
where: "invalid_column = 'value'",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "invalid_column = 'value'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mix of valid and trivial conditions",
|
||||||
|
where: "true AND status = 'active' AND 1=1",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "mastertask.status = 'active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parentheses with valid column",
|
||||||
|
where: "(status = 'active')",
|
||||||
|
tableName: "mastertask",
|
||||||
|
expected: "mastertask.status = 'active'",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SanitizeWhereClause(tt.where, tt.tableName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -43,6 +43,11 @@ type PreloadOption struct {
|
|||||||
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
||||||
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
|
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
|
||||||
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
|
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
|
||||||
|
|
||||||
|
// Relationship keys from XFiles - used to build proper foreign key filters
|
||||||
|
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
||||||
|
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
||||||
|
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
||||||
}
|
}
|
||||||
|
|
||||||
type FilterOption struct {
|
type FilterOption struct {
|
||||||
|
|||||||
@@ -17,3 +17,33 @@ func Len(v any) int {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExtractTableNameOnly extracts the table name from a fully qualified table reference.
|
||||||
|
// It removes any schema prefix (e.g., "schema.table" -> "table") and truncates at
|
||||||
|
// the first delimiter (comma, space, tab, or newline). If the input contains multiple
|
||||||
|
// dots, it returns everything after the last dot up to the first delimiter.
|
||||||
|
func ExtractTableNameOnly(fullName string) string {
|
||||||
|
// First, split by dot to remove schema prefix if present
|
||||||
|
lastDotIndex := -1
|
||||||
|
for i, char := range fullName {
|
||||||
|
if char == '.' {
|
||||||
|
lastDotIndex = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start from after the last dot (or from beginning if no dot)
|
||||||
|
startIndex := 0
|
||||||
|
if lastDotIndex != -1 {
|
||||||
|
startIndex = lastDotIndex + 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now find the end (first delimiter after the table name)
|
||||||
|
for i := startIndex; i < len(fullName); i++ {
|
||||||
|
char := rune(fullName[i])
|
||||||
|
if char == ',' || char == ' ' || char == '\t' || char == '\n' {
|
||||||
|
return fullName[startIndex:i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fullName[startIndex:]
|
||||||
|
}
|
||||||
|
|||||||
@@ -756,11 +756,42 @@ func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error)
|
|||||||
// 2. Bun tag name (if exists)
|
// 2. Bun tag name (if exists)
|
||||||
// 3. Gorm tag name (if exists)
|
// 3. Gorm tag name (if exists)
|
||||||
// 4. JSON tag name (if exists)
|
// 4. JSON tag name (if exists)
|
||||||
|
//
|
||||||
|
// Supports recursive field paths using dot notation (e.g., "MAL.MAL.DEF")
|
||||||
|
// For nested fields, it traverses through each level of the struct hierarchy
|
||||||
func GetRelationModel(model interface{}, fieldName string) interface{} {
|
func GetRelationModel(model interface{}, fieldName string) interface{} {
|
||||||
if model == nil || fieldName == "" {
|
if model == nil || fieldName == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Split the field name by "." to handle nested/recursive relations
|
||||||
|
fieldParts := strings.Split(fieldName, ".")
|
||||||
|
|
||||||
|
// Start with the current model
|
||||||
|
currentModel := model
|
||||||
|
|
||||||
|
// Traverse through each level of the field path
|
||||||
|
for _, part := range fieldParts {
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
currentModel = getRelationModelSingleLevel(currentModel, part)
|
||||||
|
if currentModel == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return currentModel
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
||||||
|
// This is a helper function used by GetRelationModel to handle one level at a time
|
||||||
|
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
||||||
|
if model == nil || fieldName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
if modelType == nil {
|
if modelType == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -199,7 +199,9 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply column selection
|
// Apply column selection
|
||||||
if len(options.Columns) > 0 {
|
if len(options.Columns) > 0 {
|
||||||
logger.Debug("Selecting columns: %v", options.Columns)
|
logger.Debug("Selecting columns: %v", options.Columns)
|
||||||
query = query.Column(options.Columns...)
|
for _, col := range options.Columns {
|
||||||
|
query = query.Column(reflection.ExtractSourceColumn(col))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(options.ComputedColumns) > 0 {
|
if len(options.ComputedColumns) > 0 {
|
||||||
@@ -1209,7 +1211,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation)
|
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||||
if len(sanitizedWhere) > 0 {
|
if len(sanitizedWhere) > 0 {
|
||||||
sq = sq.Where(sanitizedWhere)
|
sq = sq.Where(sanitizedWhere)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -213,6 +213,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
tableName := GetTableName(ctx)
|
tableName := GetTableName(ctx)
|
||||||
model := GetModel(ctx)
|
model := GetModel(ctx)
|
||||||
|
|
||||||
|
if id == "" {
|
||||||
|
options.SingleRecordAsObject = false
|
||||||
|
}
|
||||||
|
|
||||||
// Execute BeforeRead hooks
|
// Execute BeforeRead hooks
|
||||||
hookCtx := &HookContext{
|
hookCtx := &HookContext{
|
||||||
Context: ctx,
|
Context: ctx,
|
||||||
@@ -299,7 +303,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply column selection
|
// Apply column selection
|
||||||
if len(options.Columns) > 0 {
|
if len(options.Columns) > 0 {
|
||||||
logger.Debug("Selecting columns: %v", options.Columns)
|
logger.Debug("Selecting columns: %v", options.Columns)
|
||||||
query = query.Column(options.Columns...)
|
for _, col := range options.Columns {
|
||||||
|
query = query.Column(reflection.ExtractSourceColumn(col))
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply expand (Just expand to Preload for now)
|
// Apply expand (Just expand to Preload for now)
|
||||||
@@ -392,7 +399,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
if options.CustomSQLWhere != "" {
|
if options.CustomSQLWhere != "" {
|
||||||
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
||||||
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
||||||
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, "")
|
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
||||||
if sanitizedWhere != "" {
|
if sanitizedWhere != "" {
|
||||||
query = query.Where(sanitizedWhere)
|
query = query.Where(sanitizedWhere)
|
||||||
}
|
}
|
||||||
@@ -402,7 +409,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
if options.CustomSQLOr != "" {
|
if options.CustomSQLOr != "" {
|
||||||
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
||||||
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
||||||
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, "")
|
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
||||||
if sanitizedOr != "" {
|
if sanitizedOr != "" {
|
||||||
query = query.WhereOr(sanitizedOr)
|
query = query.WhereOr(sanitizedOr)
|
||||||
}
|
}
|
||||||
@@ -481,7 +488,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Apply cursor filter to query
|
// Apply cursor filter to query
|
||||||
if cursorFilter != "" {
|
if cursorFilter != "" {
|
||||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, "")
|
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName))
|
||||||
if sanitizedCursor != "" {
|
if sanitizedCursor != "" {
|
||||||
query = query.Where(sanitizedCursor)
|
query = query.Where(sanitizedCursor)
|
||||||
}
|
}
|
||||||
@@ -560,11 +567,33 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
|
|
||||||
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
||||||
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery {
|
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery {
|
||||||
|
// Log relationship keys if they're specified (from XFiles)
|
||||||
|
if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" {
|
||||||
|
logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s",
|
||||||
|
preload.Relation, preload.PrimaryKey, preload.RelatedKey, preload.ForeignKey)
|
||||||
|
|
||||||
|
// Build a WHERE clause using the relationship keys if needed
|
||||||
|
// Note: Bun's PreloadRelation typically handles the relationship join automatically via struct tags
|
||||||
|
// However, if the relationship keys are explicitly provided from XFiles, we can use them
|
||||||
|
// to add additional filtering or validation
|
||||||
|
if preload.RelatedKey != "" && preload.Where == "" {
|
||||||
|
// For child tables: ensure the child's relatedkey column will be matched
|
||||||
|
// The actual parent value is dynamic and handled by Bun's preload mechanism
|
||||||
|
// We just log this for visibility
|
||||||
|
logger.Debug("Child table %s will be filtered by %s matching parent's primary key",
|
||||||
|
preload.Relation, preload.RelatedKey)
|
||||||
|
}
|
||||||
|
if preload.ForeignKey != "" && preload.Where == "" {
|
||||||
|
// For parent tables: ensure the parent's primary key matches the current table's foreign key
|
||||||
|
logger.Debug("Parent table %s will be filtered by primary key matching current table's %s",
|
||||||
|
preload.Relation, preload.ForeignKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Apply the preload
|
// Apply the preload
|
||||||
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
// Get the related model for column operations
|
// Get the related model for column operations
|
||||||
relationParts := strings.Split(preload.Relation, ",")
|
relatedModel := reflection.GetRelationModel(model, preload.Relation)
|
||||||
relatedModel := reflection.GetRelationModel(model, relationParts[0])
|
|
||||||
if relatedModel == nil {
|
if relatedModel == nil {
|
||||||
logger.Warn("Could not get related model for preload: %s", preload.Relation)
|
logger.Warn("Could not get related model for preload: %s", preload.Relation)
|
||||||
// relatedModel = model // fallback to parent model
|
// relatedModel = model // fallback to parent model
|
||||||
@@ -633,7 +662,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
|
|
||||||
// Apply WHERE clause
|
// Apply WHERE clause
|
||||||
if len(preload.Where) > 0 {
|
if len(preload.Where) > 0 {
|
||||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation)
|
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||||
if len(sanitizedWhere) > 0 {
|
if len(sanitizedWhere) > 0 {
|
||||||
sq = sq.Where(sanitizedWhere)
|
sq = sq.Where(sanitizedWhere)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -162,9 +162,17 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
case strings.HasPrefix(key, "x-searchcols"):
|
case strings.HasPrefix(key, "x-searchcols"):
|
||||||
options.SearchColumns = h.parseCommaSeparated(decodedValue)
|
options.SearchColumns = h.parseCommaSeparated(decodedValue)
|
||||||
case strings.HasPrefix(key, "x-custom-sql-w"):
|
case strings.HasPrefix(key, "x-custom-sql-w"):
|
||||||
|
if options.CustomSQLWhere != "" {
|
||||||
|
options.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", options.CustomSQLWhere, decodedValue)
|
||||||
|
} else {
|
||||||
options.CustomSQLWhere = decodedValue
|
options.CustomSQLWhere = decodedValue
|
||||||
|
}
|
||||||
case strings.HasPrefix(key, "x-custom-sql-or"):
|
case strings.HasPrefix(key, "x-custom-sql-or"):
|
||||||
|
if options.CustomSQLOr != "" {
|
||||||
|
options.CustomSQLOr = fmt.Sprintf("%s OR (%s)", options.CustomSQLOr, decodedValue)
|
||||||
|
} else {
|
||||||
options.CustomSQLOr = decodedValue
|
options.CustomSQLOr = decodedValue
|
||||||
|
}
|
||||||
|
|
||||||
// Joins & Relations
|
// Joins & Relations
|
||||||
case strings.HasPrefix(key, "x-preload"):
|
case strings.HasPrefix(key, "x-preload"):
|
||||||
@@ -226,6 +234,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
case strings.HasPrefix(key, "x-cql-sel-"):
|
case strings.HasPrefix(key, "x-cql-sel-"):
|
||||||
colName := strings.TrimPrefix(key, "x-cql-sel-")
|
colName := strings.TrimPrefix(key, "x-cql-sel-")
|
||||||
options.ComputedQL[colName] = decodedValue
|
options.ComputedQL[colName] = decodedValue
|
||||||
|
|
||||||
case strings.HasPrefix(key, "x-distinct"):
|
case strings.HasPrefix(key, "x-distinct"):
|
||||||
options.Distinct = strings.EqualFold(decodedValue, "true")
|
options.Distinct = strings.EqualFold(decodedValue, "true")
|
||||||
case strings.HasPrefix(key, "x-skipcount"):
|
case strings.HasPrefix(key, "x-skipcount"):
|
||||||
@@ -267,6 +276,12 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
h.resolveRelationNamesInOptions(&options, model)
|
h.resolveRelationNamesInOptions(&options, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always sort according to the primary key if no sorting is specified
|
||||||
|
if len(options.Sort) == 0 {
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||||
|
}
|
||||||
|
|
||||||
return options
|
return options
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -777,7 +792,7 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
|
|||||||
field := modelType.Field(i)
|
field := modelType.Field(i)
|
||||||
if field.Name == nameOrTable {
|
if field.Name == nameOrTable {
|
||||||
// It's already a field name
|
// It's already a field name
|
||||||
logger.Debug("Input '%s' is a field name", nameOrTable)
|
// logger.Debug("Input '%s' is a field name", nameOrTable)
|
||||||
return nameOrTable
|
return nameOrTable
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -919,6 +934,20 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
// Set recursive flag
|
// Set recursive flag
|
||||||
preloadOpt.Recursive = xfile.Recursive
|
preloadOpt.Recursive = xfile.Recursive
|
||||||
|
|
||||||
|
// Extract relationship keys for proper foreign key filtering
|
||||||
|
if xfile.PrimaryKey != "" {
|
||||||
|
preloadOpt.PrimaryKey = xfile.PrimaryKey
|
||||||
|
logger.Debug("X-Files: Set primary key for %s: %s", relationPath, xfile.PrimaryKey)
|
||||||
|
}
|
||||||
|
if xfile.RelatedKey != "" {
|
||||||
|
preloadOpt.RelatedKey = xfile.RelatedKey
|
||||||
|
logger.Debug("X-Files: Set related key for %s: %s", relationPath, xfile.RelatedKey)
|
||||||
|
}
|
||||||
|
if xfile.ForeignKey != "" {
|
||||||
|
preloadOpt.ForeignKey = xfile.ForeignKey
|
||||||
|
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
||||||
|
}
|
||||||
|
|
||||||
// Add the preload option
|
// Add the preload option
|
||||||
options.Preload = append(options.Preload, preloadOpt)
|
options.Preload = append(options.Preload, preloadOpt)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user