Compare commits

...

17 Commits

Author SHA1 Message Date
Hein c120b49529 fix(router): prevent HTML escaping in JSON responses
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
fix(sql_helpers): avoid prefix extraction in subqueries
2026-06-08 15:13:58 +02:00
Hein 66348dac97 test(handler): add tests for valid nested request verbs 2026-06-08 09:06:29 +02:00
Hein a87cd18b1b fix(handler): validate nested request structure for relations
* added checks for valid _request values in single and multiple relations
* introduced isValidNestedRequest function to encapsulate validation logic
fix(crud): expand operation handling for nested CUD
* added "add" to insert operations and "modify" to update operations
* included "remove" in delete operations
2026-06-08 09:02:29 +02:00
Hein 29449c93d5 fix(test): add tests for asymmetric join column handling
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2026-06-07 19:13:59 +02:00
Hein 3b6e5c75be fix(handler): update foreign key field resolution logic
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Has been cancelled
Build , Vet Test, and Lint / Lint Code (push) Has been cancelled
Build , Vet Test, and Lint / Build (push) Has been cancelled
Tests / Unit Tests (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
* Adjust foreign key field name selection for has-many/has-one relationships
* Improve logging to clarify foreign key and child field usage
2026-06-07 14:20:55 +02:00
Hein 549ccb8468 fix(handler): fetch updated records after transaction commits
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Has been cancelled
Build , Vet Test, and Lint / Lint Code (push) Has been cancelled
Build , Vet Test, and Lint / Build (push) Has been cancelled
Tests / Unit Tests (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
* Update selection queries to use model columns
* Ensure updated records are fetched and returned in responses
2026-06-05 11:12:04 +02:00
Hein 1af9c76337 fix(handler): fetch updated record after transaction commits
Tests / Unit Tests (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Has been cancelled
Build , Vet Test, and Lint / Lint Code (push) Has been cancelled
Build , Vet Test, and Lint / Build (push) Has been cancelled
2026-06-04 18:23:18 +02:00
Hein 938a2ef3d9 fix(staticweb): add fallback for extensionless file paths
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Failing after 6s
Tests / Integration Tests (push) Failing after 13m59s
Tests / Unit Tests (push) Failing after 14m11s
Build , Vet Test, and Lint / Build (push) Failing after 14m21s
Build , Vet Test, and Lint / Lint Code (push) Failing after 14m31s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Failing after 14m45s
2026-05-27 18:41:43 +02:00
Hein 69cc3e2839 fix(db): update Returning method to accept multiple columns 2026-05-27 14:11:20 +02:00
Hein 4018af0636 fix(validation): enhance filter logic for column validation
* adjust handling of "all" filter to consider filtered columns
fix(function_api): improve variable substitution in SQL queries
* add safeSubstituteVar for context-aware value sanitization
2026-05-27 12:17:31 +02:00
Hein c4e79d6950 fix(validation): use strings.EqualFold for case-insensitive comparison 2026-05-27 12:07:08 +02:00
Hein 982a0e62ac fix(validation): add Columns method to retrieve valid column names 2026-05-27 12:06:46 +02:00
Hein 5d459c95a7 fix(headers): reorder import statements for clarity 2026-05-27 11:28:39 +02:00
Hein e9f7726e43 fix(headers): sort combined parameters before processing 2026-05-27 11:28:22 +02:00
Hein 3d2251317a fix(headers): remove unused utf8 validation in DecodeParam
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Failing after 1s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Failing after 1s
Build , Vet Test, and Lint / Lint Code (push) Failing after 1s
Build , Vet Test, and Lint / Build (push) Failing after 1s
Tests / Unit Tests (push) Failing after 0s
Tests / Integration Tests (push) Failing after 1s
2026-05-26 10:31:34 +02:00
Hein 1ce0ab1ab4 fix(validation): improve preload column validation logic
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Failing after -35m6s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Failing after -35m6s
Build , Vet Test, and Lint / Lint Code (push) Failing after -35m6s
Build , Vet Test, and Lint / Build (push) Failing after -35m6s
Tests / Unit Tests (push) Failing after -35m7s
Tests / Integration Tests (push) Failing after -35m7s
2026-05-21 20:18:01 +02:00
Hein 1f9b230f7f fix(validation): improve preload column validation logic
* Use related model's validator for filtering preload columns
* Ensure valid columns are checked against the correct validator
2026-05-21 20:16:53 +02:00
14 changed files with 586 additions and 67 deletions
+2 -2
View File
@@ -1489,7 +1489,7 @@ func (b *BunInsertQuery) OnConflict(action string) common.InsertQuery {
func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery { func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
if len(columns) > 0 { if len(columns) > 0 {
b.query = b.query.Returning(columns[0]) b.query = b.query.Returning(strings.Join(columns, ", "))
} }
return b return b
} }
@@ -1606,7 +1606,7 @@ func (b *BunUpdateQuery) Where(query string, args ...interface{}) common.UpdateQ
func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery { func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
if len(columns) > 0 { if len(columns) > 0 {
b.query = b.query.Returning(columns[0]) b.query = b.query.Returning(strings.Join(columns, ", "))
} }
return b return b
} }
+3 -1
View File
@@ -174,7 +174,9 @@ func (h *HTTPResponseWriter) Write(data []byte) (int, error) {
func (h *HTTPResponseWriter) WriteJSON(data interface{}) error { func (h *HTTPResponseWriter) WriteJSON(data interface{}) error {
h.SetHeader("Content-Type", "application/json") h.SetHeader("Content-Type", "application/json")
return json.NewEncoder(h.resp).Encode(data) enc := json.NewEncoder(h.resp)
enc.SetEscapeHTML(false)
return enc.Encode(data)
} }
// UnderlyingResponseWriter returns the underlying http.ResponseWriter // UnderlyingResponseWriter returns the underlying http.ResponseWriter
+3 -1
View File
@@ -178,7 +178,9 @@ func (s *StandardResponseWriter) Write(data []byte) (int, error) {
func (s *StandardResponseWriter) WriteJSON(data interface{}) error { func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
s.SetHeader("Content-Type", "application/json") s.SetHeader("Content-Type", "application/json")
return json.NewEncoder(s.w).Encode(data) enc := json.NewEncoder(s.w)
enc.SetEscapeHTML(false)
return enc.Encode(data)
} }
func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter { func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
+13 -9
View File
@@ -113,7 +113,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
// Process based on operation // Process based on operation
switch strings.ToLower(operation) { switch strings.ToLower(operation) {
case "insert", "create": case "insert", "create", "add":
// Only perform insert if we have data to insert // Only perform insert if we have data to insert
if hasData { if hasData {
id, err := p.processInsert(ctx, regularData, tableName) id, err := p.processInsert(ctx, regularData, tableName)
@@ -141,7 +141,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
logger.Debug("Skipping insert for %s - no data columns besides _request", tableName) logger.Debug("Skipping insert for %s - no data columns besides _request", tableName)
} }
case "update", "change": case "update", "change", "modify":
// Only perform update if we have data to update // Only perform update if we have data to update
if reflection.IsEmptyValue(data[pkName]) { if reflection.IsEmptyValue(data[pkName]) {
logger.Warn("Skipping update for %s - no primary key", tableName) logger.Warn("Skipping update for %s - no primary key", tableName)
@@ -174,7 +174,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
result.ID = data[pkName] result.ID = data[pkName]
} }
case "delete": case "delete", "remove":
if reflection.IsEmptyValue(data[pkName]) { if reflection.IsEmptyValue(data[pkName]) {
logger.Warn("Skipping delete for %s - no primary key", tableName) logger.Warn("Skipping delete for %s - no primary key", tableName)
return result, nil return result, nil
@@ -471,13 +471,17 @@ func (p *NestedCUDProcessor) processChildRelations(
// Priority: Use foreign key field name if specified // Priority: Use foreign key field name if specified
var foreignKeyFieldName string var foreignKeyFieldName string
if relInfo.ForeignKey != "" { if relInfo.ForeignKey != "" {
// Get the JSON name for the foreign key field in the child model // For has-many/has-one: join:parentCol=childCol
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey) // ForeignKey = parent side, References = child side (where we actually set the value)
if foreignKeyFieldName == "" { childField := relInfo.ForeignKey
// Fallback to lowercase field name if (relInfo.RelationType == "hasMany" || relInfo.RelationType == "hasOne") && relInfo.References != "" {
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey) childField = relInfo.References
} }
logger.Debug("Using foreign key field for direct assignment: %s (from FK %s)", foreignKeyFieldName, relInfo.ForeignKey) foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, childField)
if foreignKeyFieldName == "" {
foreignKeyFieldName = strings.ToLower(childField)
}
logger.Debug("Using foreign key field for direct assignment: %s (from FK %s -> child %s)", foreignKeyFieldName, relInfo.ForeignKey, childField)
} }
// Get the primary key name for the child model to avoid overwriting it in recursive relationships // Get the primary key name for the child model to avoid overwriting it in recursive relationships
+214
View File
@@ -713,6 +713,220 @@ func TestInjectForeignKeys(t *testing.T) {
} }
} }
// Models for asymmetric join column tests (mirrors the bun has-many join:parentCol=childCol pattern).
// ActionOption has-many ActionOptionLinks via join:rid_actionoption=rid_actionoption_child.
// The child column ("rid_actionoption_child") differs from the parent column ("rid_actionoption").
type ActionOption struct {
RidActionoption int64 `json:"rid_actionoption" bun:"rid_actionoption,pk"`
Label string `json:"label"`
Links []*ActionOptionLink `json:"aol_rid_actionoption_child,omitempty"`
}
func (a ActionOption) TableName() string { return "action_options" }
func (a ActionOption) GetIDName() string { return "RidActionoption" }
type ActionOptionLink struct {
RidActionoptionlink int64 `json:"rid_actionoptionlink" bun:"rid_actionoptionlink,pk"`
RidActionoptionChild int64 `json:"rid_actionoption_child" bun:"rid_actionoption_child"`
Label string `json:"label"`
// Note: no field named "rid_actionoption" — that is the parent's column.
}
func (a ActionOptionLink) TableName() string { return "action_option_links" }
func (a ActionOptionLink) GetIDName() string { return "RidActionoptionlink" }
// TestProcessNestedCUD_AsymmetricJoinColumns verifies that for a has-many relation with
// join:parentCol=childCol, the child rows are stamped with the child-side column (References),
// not the parent-side column (ForeignKey).
func TestProcessNestedCUD_AsymmetricJoinColumns(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
// Mirrors: bun:"rel:has-many,join:rid_actionoption=rid_actionoption_child"
relProvider.RegisterRelation("ActionOption", "aol_rid_actionoption_child", &RelationshipInfo{
FieldName: "Links",
JSONName: "aol_rid_actionoption_child",
RelationType: "hasMany",
ForeignKey: "rid_actionoption", // parent-side column (left of join:)
References: "rid_actionoption_child", // child-side column (right of join:)
RelatedModel: ActionOptionLink{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"label": "option-a",
"aol_rid_actionoption_child": []interface{}{
map[string]interface{}{"label": "link-1"},
},
}
_, err := processor.ProcessNestedCUD(
context.Background(),
"insert",
data,
ActionOption{},
nil,
"action_options",
)
if err != nil {
t.Fatalf("ProcessNestedCUD failed: %v", err)
}
if len(db.insertCalls) < 2 {
t.Fatalf("Expected at least 2 insert calls (parent + child), got %d", len(db.insertCalls))
}
childInsert := db.insertCalls[1]
// The fix: child must receive "rid_actionoption_child", NOT "rid_actionoption".
if childInsert["rid_actionoption_child"] == nil {
t.Error("Expected child to have rid_actionoption_child set (child-side FK column)")
}
if childInsert["rid_actionoption"] != nil {
t.Errorf("Child must not receive parent-side column rid_actionoption, got %v", childInsert["rid_actionoption"])
}
}
// TestProcessNestedCUD_BelongsToUnchanged verifies that the fix does not regress belongsTo
// relations, where ForeignKey is already the local (child) column.
func TestProcessNestedCUD_BelongsToUnchanged(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
// For belongsTo, ForeignKey is the column on the child; References is on the parent.
// The old and new code must behave identically here.
relProvider.RegisterRelation("Employee", "department", &RelationshipInfo{
FieldName: "Department",
JSONName: "department",
RelationType: "belongsTo",
ForeignKey: "DepartmentID", // child's own column
References: "ID", // parent's PK
RelatedModel: Department{},
})
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
FieldName: "Employees",
JSONName: "employees",
RelationType: "has_many",
ForeignKey: "DepartmentID",
RelatedModel: Employee{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"name": "Engineering",
"employees": []interface{}{
map[string]interface{}{"name": "Alice"},
},
}
_, err := processor.ProcessNestedCUD(
context.Background(),
"insert",
data,
Department{},
nil,
"departments",
)
if err != nil {
t.Fatalf("ProcessNestedCUD failed: %v", err)
}
if len(db.insertCalls) < 2 {
t.Fatalf("Expected at least 2 inserts, got %d", len(db.insertCalls))
}
// Employees relation uses has_many (old-style) so it goes through the parentIDs injection path,
// not the foreignKeyFieldName path. Just confirm no panic and employee is inserted.
if db.insertCalls[0]["name"] != "Engineering" {
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
}
}
func TestProcessNestedCUD_AddAlias(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"_request": "add",
"name": "New Department",
}
result, err := processor.ProcessNestedCUD(context.Background(), "insert", data, Department{}, nil, "departments")
if err != nil {
t.Fatalf("ProcessNestedCUD with _request=add failed: %v", err)
}
if result.ID == nil {
t.Error("Expected result.ID to be set after add")
}
if len(db.insertCalls) != 1 {
t.Errorf("Expected 1 insert call, got %d", len(db.insertCalls))
}
}
func TestProcessNestedCUD_RemoveAlias(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"_request": "remove",
"ID": int64(42),
}
_, err := processor.ProcessNestedCUD(context.Background(), "delete", data, Department{}, nil, "departments")
if err != nil {
t.Fatalf("ProcessNestedCUD with _request=remove failed: %v", err)
}
if len(db.deleteCalls) != 1 {
t.Errorf("Expected 1 delete call, got %d", len(db.deleteCalls))
}
}
func TestProcessNestedCUD_NestedAddRemoveAliases(t *testing.T) {
db := newMockDatabase()
registry := &mockModelRegistry{}
relProvider := newMockRelationshipProvider()
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
FieldName: "Employees",
JSONName: "employees",
RelationType: "has_many",
ForeignKey: "DepartmentID",
RelatedModel: Employee{},
})
processor := NewNestedCUDProcessor(db, registry, relProvider)
data := map[string]interface{}{
"ID": int64(1),
"name": "Engineering",
"employees": []interface{}{
map[string]interface{}{"_request": "add", "name": "Alice"},
map[string]interface{}{"_request": "remove", "ID": int64(5)},
},
}
_, err := processor.ProcessNestedCUD(context.Background(), "update", data, Department{}, nil, "departments")
if err != nil {
t.Fatalf("ProcessNestedCUD with nested add/remove failed: %v", err)
}
if len(db.insertCalls) != 1 {
t.Errorf("Expected 1 insert (add alias) for employee, got %d", len(db.insertCalls))
}
if len(db.deleteCalls) != 1 {
t.Errorf("Expected 1 delete (remove alias) for employee, got %d", len(db.deleteCalls))
}
}
func TestGetPrimaryKeyName(t *testing.T) { func TestGetPrimaryKeyName(t *testing.T) {
dept := Department{} dept := Department{}
pkName := reflection.GetPrimaryKeyName(dept) pkName := reflection.GetPrimaryKeyName(dept)
+9
View File
@@ -614,6 +614,15 @@ func extractTableAndColumn(cond string) (table string, column string) {
// Remove any quotes // Remove any quotes
columnRef = strings.Trim(columnRef, "`\"'") columnRef = strings.Trim(columnRef, "`\"'")
// If the left side is a parenthesized subquery (starts with '(' and contains SQL keywords),
// don't attempt prefix extraction from inside it.
if len(columnRef) > 0 && columnRef[0] == '(' {
lowerRef := strings.ToLower(columnRef)
if strings.Contains(lowerRef, "select ") || strings.Contains(lowerRef, " from ") || strings.Contains(lowerRef, " where ") {
return "", ""
}
}
// Check if there's a function call (contains opening parenthesis) // Check if there's a function call (contains opening parenthesis)
openParenIdx := strings.Index(columnRef, "(") openParenIdx := strings.Index(columnRef, "(")
+42 -6
View File
@@ -3,6 +3,7 @@ package common
import ( import (
"fmt" "fmt"
"reflect" "reflect"
"sort"
"strings" "strings"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -43,7 +44,7 @@ func (v *ColumnValidator) buildValidColumns() {
for i := 0; i < modelType.NumField(); i++ { for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i) field := modelType.Field(i)
if !field.IsExported() { if !field.IsExported() || field.Anonymous {
continue continue
} }
@@ -125,6 +126,16 @@ func (v *ColumnValidator) IsValidColumn(column string) bool {
return v.ValidateColumn(column) == nil return v.ValidateColumn(column) == nil
} }
// Columns returns all valid column names known to this validator
func (v *ColumnValidator) Columns() []string {
cols := make([]string, 0, len(v.validColumns))
for col := range v.validColumns {
cols = append(cols, col)
}
sort.Strings(cols)
return cols
}
// FilterValidColumns filters a list of columns, returning only valid ones // FilterValidColumns filters a list of columns, returning only valid ones
// Logs warnings for any invalid columns // Logs warnings for any invalid columns
func (v *ColumnValidator) FilterValidColumns(columns []string) []string { func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
@@ -224,7 +235,19 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
// Filter Filter columns // Filter Filter columns
validFilters := make([]FilterOption, 0, len(options.Filters)) validFilters := make([]FilterOption, 0, len(options.Filters))
for _, filter := range options.Filters { for _, filter := range options.Filters {
if v.IsValidColumn(filter.Column) { if strings.EqualFold(filter.Column, "all") {
allCols := v.Columns()
if len(filtered.Columns) > 0 {
allCols = filtered.Columns
}
for _, col := range allCols {
expanded := filter
expanded.Column = col
expanded.LogicOperator = "OR"
validFilters = append(validFilters, expanded)
}
} else if v.IsValidColumn(filter.Column) {
validFilters = append(validFilters, filter) validFilters = append(validFilters, filter)
} else { } else {
logger.Warn("Invalid column in filter '%s' removed", filter.Column) logger.Warn("Invalid column in filter '%s' removed", filter.Column)
@@ -266,11 +289,24 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
// Filter Preload columns // Filter Preload columns
validPreloads := make([]PreloadOption, 0, len(options.Preload)) validPreloads := make([]PreloadOption, 0, len(options.Preload))
modelType := reflect.TypeOf(v.model)
if modelType != nil && modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
for idx := range options.Preload { for idx := range options.Preload {
preload := options.Preload[idx] preload := options.Preload[idx]
filteredPreload := preload filteredPreload := preload
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns) // Use the related model's validator for preload columns/filters/sorts
preloadValidator := v
if modelType != nil {
if relInfo := GetRelationshipInfo(modelType, preload.Relation); relInfo != nil && relInfo.RelatedModel != nil {
preloadValidator = NewColumnValidator(relInfo.RelatedModel)
}
}
filteredPreload.Columns = preloadValidator.FilterValidColumns(preload.Columns)
filteredPreload.OmitColumns = preloadValidator.FilterValidColumns(preload.OmitColumns)
// Preserve SqlJoins and JoinAliases for preloads with custom joins // Preserve SqlJoins and JoinAliases for preloads with custom joins
filteredPreload.SqlJoins = preload.SqlJoins filteredPreload.SqlJoins = preload.SqlJoins
@@ -279,7 +315,7 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
// Filter preload filters // Filter preload filters
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters)) validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
for _, filter := range preload.Filters { for _, filter := range preload.Filters {
if v.IsValidColumn(filter.Column) { if preloadValidator.IsValidColumn(filter.Column) {
validPreloadFilters = append(validPreloadFilters, filter) validPreloadFilters = append(validPreloadFilters, filter)
} else { } else {
// Check if the filter column references a joined table alias // Check if the filter column references a joined table alias
@@ -302,7 +338,7 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
// Filter preload sort columns // Filter preload sort columns
validPreloadSorts := make([]SortOption, 0, len(preload.Sort)) validPreloadSorts := make([]SortOption, 0, len(preload.Sort))
for _, sort := range preload.Sort { for _, sort := range preload.Sort {
if v.IsValidColumn(sort.Column) { if preloadValidator.IsValidColumn(sort.Column) {
validPreloadSorts = append(validPreloadSorts, sort) validPreloadSorts = append(validPreloadSorts, sort)
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") { } else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
// Allow sort by expression/subquery, but validate for security // Allow sort by expression/subquery, but validate for security
+81
View File
@@ -464,3 +464,84 @@ func TestFilterRequestOptions_WithSortExpressions(t *testing.T) {
t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column) t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column)
} }
} }
// RelatedModel is used by PreloadParentModel to test preload column validation.
type RelatedModel struct {
RelatedID int64 `bun:"related_id,pk"`
Functionname string `bun:"functionname"`
}
// PreloadParentModel has a has-one relation to RelatedModel. The json tag on
// the relation field is the name used in x-preload headers.
type PreloadParentModel struct {
ID int64 `bun:"id,pk"`
Name string `bun:"name"`
RELATED *RelatedModel `json:"RELATED" bun:"rel:has-one,join:id=related_id"`
}
// TestFilterRequestOptions_PreloadColumnsValidatedAgainstRelatedModel verifies
// that preload columns are validated against the related model's fields, not the
// parent model's fields. This is the fix for the bug where specifying a column
// that exists only on the relation (e.g. "functionname") was incorrectly filtered
// out because it doesn't exist on the parent model.
func TestFilterRequestOptions_PreloadColumnsValidatedAgainstRelatedModel(t *testing.T) {
validator := NewColumnValidator(PreloadParentModel{})
options := RequestOptions{
Preload: []PreloadOption{
{
Relation: "RELATED",
// "functionname" exists on RelatedModel but NOT on PreloadParentModel.
// "name" exists on PreloadParentModel but NOT on RelatedModel.
// "nonexistent" exists on neither.
Columns: []string{"functionname", "name", "nonexistent"},
},
},
}
filtered := validator.FilterRequestOptions(options)
if len(filtered.Preload) != 1 {
t.Fatalf("Expected 1 preload, got %d", len(filtered.Preload))
}
cols := filtered.Preload[0].Columns
// Only "functionname" should survive: it belongs to RelatedModel.
if len(cols) != 1 {
t.Errorf("Expected 1 preload column, got %d: %v", len(cols), cols)
}
if len(cols) > 0 && cols[0] != "functionname" {
t.Errorf("Expected preload column 'functionname', got '%s'", cols[0])
}
}
// TestFilterRequestOptions_PreloadColumnsParentModelFallback verifies that when
// a preload relation is not found on the parent model, column validation falls
// back to the parent model's validator (no panic, no silent pass-through).
func TestFilterRequestOptions_PreloadColumnsParentModelFallback(t *testing.T) {
validator := NewColumnValidator(PreloadParentModel{})
options := RequestOptions{
Preload: []PreloadOption{
{
Relation: "UNKNOWN_RELATION",
Columns: []string{"id", "functionname"},
},
},
}
filtered := validator.FilterRequestOptions(options)
if len(filtered.Preload) != 1 {
t.Fatalf("Expected 1 preload, got %d", len(filtered.Preload))
}
cols := filtered.Preload[0].Columns
// Falls back to parent model: only "id" is valid on PreloadParentModel.
if len(cols) != 1 {
t.Errorf("Expected 1 preload column (fallback to parent), got %d: %v", len(cols), cols)
}
if len(cols) > 0 && cols[0] != "id" {
t.Errorf("Expected preload column 'id', got '%s'", cols[0])
}
}
+33 -2
View File
@@ -174,7 +174,7 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
varName := kw[1 : len(kw)-1] // strip [ and ] varName := kw[1 : len(kw)-1] // strip [ and ]
if val, ok := variables[varName]; ok { if val, ok := variables[varName]; ok {
if strVal := fmt.Sprintf("%v", val); strVal != "" { if strVal := fmt.Sprintf("%v", val); strVal != "" {
sqlquery = strings.ReplaceAll(sqlquery, kw, ValidSQL(strVal, "colvalue")) sqlquery = strings.ReplaceAll(sqlquery, kw, safeSubstituteVar(sqlquery, kw, strVal))
continue continue
} }
} }
@@ -533,7 +533,7 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
varName := kw[1 : len(kw)-1] // strip [ and ] varName := kw[1 : len(kw)-1] // strip [ and ]
if val, ok := variables[varName]; ok { if val, ok := variables[varName]; ok {
if strVal := fmt.Sprintf("%v", val); strVal != "" { if strVal := fmt.Sprintf("%v", val); strVal != "" {
sqlquery = strings.ReplaceAll(sqlquery, kw, ValidSQL(strVal, "colvalue")) sqlquery = strings.ReplaceAll(sqlquery, kw, safeSubstituteVar(sqlquery, kw, strVal))
continue continue
} }
} }
@@ -1006,6 +1006,37 @@ func IsNumeric(s string) bool {
return err == nil return err == nil
} }
// isInsideDollarQuote reports whether the first occurrence of placeholder in sqlquery
// is immediately surrounded by dollar-sign characters (i.e. inside a $...$-quoted string).
// Dollar-quoted strings pass content through literally — no backslash processing — so
// values placed there must NOT have their backslashes escaped.
func isInsideDollarQuote(sqlquery, placeholder string) bool {
idx := strings.Index(sqlquery, placeholder)
if idx < 0 {
return false
}
endIdx := idx + len(placeholder)
charBefore := byte(0)
charAfter := byte(0)
if idx > 0 {
charBefore = sqlquery[idx-1]
}
if endIdx < len(sqlquery) {
charAfter = sqlquery[endIdx]
}
return charBefore == '$' || charAfter == '$'
}
// safeSubstituteVar returns value sanitised for the quoting context that surrounds
// placeholder in sqlquery: raw (no backslash escaping) for dollar-quoted contexts,
// ValidSQL("colvalue") escaping for everything else.
func safeSubstituteVar(sqlquery, placeholder, value string) string {
if isInsideDollarQuote(sqlquery, placeholder) {
return value
}
return ValidSQL(value, "colvalue")
}
// getReplacementForBlankParam determines the replacement value for an unused parameter // getReplacementForBlankParam determines the replacement value for an unused parameter
// based on whether it appears within quotes in the SQL query. // based on whether it appears within quotes in the SQL query.
// It checks for PostgreSQL quotes: single quotes (”) and dollar quotes ($...$) // It checks for PostgreSQL quotes: single quotes (”) and dollar quotes ($...$)
+63 -8
View File
@@ -836,7 +836,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
err := h.db.RunInTransaction(ctx, func(tx common.Database) error { err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
// First, read the existing record from the database // First, read the existing record from the database
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*") selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...)
// Apply conditions to select // Apply conditions to select
if urlID != "" { if urlID != "" {
@@ -955,13 +955,34 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
return return
} }
// Fetch the updated record after the transaction commits to capture any trigger changes
updatedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
fetchQuery := h.db.NewSelect().Model(updatedRecord).Column(reflection.GetSQLModelColumns(model)...)
if urlID != "" {
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
} else if reqID != nil {
switch id := reqID.(type) {
case string:
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
case []string:
if len(id) > 0 {
fetchQuery = fetchQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
}
}
}
if err := fetchQuery.ScanModel(ctx); err != nil {
logger.Error("Failed to fetch updated record: %v", err)
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
return
}
logger.Info("Successfully updated record(s)") logger.Info("Successfully updated record(s)")
// Invalidate cache for this table // Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName) cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil { if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
} }
h.sendResponse(w, data, nil) h.sendResponse(w, updatedRecord, nil)
case []map[string]interface{}: case []map[string]interface{}:
// Batch update with array of objects // Batch update with array of objects
@@ -1017,7 +1038,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
// First, read the existing record // First, read the existing record
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
if err := selectQuery.ScanModel(ctx); err != nil { if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
continue // Skip if record not found continue // Skip if record not found
@@ -1089,13 +1110,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err) h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
return return
} }
logger.Info("Successfully updated %d records", len(updates))
// Fetch updated records after the transaction commits to capture any trigger changes
fetchedUpdates := make([]interface{}, 0, len(updates))
for _, item := range updates {
if itemID, ok := item["id"]; ok && itemID != nil {
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
fetchQuery := h.db.NewSelect().Model(fetchedRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
if err := fetchQuery.ScanModel(ctx); err != nil {
logger.Error("Failed to fetch updated record with ID %v: %v", itemID, err)
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
return
}
fetchedUpdates = append(fetchedUpdates, fetchedRecord)
}
}
logger.Info("Successfully updated %d records", len(fetchedUpdates))
// Invalidate cache for this table // Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName) cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil { if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
} }
h.sendResponse(w, updates, nil) h.sendResponse(w, fetchedUpdates, nil)
case []interface{}: case []interface{}:
// Batch update with []interface{} // Batch update with []interface{}
@@ -1157,7 +1194,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
// First, read the existing record // First, read the existing record
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface() existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID) selectQuery := tx.NewSelect().Model(existingRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
if err := selectQuery.ScanModel(ctx); err != nil { if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
continue // Skip if record not found continue // Skip if record not found
@@ -1232,13 +1269,31 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err) h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating records", err)
return return
} }
logger.Info("Successfully updated %d records", len(list))
// Fetch updated records after the transaction commits to capture any trigger changes
fetchedList := make([]interface{}, 0, len(list))
for _, item := range list {
if itemMap, ok := item.(map[string]interface{}); ok {
if itemID, ok := itemMap["id"]; ok && itemID != nil {
fetchedRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
fetchQuery := h.db.NewSelect().Model(fetchedRecord).Column(reflection.GetSQLModelColumns(model)...).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
if err := fetchQuery.ScanModel(ctx); err != nil {
logger.Error("Failed to fetch updated record with ID %v: %v", itemID, err)
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
return
}
fetchedList = append(fetchedList, fetchedRecord)
}
}
}
logger.Info("Successfully updated %d records", len(fetchedList))
// Invalidate cache for this table // Invalidate cache for this table
cacheTags := buildCacheTags(schema, tableName) cacheTags := buildCacheTags(schema, tableName)
if err := invalidateCacheForTags(ctx, cacheTags); err != nil { if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err) logger.Warn("Failed to invalidate cache for table %s: %v", tableName, err)
} }
h.sendResponse(w, list, nil) h.sendResponse(w, fetchedList, nil)
default: default:
logger.Error("Invalid data type for update operation: %T", data) logger.Error("Invalid data type for update operation: %T", data)
+51 -21
View File
@@ -1218,8 +1218,8 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" { if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
query = query.Table(tableName) query = query.Table(tableName)
} }
fields := reflection.GetSQLModelColumns(model)
query = query.Returning("*") query = query.Returning(fields...)
// Execute BeforeScan hooks - pass query chain so hooks can modify it // Execute BeforeScan hooks - pass query chain so hooks can modify it
itemHookCtx := &HookContext{ itemHookCtx := &HookContext{
@@ -1480,18 +1480,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
} }
} }
// Fetch the updated record to return the new values _ = result
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
selectQuery = tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
if err := selectQuery.ScanModel(ctx); err != nil {
return fmt.Errorf("failed to fetch updated record: %w", err)
}
updatedRecord = modelValue
// Store result for hooks
hookCtx.Result = updatedRecord
_ = result // Keep result variable for potential future use
return nil return nil
}) })
@@ -1501,6 +1490,16 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
return return
} }
// Fetch the updated record after the transaction commits to capture any trigger changes
fetchedRecord := reflect.New(reflect.TypeOf(model)).Interface()
selectQuery := h.db.NewSelect().Model(fetchedRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
if err := selectQuery.ScanModel(ctx); err != nil {
logger.Error("Failed to fetch updated record: %v", err)
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
return
}
updatedRecord = fetchedRecord
// Merge the updated record with the original request data // Merge the updated record with the original request data
// This preserves extra keys from the request and updates values from the database // This preserves extra keys from the request and updates values from the database
mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap) mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap)
@@ -2012,11 +2011,15 @@ func (h *Handler) processChildRelationsForField(
// Priority: Use foreign key field name if specified, otherwise use parent's PK name // Priority: Use foreign key field name if specified, otherwise use parent's PK name
var foreignKeyFieldName string var foreignKeyFieldName string
if relInfo.ForeignKey != "" { if relInfo.ForeignKey != "" {
// Get the JSON name for the foreign key field in the child model // For has-many/has-one: join:parentCol=childCol
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey) // ForeignKey = parent side, References = child side (where we actually set the value)
childField := relInfo.ForeignKey
if (relInfo.RelationType == "hasMany" || relInfo.RelationType == "hasOne") && relInfo.References != "" {
childField = relInfo.References
}
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, childField)
if foreignKeyFieldName == "" { if foreignKeyFieldName == "" {
// Fallback to lowercase field name foreignKeyFieldName = strings.ToLower(childField)
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
} }
} else { } else {
// Fallback: use parent's primary key name // Fallback: use parent's primary key name
@@ -2040,7 +2043,10 @@ func (h *Handler) processChildRelationsForField(
// Process based on relation type and data structure // Process based on relation type and data structure
switch v := relationValue.(type) { switch v := relationValue.(type) {
case map[string]interface{}: case map[string]interface{}:
// Single related object - add parent ID to foreign key field if !isValidNestedRequest(v) {
logger.Debug("Skipping single relation %s - missing or invalid _request value", relationName)
return nil
}
// IMPORTANT: In recursive relationships, don't overwrite the primary key // IMPORTANT: In recursive relationships, don't overwrite the primary key
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName { if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
v[foreignKeyFieldName] = parentID v[foreignKeyFieldName] = parentID
@@ -2057,7 +2063,10 @@ func (h *Handler) processChildRelationsForField(
// Multiple related objects // Multiple related objects
for i, item := range v { for i, item := range v {
if itemMap, ok := item.(map[string]interface{}); ok { if itemMap, ok := item.(map[string]interface{}); ok {
// Add parent ID to foreign key field if !isValidNestedRequest(itemMap) {
logger.Debug("Skipping relation array[%d] %s - missing or invalid _request value", i, relationName)
continue
}
// IMPORTANT: In recursive relationships, don't overwrite the primary key // IMPORTANT: In recursive relationships, don't overwrite the primary key
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName { if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
itemMap[foreignKeyFieldName] = parentID itemMap[foreignKeyFieldName] = parentID
@@ -2075,7 +2084,10 @@ func (h *Handler) processChildRelationsForField(
case []map[string]interface{}: case []map[string]interface{}:
// Multiple related objects (typed slice) // Multiple related objects (typed slice)
for i, itemMap := range v { for i, itemMap := range v {
// Add parent ID to foreign key field if !isValidNestedRequest(itemMap) {
logger.Debug("Skipping relation typed array[%d] %s - missing or invalid _request value", i, relationName)
continue
}
// IMPORTANT: In recursive relationships, don't overwrite the primary key // IMPORTANT: In recursive relationships, don't overwrite the primary key
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName { if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
itemMap[foreignKeyFieldName] = parentID itemMap[foreignKeyFieldName] = parentID
@@ -2096,6 +2108,24 @@ func (h *Handler) processChildRelationsForField(
return nil return nil
} }
// isValidNestedRequest returns true only when the item carries a _request key
// whose value is one of the recognised mutation verbs.
func isValidNestedRequest(item map[string]interface{}) bool {
raw, ok := item["_request"]
if !ok {
return false
}
s, ok := raw.(string)
if !ok {
return false
}
switch strings.ToLower(strings.TrimSpace(s)) {
case "insert", "add", "change", "update", "delete", "remove":
return true
}
return false
}
// getTableNameForRelatedModel gets the table name for a related model. // getTableNameForRelatedModel gets the table name for a related model.
// If the model's TableName() is schema-qualified (e.g. "public.users") the // If the model's TableName() is schema-qualified (e.g. "public.users") the
// separator is adjusted for the active driver: underscore for SQLite, dot otherwise. // separator is adjusted for the active driver: underscore for SQLite, dot otherwise.
+39
View File
@@ -352,6 +352,45 @@ func (m *mockRegistry) GetAllModels() map[string]interface{} {
return m.models return m.models
} }
// TestIsValidNestedRequest verifies that only the allowed _request verbs are accepted
// and that items missing the key are rejected.
func TestIsValidNestedRequest(t *testing.T) {
tests := []struct {
name string
item map[string]interface{}
expected bool
}{
// Valid verbs
{name: "insert", item: map[string]interface{}{"_request": "insert"}, expected: true},
{name: "add", item: map[string]interface{}{"_request": "add"}, expected: true},
{name: "update", item: map[string]interface{}{"_request": "update"}, expected: true},
{name: "change", item: map[string]interface{}{"_request": "change"}, expected: true},
{name: "delete", item: map[string]interface{}{"_request": "delete"}, expected: true},
{name: "remove", item: map[string]interface{}{"_request": "remove"}, expected: true},
// Case-insensitive
{name: "INSERT uppercase", item: map[string]interface{}{"_request": "INSERT"}, expected: true},
{name: "Remove mixed case", item: map[string]interface{}{"_request": "Remove"}, expected: true},
// Whitespace trimmed
{name: "insert with spaces", item: map[string]interface{}{"_request": " insert "}, expected: true},
// Invalid / missing
{name: "missing _request", item: map[string]interface{}{"name": "foo"}, expected: false},
{name: "empty string", item: map[string]interface{}{"_request": ""}, expected: false},
{name: "unknown verb", item: map[string]interface{}{"_request": "create"}, expected: false},
{name: "unknown verb modify", item: map[string]interface{}{"_request": "modify"}, expected: false},
{name: "non-string value", item: map[string]interface{}{"_request": 42}, expected: false},
{name: "nil value", item: map[string]interface{}{"_request": nil}, expected: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := isValidNestedRequest(tt.item)
if got != tt.expected {
t.Errorf("isValidNestedRequest(%v) = %v, want %v", tt.item, got, tt.expected)
}
})
}
}
// TestMultiLevelRelationExtraction tests extracting deeply nested relations // TestMultiLevelRelationExtraction tests extracting deeply nested relations
func TestMultiLevelRelationExtraction(t *testing.T) { func TestMultiLevelRelationExtraction(t *testing.T) {
registry := &mockRegistry{ registry := &mockRegistry{
+14 -7
View File
@@ -6,9 +6,9 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"regexp" "regexp"
"sort"
"strconv" "strconv"
"strings" "strings"
"unicode/utf8"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -102,11 +102,6 @@ func DecodeParam(pStr string) (string, error) {
if strings.HasPrefix(code, "ZIP_") || strings.HasPrefix(code, "__") { if strings.HasPrefix(code, "ZIP_") || strings.HasPrefix(code, "__") {
code, _ = DecodeParam(code) code, _ = DecodeParam(code)
} else {
strDat, err := base64.StdEncoding.DecodeString(code)
if err == nil && utf8.Valid(strDat) {
code = string(strDat)
}
} }
return code, nil return code, nil
@@ -146,9 +141,21 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
combinedParams[strings.ToLower(key)] = value combinedParams[strings.ToLower(key)] = value
} }
sortedKeys := make([]string, 0, len(combinedParams))
for key := range combinedParams {
sortedKeys = append(sortedKeys, key)
}
sort.Slice(sortedKeys, func(i, j int) bool {
if sortedKeys[i] != sortedKeys[j] {
return sortedKeys[i] < sortedKeys[j]
}
return combinedParams[sortedKeys[i]] < combinedParams[sortedKeys[j]]
})
// Process each parameter (from both headers and query params) // Process each parameter (from both headers and query params)
// Note: keys are already normalized to lowercase in combinedParams // Note: keys are already normalized to lowercase in combinedParams
for key, value := range combinedParams { for _, key := range sortedKeys {
value := combinedParams[key]
// Decode value if it's base64 encoded // Decode value if it's base64 encoded
decodedValue := decodeHeaderValue(value) decodedValue := decodeHeaderValue(value)
+19 -10
View File
@@ -70,6 +70,25 @@ func (m *mountPoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Try to open the file // Try to open the file
file, err := m.provider.Open(strings.TrimPrefix(filePath, "/")) file, err := m.provider.Open(strings.TrimPrefix(filePath, "/"))
if err != nil { if err != nil {
// For extensionless paths, also try path/index.html
if path.Ext(filePath) == "" {
indexFallback := path.Join(filePath, "index.html")
file, err = m.provider.Open(strings.TrimPrefix(indexFallback, "/"))
if err == nil {
defer file.Close()
m.serveFile(w, r, indexFallback, file)
return
}
indexFallback = fmt.Sprintf("%s.html", filePath)
file, err = m.provider.Open(strings.TrimPrefix(indexFallback, "/"))
if err == nil {
defer file.Close()
m.serveFile(w, r, indexFallback, file)
return
}
}
// File doesn't exist - check if we should use fallback // File doesn't exist - check if we should use fallback
if m.fallbackStrategy != nil && m.fallbackStrategy.ShouldFallback(filePath) { if m.fallbackStrategy != nil && m.fallbackStrategy.ShouldFallback(filePath) {
fallbackPath := m.fallbackStrategy.GetFallbackPath(filePath) fallbackPath := m.fallbackStrategy.GetFallbackPath(filePath)
@@ -80,16 +99,6 @@ func (m *mountPoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
// For extensionless paths, also try path/index.html
if path.Ext(filePath) == "" {
indexFallback := path.Join(filePath, "index.html")
file, err = m.provider.Open(strings.TrimPrefix(indexFallback, "/"))
if err == nil {
defer file.Close()
m.serveFile(w, r, indexFallback, file)
return
}
}
} }
// No fallback or fallback failed - return 404 // No fallback or fallback failed - return 404