Compare commits

...

3 Commits

Author SHA1 Message Date
Hein
6590cd789a fix(nestedCUD): re-select rows after insert/update for accurate state
* Ensure result.Data reflects DB-generated defaults after insert.
* Update result.Data with current DB state after update.
2026-05-18 13:10:13 +02:00
Hein
4244e838b1 fix(reflection): enhance GetForeignKeyColumn logic for self-referential models
* Add support for self-referential models in GetForeignKeyColumn
* Update comments for clarity on foreign key resolution strategies
* Introduce selfRefItem struct for testing self-referential behavior
2026-05-18 13:03:07 +02:00
Hein
c42fa11c1a fix(reflection): update GetForeignKeyColumn to return multiple columns
* Change return type to []string for composite keys
* Adjust related logic in injectForeignKeys method
* Update tests to validate new behavior for composite foreign keys
2026-05-18 12:39:06 +02:00
3 changed files with 167 additions and 43 deletions

View File

@@ -125,6 +125,13 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
result.AffectedRows = 1 result.AffectedRows = 1
result.Data = regularData result.Data = regularData
// Re-select the inserted row so result.Data reflects DB-generated defaults.
if row, err := p.processSelect(ctx, tableName, id); err != nil {
logger.Warn("Select after insert failed: table=%s, id=%v, error=%v", tableName, id, err)
} else if len(row) > 0 {
result.Data = row
}
// Process child relations after parent insert (to get parent ID) // Process child relations after parent insert (to get parent ID)
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType, parentIDs); err != nil { if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType, parentIDs); err != nil {
logger.Error("Failed to process child relations after insert: table=%s, parentID=%v, relations=%+v, error=%v", tableName, id, relationFields, err) logger.Error("Failed to process child relations after insert: table=%s, parentID=%v, relations=%+v, error=%v", tableName, id, relationFields, err)
@@ -146,9 +153,16 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
result.AffectedRows = rows result.AffectedRows = rows
result.Data = regularData result.Data = regularData
// Re-select the updated row so result.Data reflects current DB state.
if row, err := p.processSelect(ctx, tableName, result.ID); err != nil {
logger.Warn("Select after update failed: table=%s, id=%v, error=%v", tableName, result.ID, err)
} else if len(row) > 0 {
result.Data = row
}
// Process child relations for update // Process child relations for update
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil { if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
logger.Error("Failed to process child relations after update: table=%s, parentID=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err) logger.Error("Failed to process child relations after update: table=%s, parentID=%v, relations=%+v, error=%v", tableName, data[pkName], regularData, err)
return nil, fmt.Errorf("failed to process child relations: %w", err) return nil, fmt.Errorf("failed to process child relations: %w", err)
} }
} else { } else {
@@ -234,10 +248,12 @@ func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, mode
return return
} }
for parentKey, parentID := range parentIDs { pkCol := reflection.GetPrimaryKeyName(reflect.New(modelType).Interface())
dbColName := reflection.GetForeignKeyColumn(modelType, parentKey)
if dbColName == "" { for parentKey, parentID := range parentIDs {
dbColNames := reflection.GetForeignKeyColumn(modelType, parentKey)
if len(dbColNames) == 0 {
// No explicit tag found — fall back to naming convention by scanning scalar fields. // No explicit tag found — fall back to naming convention by scanning scalar fields.
for i := 0; i < modelType.NumField(); i++ { for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i) field := modelType.Field(i)
@@ -248,13 +264,16 @@ func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, mode
strings.EqualFold(jsonName, parentKey+"_id") || strings.EqualFold(jsonName, parentKey+"_id") ||
strings.EqualFold(jsonName, parentKey+"id") || strings.EqualFold(jsonName, parentKey+"id") ||
strings.EqualFold(field.Name, parentKey+"ID") { strings.EqualFold(field.Name, parentKey+"ID") {
dbColName = reflection.GetColumnName(field) dbColNames = []string{reflection.GetColumnName(field)}
break break
} }
} }
} }
if dbColName != "" { for _, dbColName := range dbColNames {
if pkCol != "" && strings.EqualFold(dbColName, pkCol) {
continue
}
if _, exists := data[dbColName]; !exists { if _, exists := data[dbColName]; !exists {
logger.Debug("Injecting foreign key: %s = %v", dbColName, parentID) logger.Debug("Injecting foreign key: %s = %v", dbColName, parentID)
data[dbColName] = parentID data[dbColName] = parentID
@@ -289,6 +308,20 @@ func (p *NestedCUDProcessor) processInsert(
return id, nil return id, nil
} }
// processSelect fetches the row identified by id from tableName into a flat map.
// Used to populate result.Data with the actual DB state after insert/update.
func (p *NestedCUDProcessor) processSelect(ctx context.Context, tableName string, id interface{}) (map[string]interface{}, error) {
pkName := reflection.GetPrimaryKeyName(tableName)
var row map[string]interface{}
if err := p.db.NewSelect().
Table(tableName).
Where(fmt.Sprintf("%s = ?", QuoteIdent(pkName)), id).
Scan(ctx, &row); err != nil {
return nil, fmt.Errorf("select after write failed: %w", err)
}
return row, nil
}
// processUpdate handles update operation // processUpdate handles update operation
func (p *NestedCUDProcessor) processUpdate( func (p *NestedCUDProcessor) processUpdate(
ctx context.Context, ctx context.Context,

View File

@@ -973,23 +973,31 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
return RelationUnknown return RelationUnknown
} }
// GetForeignKeyColumn returns the DB column name of the foreign key that the // GetForeignKeyColumn returns the DB column names of the foreign key(s) that
// relation field identified by parentKey owns on modelType. // relate parentKey to modelType. Composite keys (e.g. bun "join:a=b,join:c=d"
// or GORM "foreignKey:ColA,ColB") yield multiple entries. Returns nil when no
// tag is found (caller should fall back to convention).
// //
// It checks tags in priority order: // Two lookup strategies are tried in order:
// 1. Bun join: tag — e.g. `bun:"rel:belongs-to,join:department_id=id"` → "department_id"
// 2. GORM foreignKey: tag — e.g. `gorm:"foreignKey:DepartmentID"` → column of DepartmentID field
// 3. Returns "" when no tag is found (caller should fall back to convention)
// //
// parentKey is matched case-insensitively against the field name and JSON tag. // 1. Relation-field match: find a field whose name/json equals parentKey, then
func GetForeignKeyColumn(modelType reflect.Type, parentKey string) string { // read its bun join: or GORM foreignKey: tag and return the local columns.
// e.g. parentKey="department", field `Department bun:"join:dept_id=id"` → ["dept_id"]
//
// 2. Join left-side scan: scan every bun join tag in the struct for pairs whose
// left side equals parentKey and return the right-side (child FK) columns.
// e.g. parentKey="rid_mastertaskitem", field `Children bun:"join:rid_mastertaskitem=rid_parentmastertaskitem"` → ["rid_parentmastertaskitem"]
// Strategy 1 is skipped if the matched field is a declared relation (rel:) or
// has a GORM tag but carries no explicit FK — callers should use convention.
func GetForeignKeyColumn(modelType reflect.Type, parentKey string) []string {
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice { for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice {
modelType = modelType.Elem() modelType = modelType.Elem()
} }
if modelType.Kind() != reflect.Struct { if modelType.Kind() != reflect.Struct {
return "" return nil
} }
// Strategy 1: match parentKey against a field's name/json tag.
for i := 0; i < modelType.NumField(); i++ { for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i) field := modelType.Field(i)
@@ -999,34 +1007,72 @@ func GetForeignKeyColumn(modelType reflect.Type, parentKey string) string {
continue continue
} }
// Bun: join:local_col=foreign_col bunTag := field.Tag.Get("bun")
for _, part := range strings.Split(field.Tag.Get("bun"), ",") {
// Bun: join:local_col=foreign_col (one join: part per pair)
var bunCols []string
for _, part := range strings.Split(bunTag, ",") {
part = strings.TrimSpace(part) part = strings.TrimSpace(part)
if strings.HasPrefix(part, "join:") { if strings.HasPrefix(part, "join:") {
// join: may contain multiple pairs separated by spaces: "join:a=b join:c=d"
// but typically it's a single pair; take the first local column
pair := strings.TrimPrefix(part, "join:") pair := strings.TrimPrefix(part, "join:")
if idx := strings.Index(pair, "="); idx > 0 { if idx := strings.Index(pair, "="); idx > 0 {
return pair[:idx] bunCols = append(bunCols, pair[:idx])
} }
} }
} }
if len(bunCols) > 0 {
return bunCols
}
// GORM: foreignKey:FieldName // GORM: foreignKey:FieldA,FieldB
for _, part := range strings.Split(field.Tag.Get("gorm"), ";") { for _, part := range strings.Split(field.Tag.Get("gorm"), ";") {
part = strings.TrimSpace(part) part = strings.TrimSpace(part)
if strings.HasPrefix(part, "foreignKey:") { if strings.HasPrefix(part, "foreignKey:") {
fkFieldName := strings.TrimPrefix(part, "foreignKey:") var cols []string
if fkField, ok := modelType.FieldByName(fkFieldName); ok { for _, fkFieldName := range strings.Split(strings.TrimPrefix(part, "foreignKey:"), ",") {
return getColumnNameFromField(fkField) fkFieldName = strings.TrimSpace(fkFieldName)
if fkField, ok := modelType.FieldByName(fkFieldName); ok {
cols = append(cols, getColumnNameFromField(fkField))
}
}
if len(cols) > 0 {
return cols
} }
} }
} }
return "" // The field matched by name/json but has no explicit FK tag. If it is a
// declared relation field (rel:) or carries a GORM tag, the caller should
// use naming convention — don't fall through to strategy 2. Otherwise the
// matched field is a plain scalar column; proceed to the join left-side scan.
if strings.Contains(bunTag, "rel:") || field.Tag.Get("gorm") != "" {
return nil
}
break
} }
return "" // Strategy 2: scan every field's bun join tag for pairs whose left side (the
// parent's column) matches parentKey; the right side is the child FK column.
// This handles cases where parentKey is a raw column name rather than a
// relation field name (e.g. self-referential or has-many relationships).
seen := map[string]bool{}
var cols []string
for i := 0; i < modelType.NumField(); i++ {
for _, part := range strings.Split(modelType.Field(i).Tag.Get("bun"), ",") {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "join:") {
pair := strings.TrimPrefix(part, "join:")
if idx := strings.Index(pair, "="); idx > 0 {
left, right := pair[:idx], pair[idx+1:]
if strings.EqualFold(left, parentKey) && !seen[right] {
seen[right] = true
cols = append(cols, right)
}
}
}
}
}
return cols // nil if empty
} }
// GetRelationModel gets the model type for a relation field // GetRelationModel gets the model type for a relation field

