mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-30 00:04:25 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a3564f05f | ||
|
|
a931b8cdd2 |
@@ -43,6 +43,11 @@ type PreloadOption struct {
|
||||
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
||||
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
|
||||
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
|
||||
|
||||
// Relationship keys from XFiles - used to build proper foreign key filters
|
||||
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
||||
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
||||
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
||||
}
|
||||
|
||||
type FilterOption struct {
|
||||
|
||||
@@ -17,3 +17,33 @@ func Len(v any) int {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractTableNameOnly extracts the table name from a fully qualified table reference.
|
||||
// It removes any schema prefix (e.g., "schema.table" -> "table") and truncates at
|
||||
// the first delimiter (comma, space, tab, or newline). If the input contains multiple
|
||||
// dots, it returns everything after the last dot up to the first delimiter.
|
||||
func ExtractTableNameOnly(fullName string) string {
|
||||
// First, split by dot to remove schema prefix if present
|
||||
lastDotIndex := -1
|
||||
for i, char := range fullName {
|
||||
if char == '.' {
|
||||
lastDotIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
// Start from after the last dot (or from beginning if no dot)
|
||||
startIndex := 0
|
||||
if lastDotIndex != -1 {
|
||||
startIndex = lastDotIndex + 1
|
||||
}
|
||||
|
||||
// Now find the end (first delimiter after the table name)
|
||||
for i := startIndex; i < len(fullName); i++ {
|
||||
char := rune(fullName[i])
|
||||
if char == ',' || char == ' ' || char == '\t' || char == '\n' {
|
||||
return fullName[startIndex:i]
|
||||
}
|
||||
}
|
||||
|
||||
return fullName[startIndex:]
|
||||
}
|
||||
|
||||
@@ -756,11 +756,42 @@ func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error)
|
||||
// 2. Bun tag name (if exists)
|
||||
// 3. Gorm tag name (if exists)
|
||||
// 4. JSON tag name (if exists)
|
||||
//
|
||||
// Supports recursive field paths using dot notation (e.g., "MAL.MAL.DEF")
|
||||
// For nested fields, it traverses through each level of the struct hierarchy
|
||||
func GetRelationModel(model interface{}, fieldName string) interface{} {
|
||||
if model == nil || fieldName == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Split the field name by "." to handle nested/recursive relations
|
||||
fieldParts := strings.Split(fieldName, ".")
|
||||
|
||||
// Start with the current model
|
||||
currentModel := model
|
||||
|
||||
// Traverse through each level of the field path
|
||||
for _, part := range fieldParts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
currentModel = getRelationModelSingleLevel(currentModel, part)
|
||||
if currentModel == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return currentModel
|
||||
}
|
||||
|
||||
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
||||
// This is a helper function used by GetRelationModel to handle one level at a time
|
||||
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
||||
if model == nil || fieldName == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType == nil {
|
||||
return nil
|
||||
|
||||
@@ -1209,7 +1209,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
}
|
||||
|
||||
if len(preload.Where) > 0 {
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation)
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||
if len(sanitizedWhere) > 0 {
|
||||
sq = sq.Where(sanitizedWhere)
|
||||
}
|
||||
|
||||
@@ -392,7 +392,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
if options.CustomSQLWhere != "" {
|
||||
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
|
||||
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
||||
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, "")
|
||||
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
||||
if sanitizedWhere != "" {
|
||||
query = query.Where(sanitizedWhere)
|
||||
}
|
||||
@@ -402,7 +402,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
if options.CustomSQLOr != "" {
|
||||
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
|
||||
// Sanitize without auto-prefixing since custom SQL may reference multiple tables
|
||||
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, "")
|
||||
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
||||
if sanitizedOr != "" {
|
||||
query = query.WhereOr(sanitizedOr)
|
||||
}
|
||||
@@ -481,7 +481,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Apply cursor filter to query
|
||||
if cursorFilter != "" {
|
||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, "")
|
||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName))
|
||||
if sanitizedCursor != "" {
|
||||
query = query.Where(sanitizedCursor)
|
||||
}
|
||||
@@ -560,11 +560,33 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
|
||||
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
|
||||
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, model interface{}, depth int) common.SelectQuery {
|
||||
// Log relationship keys if they're specified (from XFiles)
|
||||
if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" {
|
||||
logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s",
|
||||
preload.Relation, preload.PrimaryKey, preload.RelatedKey, preload.ForeignKey)
|
||||
|
||||
// Build a WHERE clause using the relationship keys if needed
|
||||
// Note: Bun's PreloadRelation typically handles the relationship join automatically via struct tags
|
||||
// However, if the relationship keys are explicitly provided from XFiles, we can use them
|
||||
// to add additional filtering or validation
|
||||
if preload.RelatedKey != "" && preload.Where == "" {
|
||||
// For child tables: ensure the child's relatedkey column will be matched
|
||||
// The actual parent value is dynamic and handled by Bun's preload mechanism
|
||||
// We just log this for visibility
|
||||
logger.Debug("Child table %s will be filtered by %s matching parent's primary key",
|
||||
preload.Relation, preload.RelatedKey)
|
||||
}
|
||||
if preload.ForeignKey != "" && preload.Where == "" {
|
||||
// For parent tables: ensure the parent's primary key matches the current table's foreign key
|
||||
logger.Debug("Parent table %s will be filtered by primary key matching current table's %s",
|
||||
preload.Relation, preload.ForeignKey)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the preload
|
||||
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
|
||||
// Get the related model for column operations
|
||||
relationParts := strings.Split(preload.Relation, ",")
|
||||
relatedModel := reflection.GetRelationModel(model, relationParts[0])
|
||||
relatedModel := reflection.GetRelationModel(model, preload.Relation)
|
||||
if relatedModel == nil {
|
||||
logger.Warn("Could not get related model for preload: %s", preload.Relation)
|
||||
// relatedModel = model // fallback to parent model
|
||||
@@ -633,7 +655,7 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
|
||||
// Apply WHERE clause
|
||||
if len(preload.Where) > 0 {
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, preload.Relation)
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||
if len(sanitizedWhere) > 0 {
|
||||
sq = sq.Where(sanitizedWhere)
|
||||
}
|
||||
|
||||
@@ -919,6 +919,20 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
// Set recursive flag
|
||||
preloadOpt.Recursive = xfile.Recursive
|
||||
|
||||
// Extract relationship keys for proper foreign key filtering
|
||||
if xfile.PrimaryKey != "" {
|
||||
preloadOpt.PrimaryKey = xfile.PrimaryKey
|
||||
logger.Debug("X-Files: Set primary key for %s: %s", relationPath, xfile.PrimaryKey)
|
||||
}
|
||||
if xfile.RelatedKey != "" {
|
||||
preloadOpt.RelatedKey = xfile.RelatedKey
|
||||
logger.Debug("X-Files: Set related key for %s: %s", relationPath, xfile.RelatedKey)
|
||||
}
|
||||
if xfile.ForeignKey != "" {
|
||||
preloadOpt.ForeignKey = xfile.ForeignKey
|
||||
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
||||
}
|
||||
|
||||
// Add the preload option
|
||||
options.Preload = append(options.Preload, preloadOpt)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user