Compare commits

...

3 Commits

Author SHA1 Message Date
Hein 09a3dc92b9 fix(restheadspec): normalize empty results to objects instead of arrays 2026-05-18 14:37:46 +02:00
Hein 6590cd789a fix(nestedCUD): re-select rows after insert/update for accurate state
* Ensure result.Data reflects DB-generated defaults after insert.
* Update result.Data with current DB state after update.
2026-05-18 13:10:13 +02:00
Hein 4244e838b1 fix(reflection): enhance GetForeignKeyColumn logic for self-referential models
* Add support for self-referential models in GetForeignKeyColumn
* Update comments for clarity on foreign key resolution strategies
* Introduce selfRefItem struct for testing self-referential behavior
2026-05-18 13:03:07 +02:00
5 changed files with 138 additions and 49 deletions
+34 -1
View File
@@ -125,6 +125,13 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
result.AffectedRows = 1
result.Data = regularData
// Re-select the inserted row so result.Data reflects DB-generated defaults.
if row, err := p.processSelect(ctx, tableName, id); err != nil {
logger.Warn("Select after insert failed: table=%s, id=%v, error=%v", tableName, id, err)
} else if len(row) > 0 {
result.Data = row
}
// Process child relations after parent insert (to get parent ID)
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType, parentIDs); err != nil {
logger.Error("Failed to process child relations after insert: table=%s, parentID=%v, relations=%+v, error=%v", tableName, id, relationFields, err)
@@ -146,9 +153,16 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
result.AffectedRows = rows
result.Data = regularData
// Re-select the updated row so result.Data reflects current DB state.
if row, err := p.processSelect(ctx, tableName, result.ID); err != nil {
logger.Warn("Select after update failed: table=%s, id=%v, error=%v", tableName, result.ID, err)
} else if len(row) > 0 {
result.Data = row
}
// Process child relations for update
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
logger.Error("Failed to process child relations after update: table=%s, parentID=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err)
logger.Error("Failed to process child relations after update: table=%s, parentID=%v, relations=%+v, error=%v", tableName, data[pkName], regularData, err)
return nil, fmt.Errorf("failed to process child relations: %w", err)
}
} else {
@@ -234,6 +248,8 @@ func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, mode
return
}
pkCol := reflection.GetPrimaryKeyName(reflect.New(modelType).Interface())
for parentKey, parentID := range parentIDs {
dbColNames := reflection.GetForeignKeyColumn(modelType, parentKey)
@@ -255,6 +271,9 @@ func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, mode
}
for _, dbColName := range dbColNames {
if pkCol != "" && strings.EqualFold(dbColName, pkCol) {
continue
}
if _, exists := data[dbColName]; !exists {
logger.Debug("Injecting foreign key: %s = %v", dbColName, parentID)
data[dbColName] = parentID
@@ -289,6 +308,20 @@ func (p *NestedCUDProcessor) processInsert(
return id, nil
}
// processSelect fetches the row identified by id from tableName into a flat map.
// Used to populate result.Data with the actual DB state after insert/update.
func (p *NestedCUDProcessor) processSelect(ctx context.Context, tableName string, id interface{}) (map[string]interface{}, error) {
pkName := reflection.GetPrimaryKeyName(tableName)
var row map[string]interface{}
if err := p.db.NewSelect().
Table(tableName).
Where(fmt.Sprintf("%s = ?", QuoteIdent(pkName)), id).
Scan(ctx, &row); err != nil {
return nil, fmt.Errorf("select after write failed: %w", err)
}
return row, nil
}
// processUpdate handles update operation
func (p *NestedCUDProcessor) processUpdate(
ctx context.Context,
+47 -11
View File
@@ -974,16 +974,21 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
}
// GetForeignKeyColumn returns the DB column names of the foreign key(s) that
// the relation field identified by parentKey owns on modelType. Composite keys
// (e.g. bun "join:a=b,join:c=d" or GORM "foreignKey:ColA,ColB") yield multiple
// entries. Returns nil when no tag is found (caller should fall back to
// convention).
// relate parentKey to modelType. Composite keys (e.g. bun "join:a=b,join:c=d"
// or GORM "foreignKey:ColA,ColB") yield multiple entries. Returns nil when no
// tag is found (caller should fall back to convention).
//
// It checks tags in priority order:
// 1. Bun join: tag — e.g. `bun:"rel:belongs-to,join:department_id=id"` → ["department_id"]
// 2. GORM foreignKey: tag — e.g. `gorm:"foreignKey:DepartmentID"` → [column of DepartmentID field]
// Two lookup strategies are tried in order:
//
// parentKey is matched case-insensitively against the field name and JSON tag.
// 1. Relation-field match: find a field whose name/json equals parentKey, then
// read its bun join: or GORM foreignKey: tag and return the local columns.
// e.g. parentKey="department", field `Department bun:"join:dept_id=id"` → ["dept_id"]
//
// 2. Join left-side scan: scan every bun join tag in the struct for pairs whose
// left side equals parentKey and return the right-side (child FK) columns.
// e.g. parentKey="rid_mastertaskitem", field `Children bun:"join:rid_mastertaskitem=rid_parentmastertaskitem"` → ["rid_parentmastertaskitem"]
// Strategy 1 is skipped if the matched field is a declared relation (rel:) or
// has a GORM tag but carries no explicit FK — callers should use convention.
func GetForeignKeyColumn(modelType reflect.Type, parentKey string) []string {
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice {
modelType = modelType.Elem()
@@ -992,6 +997,7 @@ func GetForeignKeyColumn(modelType reflect.Type, parentKey string) []string {
return nil
}
// Strategy 1: match parentKey against a field's name/json tag.
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
@@ -1001,9 +1007,11 @@ func GetForeignKeyColumn(modelType reflect.Type, parentKey string) []string {
continue
}
bunTag := field.Tag.Get("bun")
// Bun: join:local_col=foreign_col (one join: part per pair)
var bunCols []string
for _, part := range strings.Split(field.Tag.Get("bun"), ",") {
for _, part := range strings.Split(bunTag, ",") {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "join:") {
pair := strings.TrimPrefix(part, "join:")
@@ -1033,10 +1041,38 @@ func GetForeignKeyColumn(modelType reflect.Type, parentKey string) []string {
}
}
return nil
// The field matched by name/json but has no explicit FK tag. If it is a
// declared relation field (rel:) or carries a GORM tag, the caller should
// use naming convention — don't fall through to strategy 2. Otherwise the
// matched field is a plain scalar column; proceed to the join left-side scan.
if strings.Contains(bunTag, "rel:") || field.Tag.Get("gorm") != "" {
return nil
}
break
}
return nil
// Strategy 2: scan every field's bun join tag for pairs whose left side (the
// parent's column) matches parentKey; the right side is the child FK column.
// This handles cases where parentKey is a raw column name rather than a
// relation field name (e.g. self-referential or has-many relationships).
seen := map[string]bool{}
var cols []string
for i := 0; i < modelType.NumField(); i++ {
for _, part := range strings.Split(modelType.Field(i).Tag.Get("bun"), ",") {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "join:") {
pair := strings.TrimPrefix(part, "join:")
if idx := strings.Index(pair, "="); idx > 0 {
left, right := pair[:idx], pair[idx+1:]
if strings.EqualFold(left, parentKey) && !seen[right] {
seen[right] = true
cols = append(cols, right)
}
}
}
}
}
return cols // nil if empty
}
// GetRelationModel gets the model type for a relation field
@@ -37,6 +37,17 @@ type gormCompositeEmployee struct {
Department *fkDept `gorm:"foreignKey:DeptID,TenantID" json:"department"`
}
// selfRefItem mimics a self-referential model (like mastertaskitem) where the
// parent PK column appears as the left side of a has-many join tag.
type selfRefItem struct {
RidItem int32 `json:"rid_item" bun:"rid_item,type:integer,pk"`
RidParentItem int32 `json:"rid_parentitem" bun:"rid_parentitem,type:integer"`
// has-one (single parent pointer)
Parent *selfRefItem `json:"Parent,omitempty" bun:"rel:has-one,join:rid_item=rid_parentitem"`
// has-many (child collection) — same join, duplicate right-side must be deduped
Children []*selfRefItem `json:"Children,omitempty" bun:"rel:has-many,join:rid_item=rid_parentitem"`
}
// conventionEmployee has no explicit FK tag — relies on naming convention.
type conventionEmployee struct {
DepartmentID string `json:"department_id"`
@@ -101,6 +112,14 @@ func TestGetForeignKeyColumn(t *testing.T) {
want: []string{"dept_id", "tenant_id"},
},
// Join left-side scan (parentKey is a raw column name, not a relation field name)
{
name: "self-referential: parent PK column returns child FK column",
modelType: reflect.TypeOf(selfRefItem{}),
parentKey: "rid_item",
want: []string{"rid_parentitem"},
},
// Pointer and slice unwrapping
{
name: "pointer to struct is unwrapped",
+27 -27
View File
@@ -9,29 +9,29 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Test that normalizeResultArray returns empty array when no records found without ID
// Test that normalizeResultArray returns empty object when no records found (single-record mode)
func TestNormalizeResultArray_EmptyArrayWhenNoID(t *testing.T) {
handler := &Handler{}
tests := []struct {
name string
input interface{}
shouldBeEmptyArr bool
name string
input interface{}
shouldBeEmptyObj bool
}{
{
name: "nil should return empty array",
input: nil,
shouldBeEmptyArr: true,
name: "nil should return empty object",
input: nil,
shouldBeEmptyObj: true,
},
{
name: "empty slice should return empty array",
input: []*EmptyTestModel{},
shouldBeEmptyArr: true,
name: "empty slice should return empty object",
input: []*EmptyTestModel{},
shouldBeEmptyObj: true,
},
{
name: "single element should return the element",
input: []*EmptyTestModel{{ID: 1, Name: "test"}},
shouldBeEmptyArr: false,
name: "single element should return the element",
input: []*EmptyTestModel{{ID: 1, Name: "test"}},
shouldBeEmptyObj: false,
},
{
name: "multiple elements should return the slice",
@@ -39,7 +39,7 @@ func TestNormalizeResultArray_EmptyArrayWhenNoID(t *testing.T) {
{ID: 1, Name: "test1"},
{ID: 2, Name: "test2"},
},
shouldBeEmptyArr: false,
shouldBeEmptyObj: false,
},
}
@@ -47,25 +47,25 @@ func TestNormalizeResultArray_EmptyArrayWhenNoID(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
result := handler.normalizeResultArray(tt.input)
// For cases that should return empty array
if tt.shouldBeEmptyArr {
emptyArr, ok := result.([]interface{})
// For cases that should return empty object
if tt.shouldBeEmptyObj {
emptyObj, ok := result.(map[string]interface{})
if !ok {
t.Errorf("Expected empty array []interface{}{}, got %T: %v", result, result)
t.Errorf("Expected empty object map[string]interface{}{}, got %T: %v", result, result)
return
}
if len(emptyArr) != 0 {
t.Errorf("Expected empty array with length 0, got length %d", len(emptyArr))
if len(emptyObj) != 0 {
t.Errorf("Expected empty object with length 0, got length %d", len(emptyObj))
}
// Verify it serializes to [] and not null
// Verify it serializes to {} and not null
jsonBytes, err := json.Marshal(result)
if err != nil {
t.Errorf("Failed to marshal result: %v", err)
return
}
if string(jsonBytes) != "[]" {
t.Errorf("Expected JSON '[]', got '%s'", string(jsonBytes))
if string(jsonBytes) != "{}" {
t.Errorf("Expected JSON '{}', got '%s'", string(jsonBytes))
}
}
})
@@ -138,12 +138,12 @@ func TestSendResponseWithOptions_NoDataFoundHeader(t *testing.T) {
t.Errorf("Expected X-No-Data-Found header to be 'true', got '%s'", mockWriter.headers["X-No-Data-Found"])
}
// Check status code is 200
if mockWriter.statusCode != 200 {
t.Errorf("Expected status code 200, got %d", mockWriter.statusCode)
// Check status code is 204 when no records found
if mockWriter.statusCode != 204 {
t.Errorf("Expected status code 204, got %d", mockWriter.statusCode)
}
// Verify the body is an empty array
// Verify the body is an empty array (list request, SingleRecordAsObject not set)
if mockWriter.body == nil {
t.Error("Expected body to be set, got nil")
} else {
+11 -10
View File
@@ -2502,14 +2502,16 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
w.SetHeader("X-No-Data-Found", "true")
}
w.WriteHeader(http.StatusOK)
// Normalize single-record arrays to objects if requested
if options != nil && options.SingleRecordAsObject {
data = h.normalizeResultArray(data)
}
// Return data as-is without wrapping in common.Response
if dataLen == 0 {
w.WriteHeader(http.StatusNoContent)
} else {
w.WriteHeader(http.StatusOK)
}
if err := w.WriteJSON(data); err != nil {
logger.Error("Failed to write JSON response: %v", err)
@@ -2520,7 +2522,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
if data == nil {
return []interface{}{}
return map[string]interface{}{}
}
// Use reflection to check if data is a slice or array
@@ -2535,15 +2537,15 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
// Return the single element
return dataValue.Index(0).Interface()
} else if dataValue.Len() == 0 {
// Keep empty array as empty array, don't convert to empty object
return []interface{}{}
// Single-record request with no result → empty object
return map[string]interface{}{}
}
}
if dataValue.Kind() == reflect.String {
str := dataValue.String()
if str == "" || str == "null" {
return []interface{}{}
return map[string]interface{}{}
}
}
@@ -2552,9 +2554,6 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
// sendFormattedResponse sends response with formatting options
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
// Normalize single-record arrays to objects if requested
httpStatus := http.StatusOK
// Handle nil data - convert to empty array
if data == nil {
data = []interface{}{}
@@ -2566,8 +2565,10 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
dataLen := reflection.Len(data)
// Add X-No-Data-Found header when no records were found
httpStatus := http.StatusOK
if dataLen == 0 {
w.SetHeader("X-No-Data-Found", "true")
httpStatus = http.StatusNoContent
}
// Apply normalization after header is set