diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index d0a5cbe..7e81953 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -127,7 +127,7 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s // Validate and filter columns in options (log warnings for invalid columns) validator := common.NewColumnValidator(model) - options = filterExtendedOptions(validator, options) + options = h.filterExtendedOptions(validator, options, model) // Add request-scoped data to context (including options) ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options) @@ -2241,7 +2241,7 @@ func (h *Handler) setRowNumbersOnRecords(records any, offset int) { } // filterExtendedOptions filters all column references, removing invalid ones and logging warnings -func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions { +func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions, model interface{}) ExtendedRequestOptions { filtered := options // Filter base RequestOptions @@ -2265,15 +2265,33 @@ func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRe // No filtering needed for ComputedQL keys filtered.ComputedQL = options.ComputedQL - // Filter Expand columns - // filteredExpands := make([]ExpandOption, 0, len(options.Expand)) - // for _, expand := range options.Expand { - // filteredExpand := expand - // // Don't validate relation name, only columns - // filteredExpand.Columns = validator.FilterValidColumns(expand.Columns) - // filteredExpands = append(filteredExpands, filteredExpand) - // } - // filtered.Expand = filteredExpands + // Filter Expand columns using the expand relation's model + filteredExpands := make([]ExpandOption, 0, len(options.Expand)) + modelType := reflect.TypeOf(model) + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + for _, expand := range options.Expand { + filteredExpand := expand + + // Get the relationship info for this expand relation + relInfo := h.getRelationshipInfo(modelType, expand.Relation) + if relInfo != nil && relInfo.relatedModel != nil { + // Create a validator for the related model + expandValidator := common.NewColumnValidator(relInfo.relatedModel) + // Filter columns using the related model's validator + filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns) + } else { + // If we can't find the relationship, log a warning and skip column filtering + logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation) + // Keep the columns as-is if we can't validate them + filteredExpand.Columns = expand.Columns + } + + filteredExpands = append(filteredExpands, filteredExpand) + } + filtered.Expand = filteredExpands return filtered }