reflection-based column validation for UpdateQuery

This commit is contained in:
Hein 2025-11-19 17:41:15 +02:00
parent 14daea3b05
commit 8b7db5b31a
4 changed files with 280 additions and 1 deletions

View File

@ -9,6 +9,7 @@ import (
"github.com/uptrace/bun"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// BunAdapter adapts Bun to work with our Database interface
@ -353,10 +354,12 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
// BunUpdateQuery implements UpdateQuery for Bun
type BunUpdateQuery struct {
query *bun.UpdateQuery
model interface{}
}
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
b.query = b.query.Model(model)
b.model = model
return b
}
@ -366,12 +369,22 @@ func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
}
func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
// Validate column is writable if model is set
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
// Skip scan-only columns
return b
}
b.query = b.query.Set(column+" = ?", value)
return b
}
func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
for column, value := range values {
// Validate column is writable if model is set
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
// Skip scan-only columns
continue
}
b.query = b.query.Set(column+" = ?", value)
}
return b

View File

@ -8,6 +8,7 @@ import (
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// GormAdapter adapts GORM to work with our Database interface
@ -343,6 +344,12 @@ func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
}
func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
// Validate column is writable if model is set
if g.model != nil && !reflection.IsColumnWritable(g.model, column) {
// Skip read-only columns
return g
}
if g.updates == nil {
g.updates = make(map[string]interface{})
}
@ -353,7 +360,18 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQue
}
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
// Filter out read-only columns if model is set
if g.model != nil {
filteredValues := make(map[string]interface{})
for column, value := range values {
if reflection.IsColumnWritable(g.model, column) {
filteredValues[column] = value
}
}
g.updates = filteredValues
} else {
g.updates = values
}
return g
}

View File

@ -0,0 +1,161 @@
package database
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// Test models for bun
type BunTestModel struct {
ID int `bun:"id,pk"`
Name string `bun:"name"`
Email string `bun:"email"`
ComputedCol string `bun:"computed_col,scanonly"`
}
// Test models for gorm
type GormTestModel struct {
ID int `gorm:"column:id;primaryKey"`
Name string `gorm:"column:name"`
Email string `gorm:"column:email"`
ReadOnlyCol string `gorm:"column:readonly_col;->"`
NoWriteCol string `gorm:"column:nowrite_col;<-:false"`
}
func TestIsColumnWritable_Bun(t *testing.T) {
model := &BunTestModel{}
tests := []struct {
name string
columnName string
expected bool
}{
{
name: "writable column - id",
columnName: "id",
expected: true,
},
{
name: "writable column - name",
columnName: "name",
expected: true,
},
{
name: "writable column - email",
columnName: "email",
expected: true,
},
{
name: "scanonly column should not be writable",
columnName: "computed_col",
expected: false,
},
{
name: "non-existent column should be writable (dynamic)",
columnName: "nonexistent",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := reflection.IsColumnWritable(model, tt.columnName)
if result != tt.expected {
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
}
})
}
}
func TestIsColumnWritable_Gorm(t *testing.T) {
model := &GormTestModel{}
tests := []struct {
name string
columnName string
expected bool
}{
{
name: "writable column - id",
columnName: "id",
expected: true,
},
{
name: "writable column - name",
columnName: "name",
expected: true,
},
{
name: "writable column - email",
columnName: "email",
expected: true,
},
{
name: "read-only column with -> should not be writable",
columnName: "readonly_col",
expected: false,
},
{
name: "column with <-:false should not be writable",
columnName: "nowrite_col",
expected: false,
},
{
name: "non-existent column should be writable (dynamic)",
columnName: "nonexistent",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := reflection.IsColumnWritable(model, tt.columnName)
if result != tt.expected {
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
}
})
}
}
func TestBunUpdateQuery_SetMap_FiltersScanOnly(t *testing.T) {
// Note: This is a unit test for the validation logic only.
// We can't fully test the bun query without a database connection,
// but we've verified the validation logic in TestIsColumnWritable_Bun
t.Skip("Skipping integration test - validation logic tested in TestIsColumnWritable_Bun")
}
func TestGormUpdateQuery_SetMap_FiltersReadOnly(t *testing.T) {
model := &GormTestModel{}
query := &GormUpdateQuery{
model: model,
}
// SetMap should filter out read-only columns
values := map[string]interface{}{
"name": "John",
"email": "john@example.com",
"readonly_col": "should_be_filtered",
"nowrite_col": "should_also_be_filtered",
}
query.SetMap(values)
// Check that the updates map only contains writable columns
if updates, ok := query.updates.(map[string]interface{}); ok {
if _, exists := updates["readonly_col"]; exists {
t.Error("readonly_col should have been filtered out")
}
if _, exists := updates["nowrite_col"]; exists {
t.Error("nowrite_col should have been filtered out")
}
if _, exists := updates["name"]; !exists {
t.Error("name should be in updates")
}
if _, exists := updates["email"]; !exists {
t.Error("email should be in updates")
}
} else {
t.Error("updates should be a map[string]interface{}")
}
}

View File

@ -236,3 +236,90 @@ func ExtractColumnFromBunTag(tag string) string {
}
return ""
}
// IsColumnWritable checks if a column can be written to in the database
// For bun: returns false if the field has "scanonly" tag
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
func IsColumnWritable(model any, columnName string) bool {
modelType := reflect.TypeOf(model)
// Unwrap pointers to get to the base struct type
for modelType != nil && modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return false
}
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Check if this field matches the column name
fieldColumnName := getColumnNameFromField(field)
if fieldColumnName != columnName {
continue
}
// Check bun tag for scanonly
bunTag := field.Tag.Get("bun")
if bunTag != "" {
if isBunFieldScanOnly(bunTag) {
return false
}
}
// Check gorm tag for write restrictions
gormTag := field.Tag.Get("gorm")
if gormTag != "" {
if isGormFieldReadOnly(gormTag) {
return false
}
}
// Column is writable
return true
}
// Column not found in model, allow it (might be a dynamic column)
return true
}
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
// Example: "column_name,scanonly" -> true
func isBunFieldScanOnly(tag string) bool {
parts := strings.Split(tag, ",")
for _, part := range parts {
if strings.TrimSpace(part) == "scanonly" {
return true
}
}
return false
}
// isGormFieldReadOnly checks if a gorm tag indicates the field is read-only
// Examples:
// - "<-:false" -> true (no writes allowed)
// - "->" -> true (read-only, common pattern)
// - "column:name;->" -> true
// - "<-:create" -> false (writes allowed on create)
func isGormFieldReadOnly(tag string) bool {
parts := strings.Split(tag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
// Check for read-only marker
if part == "->" {
return true
}
// Check for write restrictions
if value, found := strings.CutPrefix(part, "<-:"); found {
if value == "false" {
return true
}
}
}
return false
}