diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index 82ffb8d..aeed1b5 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -9,6 +9,7 @@ import ( "github.com/uptrace/bun" "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/reflection" ) // BunAdapter adapts Bun to work with our Database interface @@ -353,10 +354,12 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) { // BunUpdateQuery implements UpdateQuery for Bun type BunUpdateQuery struct { query *bun.UpdateQuery + model interface{} } func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery { b.query = b.query.Model(model) + b.model = model return b } @@ -366,12 +369,22 @@ func (b *BunUpdateQuery) Table(table string) common.UpdateQuery { } func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuery { + // Validate column is writable if model is set + if b.model != nil && !reflection.IsColumnWritable(b.model, column) { + // Skip scan-only columns + return b + } b.query = b.query.Set(column+" = ?", value) return b } func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery { for column, value := range values { + // Validate column is writable if model is set + if b.model != nil && !reflection.IsColumnWritable(b.model, column) { + // Skip scan-only columns + continue + } b.query = b.query.Set(column+" = ?", value) } return b diff --git a/pkg/common/adapters/database/gorm.go b/pkg/common/adapters/database/gorm.go index 4479976..1cee1ab 100644 --- a/pkg/common/adapters/database/gorm.go +++ b/pkg/common/adapters/database/gorm.go @@ -8,6 +8,7 @@ import ( "gorm.io/gorm" "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/reflection" ) // GormAdapter adapts GORM to work with our Database interface @@ -343,6 +344,12 @@ func (g *GormUpdateQuery) Table(table string) common.UpdateQuery { } func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery { + // Validate column is writable if model is set + if g.model != nil && !reflection.IsColumnWritable(g.model, column) { + // Skip read-only columns + return g + } + if g.updates == nil { g.updates = make(map[string]interface{}) } @@ -353,7 +360,18 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQue } func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery { - g.updates = values + // Filter out read-only columns if model is set + if g.model != nil { + filteredValues := make(map[string]interface{}) + for column, value := range values { + if reflection.IsColumnWritable(g.model, column) { + filteredValues[column] = value + } + } + g.updates = filteredValues + } else { + g.updates = values + } return g } diff --git a/pkg/common/adapters/database/update_validation_test.go b/pkg/common/adapters/database/update_validation_test.go new file mode 100644 index 0000000..f857972 --- /dev/null +++ b/pkg/common/adapters/database/update_validation_test.go @@ -0,0 +1,161 @@ +package database + +import ( + "testing" + + "github.com/bitechdev/ResolveSpec/pkg/reflection" +) + +// Test models for bun +type BunTestModel struct { + ID int `bun:"id,pk"` + Name string `bun:"name"` + Email string `bun:"email"` + ComputedCol string `bun:"computed_col,scanonly"` +} + +// Test models for gorm +type GormTestModel struct { + ID int `gorm:"column:id;primaryKey"` + Name string `gorm:"column:name"` + Email string `gorm:"column:email"` + ReadOnlyCol string `gorm:"column:readonly_col;->"` + NoWriteCol string `gorm:"column:nowrite_col;<-:false"` +} + +func TestIsColumnWritable_Bun(t *testing.T) { + model := &BunTestModel{} + + tests := []struct { + name string + columnName string + expected bool + }{ + { + name: "writable column - id", + columnName: "id", + expected: true, + }, + { + name: "writable column - name", + columnName: "name", + expected: true, + }, + { + name: "writable column - email", + columnName: "email", + expected: true, + }, + { + name: "scanonly column should not be writable", + columnName: "computed_col", + expected: false, + }, + { + name: "non-existent column should be writable (dynamic)", + columnName: "nonexistent", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := reflection.IsColumnWritable(model, tt.columnName) + if result != tt.expected { + t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected) + } + }) + } +} + +func TestIsColumnWritable_Gorm(t *testing.T) { + model := &GormTestModel{} + + tests := []struct { + name string + columnName string + expected bool + }{ + { + name: "writable column - id", + columnName: "id", + expected: true, + }, + { + name: "writable column - name", + columnName: "name", + expected: true, + }, + { + name: "writable column - email", + columnName: "email", + expected: true, + }, + { + name: "read-only column with -> should not be writable", + columnName: "readonly_col", + expected: false, + }, + { + name: "column with <-:false should not be writable", + columnName: "nowrite_col", + expected: false, + }, + { + name: "non-existent column should be writable (dynamic)", + columnName: "nonexistent", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := reflection.IsColumnWritable(model, tt.columnName) + if result != tt.expected { + t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected) + } + }) + } +} + +func TestBunUpdateQuery_SetMap_FiltersScanOnly(t *testing.T) { + // Note: This is a unit test for the validation logic only. + // We can't fully test the bun query without a database connection, + // but we've verified the validation logic in TestIsColumnWritable_Bun + t.Skip("Skipping integration test - validation logic tested in TestIsColumnWritable_Bun") +} + +func TestGormUpdateQuery_SetMap_FiltersReadOnly(t *testing.T) { + model := &GormTestModel{} + query := &GormUpdateQuery{ + model: model, + } + + // SetMap should filter out read-only columns + values := map[string]interface{}{ + "name": "John", + "email": "john@example.com", + "readonly_col": "should_be_filtered", + "nowrite_col": "should_also_be_filtered", + } + + query.SetMap(values) + + // Check that the updates map only contains writable columns + if updates, ok := query.updates.(map[string]interface{}); ok { + if _, exists := updates["readonly_col"]; exists { + t.Error("readonly_col should have been filtered out") + } + if _, exists := updates["nowrite_col"]; exists { + t.Error("nowrite_col should have been filtered out") + } + if _, exists := updates["name"]; !exists { + t.Error("name should be in updates") + } + if _, exists := updates["email"]; !exists { + t.Error("email should be in updates") + } + } else { + t.Error("updates should be a map[string]interface{}") + } +} diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index 2363ca4..f28d164 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -236,3 +236,90 @@ func ExtractColumnFromBunTag(tag string) string { } return "" } + +// IsColumnWritable checks if a column can be written to in the database +// For bun: returns false if the field has "scanonly" tag +// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag +func IsColumnWritable(model any, columnName string) bool { + modelType := reflect.TypeOf(model) + + // Unwrap pointers to get to the base struct type + for modelType != nil && modelType.Kind() == reflect.Pointer { + modelType = modelType.Elem() + } + + // Validate that we have a struct type + if modelType == nil || modelType.Kind() != reflect.Struct { + return false + } + + for i := 0; i < modelType.NumField(); i++ { + field := modelType.Field(i) + + // Check if this field matches the column name + fieldColumnName := getColumnNameFromField(field) + if fieldColumnName != columnName { + continue + } + + // Check bun tag for scanonly + bunTag := field.Tag.Get("bun") + if bunTag != "" { + if isBunFieldScanOnly(bunTag) { + return false + } + } + + // Check gorm tag for write restrictions + gormTag := field.Tag.Get("gorm") + if gormTag != "" { + if isGormFieldReadOnly(gormTag) { + return false + } + } + + // Column is writable + return true + } + + // Column not found in model, allow it (might be a dynamic column) + return true +} + +// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only +// Example: "column_name,scanonly" -> true +func isBunFieldScanOnly(tag string) bool { + parts := strings.Split(tag, ",") + for _, part := range parts { + if strings.TrimSpace(part) == "scanonly" { + return true + } + } + return false +} + +// isGormFieldReadOnly checks if a gorm tag indicates the field is read-only +// Examples: +// - "<-:false" -> true (no writes allowed) +// - "->" -> true (read-only, common pattern) +// - "column:name;->" -> true +// - "<-:create" -> false (writes allowed on create) +func isGormFieldReadOnly(tag string) bool { + parts := strings.Split(tag, ";") + for _, part := range parts { + part = strings.TrimSpace(part) + + // Check for read-only marker + if part == "->" { + return true + } + + // Check for write restrictions + if value, found := strings.CutPrefix(part, "<-:"); found { + if value == "false" { + return true + } + } + } + return false +}