View File

@@ -15,6 +15,13 @@ type bunEmployee struct {
Department *fkDept `bun:"rel:belongs-to,join:dept_id=id" json:"department"` Department *fkDept `bun:"rel:belongs-to,join:dept_id=id" json:"department"`
} }
// bunCompositeEmployee has a composite bun join: (two join: parts).
type bunCompositeEmployee struct {
DeptID string `bun:"dept_id" json:"dept_id"`
TenantID string `bun:"tenant_id" json:"tenant_id"`
Department *fkDept `bun:"rel:belongs-to,join:dept_id=id,join:tenant_id=id" json:"department"`
}
// gormEmployee uses gorm foreignKey: tag (mirrors testmodels.Employee). // gormEmployee uses gorm foreignKey: tag (mirrors testmodels.Employee).
type gormEmployee struct { type gormEmployee struct {
DepartmentID string `json:"department_id"` DepartmentID string `json:"department_id"`
@@ -23,6 +30,24 @@ type gormEmployee struct {
Manager *fkDept `gorm:"foreignKey:ManagerID;references:ID" json:"manager"` Manager *fkDept `gorm:"foreignKey:ManagerID;references:ID" json:"manager"`
} }
// gormCompositeEmployee has a composite GORM foreignKey.
type gormCompositeEmployee struct {
DeptID string `json:"dept_id"`
TenantID string `json:"tenant_id"`
Department *fkDept `gorm:"foreignKey:DeptID,TenantID" json:"department"`
}
// selfRefItem mimics a self-referential model (like mastertaskitem) where the
// parent PK column appears as the left side of a has-many join tag.
type selfRefItem struct {
RidItem int32 `json:"rid_item" bun:"rid_item,type:integer,pk"`
RidParentItem int32 `json:"rid_parentitem" bun:"rid_parentitem,type:integer"`
// has-one (single parent pointer)
Parent *selfRefItem `json:"Parent,omitempty" bun:"rel:has-one,join:rid_item=rid_parentitem"`
// has-many (child collection) — same join, duplicate right-side must be deduped
Children []*selfRefItem `json:"Children,omitempty" bun:"rel:has-many,join:rid_item=rid_parentitem"`
}
// conventionEmployee has no explicit FK tag — relies on naming convention. // conventionEmployee has no explicit FK tag — relies on naming convention.
type conventionEmployee struct { type conventionEmployee struct {
DepartmentID string `json:"department_id"` DepartmentID string `json:"department_id"`
@@ -39,20 +64,26 @@ func TestGetForeignKeyColumn(t *testing.T) {
name string name string
modelType reflect.Type modelType reflect.Type
parentKey string parentKey string
want string want []string
}{ }{
// Bun join: tag // Bun join: tag
{ {
name: "bun join tag returns local column", name: "bun join tag returns local column",
modelType: reflect.TypeOf(bunEmployee{}), modelType: reflect.TypeOf(bunEmployee{}),
parentKey: "department", parentKey: "department",
want: "dept_id", want: []string{"dept_id"},
}, },
{ {
name: "bun join tag matched via json tag (case-insensitive)", name: "bun join tag matched via json tag (case-insensitive)",
modelType: reflect.TypeOf(bunEmployee{}), modelType: reflect.TypeOf(bunEmployee{}),
parentKey: "Department", parentKey: "Department",
want: "dept_id", want: []string{"dept_id"},
},
{
name: "bun composite join returns all local columns",
modelType: reflect.TypeOf(bunCompositeEmployee{}),
parentKey: "department",
want: []string{"dept_id", "tenant_id"},
}, },
// GORM foreignKey: tag // GORM foreignKey: tag
@@ -60,19 +91,33 @@ func TestGetForeignKeyColumn(t *testing.T) {
name: "gorm foreignKey resolves to column name", name: "gorm foreignKey resolves to column name",
modelType: reflect.TypeOf(gormEmployee{}), modelType: reflect.TypeOf(gormEmployee{}),
parentKey: "department", parentKey: "department",
want: "department_id", want: []string{"department_id"},
}, },
{ {
name: "gorm foreignKey resolves second relation", name: "gorm foreignKey resolves second relation",
modelType: reflect.TypeOf(gormEmployee{}), modelType: reflect.TypeOf(gormEmployee{}),
parentKey: "manager", parentKey: "manager",
want: "manager_id", want: []string{"manager_id"},
}, },
{ {
name: "gorm foreignKey matched case-insensitively", name: "gorm foreignKey matched case-insensitively",
modelType: reflect.TypeOf(gormEmployee{}), modelType: reflect.TypeOf(gormEmployee{}),
parentKey: "Department", parentKey: "Department",
want: "department_id", want: []string{"department_id"},
},
{
name: "gorm composite foreignKey returns all columns",
modelType: reflect.TypeOf(gormCompositeEmployee{}),
parentKey: "department",
want: []string{"dept_id", "tenant_id"},
},
// Join left-side scan (parentKey is a raw column name, not a relation field name)
{
name: "self-referential: parent PK column returns child FK column",
modelType: reflect.TypeOf(selfRefItem{}),
parentKey: "rid_item",
want: []string{"rid_parentitem"},
}, },
// Pointer and slice unwrapping // Pointer and slice unwrapping
@@ -80,43 +125,43 @@ func TestGetForeignKeyColumn(t *testing.T) {
name: "pointer to struct is unwrapped", name: "pointer to struct is unwrapped",
modelType: reflect.TypeOf(&gormEmployee{}), modelType: reflect.TypeOf(&gormEmployee{}),
parentKey: "department", parentKey: "department",
want: "department_id", want: []string{"department_id"},
}, },
{ {
name: "slice of struct is unwrapped", name: "slice of struct is unwrapped",
modelType: reflect.TypeOf([]gormEmployee{}), modelType: reflect.TypeOf([]gormEmployee{}),
parentKey: "department", parentKey: "department",
want: "department_id", want: []string{"department_id"},
}, },
// No tag — returns "" so caller can fall back to convention // No tag — returns nil so caller can fall back to convention
{ {
name: "relation with no FK tag returns empty string", name: "relation with no FK tag returns nil",
modelType: reflect.TypeOf(conventionEmployee{}), modelType: reflect.TypeOf(conventionEmployee{}),
parentKey: "department", parentKey: "department",
want: "", want: nil,
}, },
// Unknown parent key // Unknown parent key
{ {
name: "unknown parent key returns empty string", name: "unknown parent key returns nil",
modelType: reflect.TypeOf(gormEmployee{}), modelType: reflect.TypeOf(gormEmployee{}),
parentKey: "nonexistent", parentKey: "nonexistent",
want: "", want: nil,
}, },
{ {
name: "non-struct type returns empty string", name: "non-struct type returns nil",
modelType: reflect.TypeOf(""), modelType: reflect.TypeOf(""),
parentKey: "department", parentKey: "department",
want: "", want: nil,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got := GetForeignKeyColumn(tt.modelType, tt.parentKey) got := GetForeignKeyColumn(tt.modelType, tt.parentKey)
if got != tt.want { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("GetForeignKeyColumn(%v, %q) = %q, want %q", tt.modelType, tt.parentKey, got, tt.want) t.Errorf("GetForeignKeyColumn(%v, %q) = %v, want %v", tt.modelType, tt.parentKey, got, tt.want)
} }
}) })
} }