From 311e50bfdd0b6a21c76f21534272a7393cf4d3c8 Mon Sep 17 00:00:00 2001 From: Hein Date: Thu, 20 Nov 2025 14:30:59 +0200 Subject: [PATCH] Better relation lookup --- pkg/restheadspec/context.go | 19 +++- pkg/restheadspec/handler.go | 10 +-- pkg/restheadspec/headers.go | 174 +++++++++++++++++++++++++++++++++++- 3 files changed, 195 insertions(+), 8 deletions(-) diff --git a/pkg/restheadspec/context.go b/pkg/restheadspec/context.go index 820ef21..c57a8f7 100644 --- a/pkg/restheadspec/context.go +++ b/pkg/restheadspec/context.go @@ -13,6 +13,7 @@ const ( contextKeyTableName contextKey = "tableName" contextKeyModel contextKey = "model" contextKeyModelPtr contextKey = "modelPtr" + contextKeyOptions contextKey = "options" ) // WithSchema adds schema to context @@ -74,12 +75,28 @@ func GetModelPtr(ctx context.Context) interface{} { return ctx.Value(contextKeyModelPtr) } +// WithOptions adds request options to context +func WithOptions(ctx context.Context, options ExtendedRequestOptions) context.Context { + return context.WithValue(ctx, contextKeyOptions, options) +} + +// GetOptions retrieves request options from context +func GetOptions(ctx context.Context) *ExtendedRequestOptions { + if v := ctx.Value(contextKeyOptions); v != nil { + if opts, ok := v.(ExtendedRequestOptions); ok { + return &opts + } + } + return nil +} + // WithRequestData adds all request-scoped data to context at once -func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context { +func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}, options ExtendedRequestOptions) context.Context { ctx = WithSchema(ctx, schema) ctx = WithEntity(ctx, entity) ctx = WithTableName(ctx, tableName) ctx = WithModel(ctx, model) ctx = WithModelPtr(ctx, modelPtr) + ctx = WithOptions(ctx, options) return ctx } diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 0dc316a..28912e6 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -65,9 +65,6 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s entity := params["entity"] id := params["id"] - // Parse options from headers (now returns ExtendedRequestOptions) - options := h.parseOptionsFromHeaders(r) - // Determine operation based on HTTP method method := r.Method() @@ -104,13 +101,16 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s modelPtr := reflect.New(reflect.TypeOf(model)).Interface() tableName := h.getTableName(schema, entity, model) - // Add request-scoped data to context - ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr) + // Parse options from headers - this now includes relation name resolution + options := h.parseOptionsFromHeaders(r, model) // Validate and filter columns in options (log warnings for invalid columns) validator := common.NewColumnValidator(model) options = filterExtendedOptions(validator, options) + // Add request-scoped data to context (including options) + ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options) + switch method { case "GET": if id != "" { diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index 09c2fa7..c395f98 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -99,7 +99,8 @@ func DecodeParam(pStr string) (string, error) { } // parseOptionsFromHeaders parses all request options from HTTP headers -func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptions { +// If model is provided, it will resolve table names to field names in preload/expand options +func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) ExtendedRequestOptions { options := ExtendedRequestOptions{ RequestOptions: common.RequestOptions{ Filters: make([]common.FilterOption, 0), @@ -225,6 +226,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request) ExtendedRequestOptio } } + // Resolve relation names (convert table names to field names) if model is provided + if model != nil { + h.resolveRelationNamesInOptions(&options, model) + } + return options } @@ -655,6 +661,169 @@ func (h *Handler) processXFilesRelations(xfiles *XFiles, options *ExtendedReques } } +// resolveRelationNamesInOptions resolves all table names to field names in preload options +// This is called internally by parseOptionsFromHeaders when a model is provided +func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions, model interface{}) { + if options == nil || model == nil { + return + } + + // Resolve relation names in all preload options + for i := range options.Preload { + preload := &options.Preload[i] + + // Split the relation path (e.g., "parent.child.grandchild") + parts := strings.Split(preload.Relation, ".") + resolvedParts := make([]string, 0, len(parts)) + + // Resolve each part of the path + currentModel := model + for _, part := range parts { + resolvedPart := h.resolveRelationName(currentModel, part) + resolvedParts = append(resolvedParts, resolvedPart) + + // Try to get the model type for the next level + // This allows nested resolution + if nextModel := h.getRelationModel(currentModel, resolvedPart); nextModel != nil { + currentModel = nextModel + } + } + + // Update the relation path with resolved names + resolvedPath := strings.Join(resolvedParts, ".") + if resolvedPath != preload.Relation { + logger.Debug("Resolved relation path '%s' -> '%s'", preload.Relation, resolvedPath) + preload.Relation = resolvedPath + } + } + + // Resolve relation names in expand options + for i := range options.Expand { + expand := &options.Expand[i] + resolved := h.resolveRelationName(model, expand.Relation) + if resolved != expand.Relation { + logger.Debug("Resolved expand relation '%s' -> '%s'", expand.Relation, resolved) + expand.Relation = resolved + } + } +} + +// getRelationModel gets the model type for a relation field +func (h *Handler) getRelationModel(model interface{}, fieldName string) interface{} { + if model == nil || fieldName == "" { + return nil + } + + modelType := reflect.TypeOf(model) + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + if modelType.Kind() != reflect.Struct { + return nil + } + + // Find the field + field, found := modelType.FieldByName(fieldName) + if !found { + return nil + } + + // Get the target type + targetType := field.Type + if targetType.Kind() == reflect.Slice { + targetType = targetType.Elem() + } + if targetType.Kind() == reflect.Ptr { + targetType = targetType.Elem() + } + + if targetType.Kind() != reflect.Struct { + return nil + } + + // Create a zero value of the target type + return reflect.New(targetType).Elem().Interface() +} + +// resolveRelationName resolves a relation name or table name to the actual field name in the model +// If the input is already a field name, it returns it as-is +// If the input is a table name, it looks up the corresponding relation field +func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) string { + if model == nil || nameOrTable == "" { + return nameOrTable + } + + modelType := reflect.TypeOf(model) + // Dereference pointer if needed + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + // Ensure it's a struct + if modelType.Kind() != reflect.Struct { + return nameOrTable + } + + // First, check if the input matches a field name directly + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + if field.Name == nameOrTable { + // It's already a field name + logger.Debug("Input '%s' is a field name", nameOrTable) + return nameOrTable + } + } + + // If not found as a field name, try to look it up as a table name + normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", "")) + + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + fieldType := field.Type + + // Check if it's a slice or pointer to a struct + var targetType reflect.Type + if fieldType.Kind() == reflect.Slice { + targetType = fieldType.Elem() + } else if fieldType.Kind() == reflect.Ptr { + targetType = fieldType.Elem() + } + + if targetType != nil { + // Dereference pointer if the slice contains pointers + if targetType.Kind() == reflect.Ptr { + targetType = targetType.Elem() + } + + // Check if it's a struct type + if targetType.Kind() == reflect.Struct { + // Get the type name and normalize it + typeName := targetType.Name() + + // Extract the table name from type name + // Patterns: ModelCoreMastertaskitem -> mastertaskitem + // ModelMastertaskitem -> mastertaskitem + normalizedTypeName := strings.ToLower(typeName) + + // Remove common prefixes like "model", "modelcore", etc. + normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore") + normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model") + + // Compare normalized names + if normalizedTypeName == normalizedInput { + logger.Debug("Resolved table name '%s' to field '%s' (type: %s)", nameOrTable, field.Name, typeName) + return field.Name + } + } + } + } + + // If no match found, return the original input + logger.Debug("No field found for '%s', using as-is", nameOrTable) + return nameOrTable +} + // addXFilesPreload converts an XFiles relation into a PreloadOption // and recursively processes its children func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) { @@ -662,7 +831,8 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption return } - // Determine the relation path + // Store the table name as-is for now - it will be resolved to field name later + // when we have the model instance available relationPath := xfile.TableName if basePath != "" { relationPath = basePath + "." + xfile.TableName