diff --git a/pkg/common/validation.go b/pkg/common/validation.go index e70d589..ba97bb5 100644 --- a/pkg/common/validation.go +++ b/pkg/common/validation.go @@ -92,9 +92,27 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string { return strings.ToLower(field.Name) } +// extractSourceColumn extracts the base column name from PostgreSQL JSON operators +// Examples: +// - "columna->>'val'" returns "columna" +// - "columna->'key'" returns "columna" +// - "columna" returns "columna" +// - "table.columna->>'val'" returns "table.columna" +func extractSourceColumn(colName string) string { + // Check for PostgreSQL JSON operators: -> and ->> + if idx := strings.Index(colName, "->>"); idx != -1 { + return strings.TrimSpace(colName[:idx]) + } + if idx := strings.Index(colName, "->"); idx != -1 { + return strings.TrimSpace(colName[:idx]) + } + return colName +} + // ValidateColumn validates a single column name // Returns nil if valid, error if invalid // Columns prefixed with "cql" (case insensitive) are always valid +// Handles PostgreSQL JSON operators (-> and ->>) func (v *ColumnValidator) ValidateColumn(column string) error { // Allow empty columns if column == "" { @@ -106,8 +124,11 @@ func (v *ColumnValidator) ValidateColumn(column string) error { return nil } + // Extract source column name (remove JSON operators like ->> or ->) + sourceColumn := extractSourceColumn(column) + // Check if column exists in model - if _, exists := v.validColumns[strings.ToLower(column)]; !exists { + if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists { return fmt.Errorf("invalid column '%s': column does not exist in model", column) } diff --git a/pkg/common/validation_json_test.go b/pkg/common/validation_json_test.go new file mode 100644 index 0000000..1c6273a --- /dev/null +++ b/pkg/common/validation_json_test.go @@ -0,0 +1,124 @@ +package common + +import ( + "testing" +) + +func TestExtractSourceColumn(t *testing.T) { + testCases := []struct { + name string + input string + expected string + }{ + { + name: "simple column name", + input: "columna", + expected: "columna", + }, + { + name: "column with ->> operator", + input: "columna->>'val'", + expected: "columna", + }, + { + name: "column with -> operator", + input: "columna->'key'", + expected: "columna", + }, + { + name: "column with table prefix and ->> operator", + input: "table.columna->>'val'", + expected: "table.columna", + }, + { + name: "column with table prefix and -> operator", + input: "table.columna->'key'", + expected: "table.columna", + }, + { + name: "complex JSON path with ->>", + input: "data->>'nested'->>'value'", + expected: "data", + }, + { + name: "column with spaces before operator", + input: "columna ->>'val'", + expected: "columna", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := extractSourceColumn(tc.input) + if result != tc.expected { + t.Errorf("extractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected) + } + }) + } +} + +func TestValidateColumnWithJSONOperators(t *testing.T) { + // Create a test model + type TestModel struct { + ID int `json:"id"` + Name string `json:"name"` + Data string `json:"data"` // JSON column + Metadata string `json:"metadata"` + } + + validator := NewColumnValidator(TestModel{}) + + testCases := []struct { + name string + column string + shouldErr bool + }{ + { + name: "simple valid column", + column: "name", + shouldErr: false, + }, + { + name: "valid column with ->> operator", + column: "data->>'field'", + shouldErr: false, + }, + { + name: "valid column with -> operator", + column: "metadata->'key'", + shouldErr: false, + }, + { + name: "invalid column", + column: "invalid_column", + shouldErr: true, + }, + { + name: "invalid column with ->> operator", + column: "invalid_column->>'field'", + shouldErr: true, + }, + { + name: "cql prefixed column (always valid)", + column: "cql_computed", + shouldErr: false, + }, + { + name: "empty column", + column: "", + shouldErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validator.ValidateColumn(tc.column) + if tc.shouldErr && err == nil { + t.Errorf("ValidateColumn(%q) expected error, got nil", tc.column) + } + if !tc.shouldErr && err != nil { + t.Errorf("ValidateColumn(%q) expected no error, got %v", tc.column, err) + } + }) + } +} diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 8f06f31..9f37acc 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -1647,10 +1647,13 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac data = h.normalizeResultArray(data) } - response := common.Response{ - Success: true, - Data: data, - Metadata: metadata, + response := data + if response == nil { + response = common.Response{ + Success: true, + Data: data, + Metadata: metadata, + } } w.WriteHeader(http.StatusOK) if err := w.WriteJSON(response); err != nil { diff --git a/pkg/restheadspec/headers.go b/pkg/restheadspec/headers.go index 053f834..f105703 100644 --- a/pkg/restheadspec/headers.go +++ b/pkg/restheadspec/headers.go @@ -480,12 +480,32 @@ func (h *Handler) parseCommaSeparated(value string) []string { return result } +// extractSourceColumn extracts the base column name from PostgreSQL JSON operators +// Examples: +// - "columna->>'val'" returns "columna" +// - "columna->'key'" returns "columna" +// - "columna" returns "columna" +// - "table.columna->>'val'" returns "table.columna" +func extractSourceColumn(colName string) string { + // Check for PostgreSQL JSON operators: -> and ->> + if idx := strings.Index(colName, "->>"); idx != -1 { + return strings.TrimSpace(colName[:idx]) + } + if idx := strings.Index(colName, "->"); idx != -1 { + return strings.TrimSpace(colName[:idx]) + } + return colName +} + // getColumnTypeFromModel uses reflection to determine the Go type of a column in a model func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) reflect.Kind { if model == nil { return reflect.Invalid } + // Extract the source column name (remove JSON operators like ->> or ->) + sourceColName := extractSourceColumn(colName) + modelType := reflect.TypeOf(model) // Dereference pointer if needed if modelType.Kind() == reflect.Ptr { @@ -506,19 +526,19 @@ func (h *Handler) getColumnTypeFromModel(model interface{}, colName string) refl if jsonTag != "" { // Parse JSON tag (format: "name,omitempty") parts := strings.Split(jsonTag, ",") - if parts[0] == colName { + if parts[0] == sourceColName { return field.Type.Kind() } } // Check field name (case-insensitive) - if strings.EqualFold(field.Name, colName) { + if strings.EqualFold(field.Name, sourceColName) { return field.Type.Kind() } // Check snake_case conversion snakeCaseName := toSnakeCase(field.Name) - if snakeCaseName == colName { + if snakeCaseName == sourceColName { return field.Type.Kind() } }