diff --git a/pkg/common/validation.go b/pkg/common/validation.go new file mode 100644 index 0000000..a70fed2 --- /dev/null +++ b/pkg/common/validation.go @@ -0,0 +1,272 @@ +package common + +import ( + "fmt" + "reflect" + "strings" + + "github.com/Warky-Devs/ResolveSpec/pkg/logger" +) + +// ColumnValidator validates column names against a model's fields +type ColumnValidator struct { + validColumns map[string]bool + model interface{} +} + +// NewColumnValidator creates a new column validator for a given model +func NewColumnValidator(model interface{}) *ColumnValidator { + validator := &ColumnValidator{ + validColumns: make(map[string]bool), + model: model, + } + validator.buildValidColumns() + return validator +} + +// buildValidColumns extracts all valid column names from the model using reflection +func (v *ColumnValidator) buildValidColumns() { + modelType := reflect.TypeOf(v.model) + + // Unwrap pointers, slices, and arrays to get to the base struct type + for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { + modelType = modelType.Elem() + } + + // Validate that we have a struct type + if modelType == nil || modelType.Kind() != reflect.Struct { + return + } + + // Extract column names from struct fields + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + + if !field.IsExported() { + continue + } + + // Get column name from bun, gorm, or json tag + columnName := v.getColumnName(field) + if columnName != "" && columnName != "-" { + v.validColumns[strings.ToLower(columnName)] = true + } + } +} + +// getColumnName extracts the column name from a struct field's tags +// Supports both Bun and GORM tags +func (v *ColumnValidator) getColumnName(field reflect.StructField) string { + // First check Bun tag for column name + bunTag := field.Tag.Get("bun") + if bunTag != "" && bunTag != "-" { + parts := strings.Split(bunTag, ",") + // The first part is usually the column name + columnName := strings.TrimSpace(parts[0]) + if columnName != "" && columnName != "-" { + return columnName + } + } + + // Check GORM tag for column name + gormTag := field.Tag.Get("gorm") + if strings.Contains(gormTag, "column:") { + parts := strings.Split(gormTag, ";") + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "column:") { + return strings.TrimPrefix(part, "column:") + } + } + } + + // Fall back to JSON tag + jsonTag := field.Tag.Get("json") + if jsonTag != "" && jsonTag != "-" { + // Extract just the name part (before any comma) + jsonName := strings.Split(jsonTag, ",")[0] + return jsonName + } + + // Fall back to field name in lowercase (snake_case conversion would be better) + return strings.ToLower(field.Name) +} + +// ValidateColumn validates a single column name +// Returns nil if valid, error if invalid +// Columns prefixed with "cql" (case insensitive) are always valid +func (v *ColumnValidator) ValidateColumn(column string) error { + // Allow empty columns + if column == "" { + return nil + } + + // Allow columns prefixed with "cql" (case insensitive) for computed columns + if strings.HasPrefix(strings.ToLower(column), "cql") { + return nil + } + + // Check if column exists in model + if _, exists := v.validColumns[strings.ToLower(column)]; !exists { + return fmt.Errorf("invalid column '%s': column does not exist in model", column) + } + + return nil +} + +// IsValidColumn checks if a column is valid +// Returns true if valid, false if invalid +func (v *ColumnValidator) IsValidColumn(column string) bool { + return v.ValidateColumn(column) == nil +} + +// FilterValidColumns filters a list of columns, returning only valid ones +// Logs warnings for any invalid columns +func (v *ColumnValidator) FilterValidColumns(columns []string) []string { + if len(columns) == 0 { + return columns + } + + validColumns := make([]string, 0, len(columns)) + for _, col := range columns { + if v.IsValidColumn(col) { + validColumns = append(validColumns, col) + } else { + logger.Warn("Invalid column '%s' filtered out: column does not exist in model", col) + } + } + return validColumns +} + +// ValidateColumns validates multiple column names +// Returns error with details about all invalid columns +func (v *ColumnValidator) ValidateColumns(columns []string) error { + var invalidColumns []string + + for _, column := range columns { + if err := v.ValidateColumn(column); err != nil { + invalidColumns = append(invalidColumns, column) + } + } + + if len(invalidColumns) > 0 { + return fmt.Errorf("invalid columns: %s", strings.Join(invalidColumns, ", ")) + } + + return nil +} + +// ValidateRequestOptions validates all column references in RequestOptions +func (v *ColumnValidator) ValidateRequestOptions(options RequestOptions) error { + // Validate Columns + if err := v.ValidateColumns(options.Columns); err != nil { + return fmt.Errorf("in select columns: %w", err) + } + + // Validate OmitColumns + if err := v.ValidateColumns(options.OmitColumns); err != nil { + return fmt.Errorf("in omit columns: %w", err) + } + + // Validate Filter columns + for _, filter := range options.Filters { + if err := v.ValidateColumn(filter.Column); err != nil { + return fmt.Errorf("in filter: %w", err) + } + } + + // Validate Sort columns + for _, sort := range options.Sort { + if err := v.ValidateColumn(sort.Column); err != nil { + return fmt.Errorf("in sort: %w", err) + } + } + + // Validate Preload columns (if specified) + for _, preload := range options.Preload { + // Note: We don't validate the relation name itself, as it's a relationship + // Only validate columns if specified for the preload + if err := v.ValidateColumns(preload.Columns); err != nil { + return fmt.Errorf("in preload '%s' columns: %w", preload.Relation, err) + } + if err := v.ValidateColumns(preload.OmitColumns); err != nil { + return fmt.Errorf("in preload '%s' omit columns: %w", preload.Relation, err) + } + + // Validate filter columns in preload + for _, filter := range preload.Filters { + if err := v.ValidateColumn(filter.Column); err != nil { + return fmt.Errorf("in preload '%s' filter: %w", preload.Relation, err) + } + } + } + + return nil +} + +// FilterRequestOptions filters all column references in RequestOptions +// Returns a new RequestOptions with only valid columns, logging warnings for invalid ones +func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOptions { + filtered := options + + // Filter Columns + filtered.Columns = v.FilterValidColumns(options.Columns) + + // Filter OmitColumns + filtered.OmitColumns = v.FilterValidColumns(options.OmitColumns) + + // Filter Filter columns + validFilters := make([]FilterOption, 0, len(options.Filters)) + for _, filter := range options.Filters { + if v.IsValidColumn(filter.Column) { + validFilters = append(validFilters, filter) + } else { + logger.Warn("Invalid column in filter '%s' removed", filter.Column) + } + } + filtered.Filters = validFilters + + // Filter Sort columns + validSorts := make([]SortOption, 0, len(options.Sort)) + for _, sort := range options.Sort { + if v.IsValidColumn(sort.Column) { + validSorts = append(validSorts, sort) + } else { + logger.Warn("Invalid column in sort '%s' removed", sort.Column) + } + } + filtered.Sort = validSorts + + // Filter Preload columns + validPreloads := make([]PreloadOption, 0, len(options.Preload)) + for _, preload := range options.Preload { + filteredPreload := preload + filteredPreload.Columns = v.FilterValidColumns(preload.Columns) + filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns) + + // Filter preload filters + validPreloadFilters := make([]FilterOption, 0, len(preload.Filters)) + for _, filter := range preload.Filters { + if v.IsValidColumn(filter.Column) { + validPreloadFilters = append(validPreloadFilters, filter) + } else { + logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column) + } + } + filteredPreload.Filters = validPreloadFilters + + validPreloads = append(validPreloads, filteredPreload) + } + filtered.Preload = validPreloads + + return filtered +} + +// GetValidColumns returns a list of all valid column names for debugging purposes +func (v *ColumnValidator) GetValidColumns() []string { + columns := make([]string, 0, len(v.validColumns)) + for col := range v.validColumns { + columns = append(columns, col) + } + return columns +} diff --git a/pkg/common/validation_test.go b/pkg/common/validation_test.go new file mode 100644 index 0000000..a68be98 --- /dev/null +++ b/pkg/common/validation_test.go @@ -0,0 +1,363 @@ +package common + +import ( + "strings" + "testing" +) + +// TestModel represents a sample model for testing +type TestModel struct { + ID int64 `json:"id" gorm:"primaryKey"` + Name string `json:"name" gorm:"column:name"` + Email string `json:"email" bun:"email"` + Age int `json:"age"` + IsActive bool `json:"is_active"` + CreatedAt string `json:"created_at"` +} + +func TestNewColumnValidator(t *testing.T) { + model := TestModel{} + validator := NewColumnValidator(model) + + if validator == nil { + t.Fatal("Expected validator to be created") + } + + if len(validator.validColumns) == 0 { + t.Fatal("Expected validator to have valid columns") + } + + // Check that expected columns are present + expectedColumns := []string{"id", "name", "email", "age", "is_active", "created_at"} + for _, col := range expectedColumns { + if !validator.validColumns[col] { + t.Errorf("Expected column '%s' to be valid", col) + } + } +} + +func TestValidateColumn(t *testing.T) { + model := TestModel{} + validator := NewColumnValidator(model) + + tests := []struct { + name string + column string + shouldError bool + }{ + {"Valid column - id", "id", false}, + {"Valid column - name", "name", false}, + {"Valid column - email", "email", false}, + {"Valid column - uppercase", "ID", false}, // Case insensitive + {"Invalid column", "invalid_column", true}, + {"CQL prefixed - should be valid", "cqlComputedField", false}, + {"CQL prefixed uppercase - should be valid", "CQLComputedField", false}, + {"Empty column", "", false}, // Empty columns are allowed + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.ValidateColumn(tt.column) + if tt.shouldError && err == nil { + t.Errorf("Expected error for column '%s', got nil", tt.column) + } + if !tt.shouldError && err != nil { + t.Errorf("Expected no error for column '%s', got: %v", tt.column, err) + } + }) + } +} + +func TestValidateColumns(t *testing.T) { + model := TestModel{} + validator := NewColumnValidator(model) + + tests := []struct { + name string + columns []string + shouldError bool + }{ + {"All valid columns", []string{"id", "name", "email"}, false}, + {"One invalid column", []string{"id", "invalid_col", "name"}, true}, + {"All invalid columns", []string{"bad1", "bad2"}, true}, + {"With CQL prefix", []string{"id", "cqlComputed", "name"}, false}, + {"Empty list", []string{}, false}, + {"Nil list", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.ValidateColumns(tt.columns) + if tt.shouldError && err == nil { + t.Errorf("Expected error for columns %v, got nil", tt.columns) + } + if !tt.shouldError && err != nil { + t.Errorf("Expected no error for columns %v, got: %v", tt.columns, err) + } + }) + } +} + +func TestValidateRequestOptions(t *testing.T) { + model := TestModel{} + validator := NewColumnValidator(model) + + tests := []struct { + name string + options RequestOptions + shouldError bool + errorMsg string + }{ + { + name: "Valid options with columns", + options: RequestOptions{ + Columns: []string{"id", "name"}, + Filters: []FilterOption{ + {Column: "name", Operator: "eq", Value: "test"}, + }, + Sort: []SortOption{ + {Column: "id", Direction: "ASC"}, + }, + }, + shouldError: false, + }, + { + name: "Invalid column in Columns", + options: RequestOptions{ + Columns: []string{"id", "invalid_column"}, + }, + shouldError: true, + errorMsg: "select columns", + }, + { + name: "Invalid column in Filters", + options: RequestOptions{ + Filters: []FilterOption{ + {Column: "invalid_col", Operator: "eq", Value: "test"}, + }, + }, + shouldError: true, + errorMsg: "filter", + }, + { + name: "Invalid column in Sort", + options: RequestOptions{ + Sort: []SortOption{ + {Column: "invalid_col", Direction: "ASC"}, + }, + }, + shouldError: true, + errorMsg: "sort", + }, + { + name: "Valid CQL prefixed columns", + options: RequestOptions{ + Columns: []string{"id", "cqlComputedField"}, + Filters: []FilterOption{ + {Column: "cqlCustomFilter", Operator: "eq", Value: "test"}, + }, + }, + shouldError: false, + }, + { + name: "Invalid column in Preload", + options: RequestOptions{ + Preload: []PreloadOption{ + { + Relation: "SomeRelation", + Columns: []string{"id", "invalid_col"}, + }, + }, + }, + shouldError: true, + errorMsg: "preload", + }, + { + name: "Valid preload with valid columns", + options: RequestOptions{ + Preload: []PreloadOption{ + { + Relation: "SomeRelation", + Columns: []string{"id", "name"}, + }, + }, + }, + shouldError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validator.ValidateRequestOptions(tt.options) + if tt.shouldError { + if err == nil { + t.Errorf("Expected error, got nil") + } else if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) { + t.Errorf("Expected error to contain '%s', got: %v", tt.errorMsg, err) + } + } else { + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + } + }) + } +} + +func TestGetValidColumns(t *testing.T) { + model := TestModel{} + validator := NewColumnValidator(model) + + columns := validator.GetValidColumns() + if len(columns) == 0 { + t.Error("Expected to get valid columns, got empty list") + } + + // Should have at least the columns from TestModel + if len(columns) < 6 { + t.Errorf("Expected at least 6 columns, got %d", len(columns)) + } +} + +// Test with Bun tags specifically +type BunModel struct { + ID int64 `bun:"id,pk"` + Name string `bun:"name"` + Email string `bun:"user_email"` +} + +func TestBunTagSupport(t *testing.T) { + model := BunModel{} + validator := NewColumnValidator(model) + + // Test that bun tags are properly recognized + tests := []struct { + column string + shouldError bool + }{ + {"id", false}, + {"name", false}, + {"user_email", false}, // Bun tag specifies this name + {"email", true}, // JSON tag would be "email", but bun tag says "user_email" + } + + for _, tt := range tests { + t.Run(tt.column, func(t *testing.T) { + err := validator.ValidateColumn(tt.column) + if tt.shouldError && err == nil { + t.Errorf("Expected error for column '%s'", tt.column) + } + if !tt.shouldError && err != nil { + t.Errorf("Expected no error for column '%s', got: %v", tt.column, err) + } + }) + } +} + +func TestFilterValidColumns(t *testing.T) { + model := TestModel{} + validator := NewColumnValidator(model) + + tests := []struct { + name string + input []string + expectedOutput []string + }{ + { + name: "All valid columns", + input: []string{"id", "name", "email"}, + expectedOutput: []string{"id", "name", "email"}, + }, + { + name: "Mix of valid and invalid", + input: []string{"id", "invalid_col", "name", "bad_col", "email"}, + expectedOutput: []string{"id", "name", "email"}, + }, + { + name: "All invalid columns", + input: []string{"bad1", "bad2"}, + expectedOutput: []string{}, + }, + { + name: "With CQL prefix (should pass)", + input: []string{"id", "cqlComputed", "name"}, + expectedOutput: []string{"id", "cqlComputed", "name"}, + }, + { + name: "Empty input", + input: []string{}, + expectedOutput: []string{}, + }, + { + name: "Nil input", + input: nil, + expectedOutput: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := validator.FilterValidColumns(tt.input) + if len(result) != len(tt.expectedOutput) { + t.Errorf("Expected %d columns, got %d", len(tt.expectedOutput), len(result)) + } + for i, col := range result { + if col != tt.expectedOutput[i] { + t.Errorf("At index %d: expected %s, got %s", i, tt.expectedOutput[i], col) + } + } + }) + } +} + +func TestFilterRequestOptions(t *testing.T) { + model := TestModel{} + validator := NewColumnValidator(model) + + options := RequestOptions{ + Columns: []string{"id", "name", "invalid_col"}, + OmitColumns: []string{"email", "bad_col"}, + Filters: []FilterOption{ + {Column: "name", Operator: "eq", Value: "test"}, + {Column: "invalid_col", Operator: "eq", Value: "test"}, + }, + Sort: []SortOption{ + {Column: "id", Direction: "ASC"}, + {Column: "bad_col", Direction: "DESC"}, + }, + } + + filtered := validator.FilterRequestOptions(options) + + // Check Columns + if len(filtered.Columns) != 2 { + t.Errorf("Expected 2 columns, got %d", len(filtered.Columns)) + } + if filtered.Columns[0] != "id" || filtered.Columns[1] != "name" { + t.Errorf("Expected columns [id, name], got %v", filtered.Columns) + } + + // Check OmitColumns + if len(filtered.OmitColumns) != 1 { + t.Errorf("Expected 1 omit column, got %d", len(filtered.OmitColumns)) + } + if filtered.OmitColumns[0] != "email" { + t.Errorf("Expected omit column [email], got %v", filtered.OmitColumns) + } + + // Check Filters + if len(filtered.Filters) != 1 { + t.Errorf("Expected 1 filter, got %d", len(filtered.Filters)) + } + if filtered.Filters[0].Column != "name" { + t.Errorf("Expected filter column 'name', got %s", filtered.Filters[0].Column) + } + + // Check Sort + if len(filtered.Sort) != 1 { + t.Errorf("Expected 1 sort, got %d", len(filtered.Sort)) + } + if filtered.Sort[0].Column != "id" { + t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column) + } +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 92a1098..17ada65 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -100,6 +100,10 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s // Add request-scoped data to context ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr) + // Validate and filter columns in options (log warnings for invalid columns) + validator := common.NewColumnValidator(model) + req.Options = validator.FilterRequestOptions(req.Options) + switch req.Operation { case "read": h.handleRead(ctx, w, id, req.Options) diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 5bff8f8..06130ab 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -93,6 +93,10 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s // Add request-scoped data to context ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr) + // Validate and filter columns in options (log warnings for invalid columns) + validator := common.NewColumnValidator(model) + options = filterExtendedOptions(validator, options) + switch method { case "GET": if id != "" { @@ -750,3 +754,41 @@ func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, messa w.WriteHeader(statusCode) w.WriteJSON(response) } + +// filterExtendedOptions filters all column references, removing invalid ones and logging warnings +func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions { + filtered := options + + // Filter base RequestOptions + filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions) + + // Filter SearchColumns + filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns) + + // Filter AdvancedSQL column keys + filteredAdvSQL := make(map[string]string) + for colName, sqlExpr := range options.AdvancedSQL { + if validator.IsValidColumn(colName) { + filteredAdvSQL[colName] = sqlExpr + } else { + logger.Warn("Invalid column in advanced SQL removed: %s", colName) + } + } + filtered.AdvancedSQL = filteredAdvSQL + + // ComputedQL columns are allowed to be any name since they're computed + // No filtering needed for ComputedQL keys + filtered.ComputedQL = options.ComputedQL + + // Filter Expand columns + filteredExpands := make([]ExpandOption, 0, len(options.Expand)) + for _, expand := range options.Expand { + filteredExpand := expand + // Don't validate relation name, only columns + filteredExpand.Columns = validator.FilterValidColumns(expand.Columns) + filteredExpands = append(filteredExpands, filteredExpand) + } + filtered.Expand = filteredExpands + + return filtered +} diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index 737fa88..068058b 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -57,27 +57,43 @@ type ExpandOption struct { // decodeHeaderValue decodes base64 encoded header values // Supports ZIP_ and __ prefixes for base64 encoding func decodeHeaderValue(value string) string { - // Check for ZIP_ prefix - if strings.HasPrefix(value, "ZIP_") { - decoded, err := base64.StdEncoding.DecodeString(value[4:]) - if err == nil { - return string(decoded) + str, _ := DecodeParam(value) + return str +} + +// DecodeParam - Decodes parameter string and returns unencoded string +func DecodeParam(pStr string) (string, error) { + var code string = pStr + if strings.HasPrefix(pStr, "ZIP_") { + code = strings.ReplaceAll(pStr, "ZIP_", "") + code = strings.ReplaceAll(code, "\n", "") + code = strings.ReplaceAll(code, "\r", "") + code = strings.ReplaceAll(code, " ", "") + strDat, err := base64.StdEncoding.DecodeString(code) + if err != nil { + return code, fmt.Errorf("failed to read parameter base64: %v", err) + } else { + code = string(strDat) + } + } else if strings.HasPrefix(pStr, "__") { + code = strings.ReplaceAll(pStr, "__", "") + code = strings.ReplaceAll(code, "\n", "") + code = strings.ReplaceAll(code, "\r", "") + code = strings.ReplaceAll(code, " ", "") + + strDat, err := base64.StdEncoding.DecodeString(code) + if err != nil { + return code, fmt.Errorf("failed to read parameter base64: %v", err) + } else { + code = string(strDat) } - logger.Warn("Failed to decode ZIP_ prefixed value: %v", err) - return value } - // Check for __ prefix - if strings.HasPrefix(value, "__") { - decoded, err := base64.StdEncoding.DecodeString(value[2:]) - if err == nil { - return string(decoded) - } - logger.Warn("Failed to decode __ prefixed value: %v", err) - return value + if strings.HasPrefix(code, "ZIP_") || strings.HasPrefix(code, "__") { + code, _ = DecodeParam(code) } - return value + return code, nil } // parseOptionsFromHeaders parses all request options from HTTP headers