From 59bd70946028fc5bc5d30537bdbad9b46799f3f5 Mon Sep 17 00:00:00 2001 From: Hein Date: Fri, 21 Nov 2025 08:35:46 +0200 Subject: [PATCH] More reflection function to handle sql columns and get default sqlcolumn lists. --- pkg/reflection/model_utils.go | 121 +++++++++++++++++++++++++ pkg/reflection/model_utils_test.go | 140 +++++++++++++++++++++++++++++ pkg/resolvespec/handler.go | 7 +- pkg/restheadspec/handler.go | 2 +- 4 files changed, 268 insertions(+), 2 deletions(-) diff --git a/pkg/reflection/model_utils.go b/pkg/reflection/model_utils.go index ffea02a..9409e2f 100644 --- a/pkg/reflection/model_utils.go +++ b/pkg/reflection/model_utils.go @@ -323,6 +323,127 @@ func ExtractColumnFromBunTag(tag string) string { return "" } +// GetSQLModelColumns extracts column names that have valid SQL field mappings +// This function only returns columns that: +// 1. Have bun or gorm tags (not just json tags) +// 2. Are not relations (no rel:, join:, foreignKey, references, many2many tags) +// 3. Are not scan-only embedded fields +func GetSQLModelColumns(model any) []string { + var columns []string + + modelType := reflect.TypeOf(model) + + // Unwrap pointers, slices, and arrays to get to the base struct type + for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { + modelType = modelType.Elem() + } + + // Validate that we have a struct type + if modelType == nil || modelType.Kind() != reflect.Struct { + return columns + } + + collectSQLColumnsFromType(modelType, &columns, false) + + return columns +} + +// collectSQLColumnsFromType recursively collects SQL column names from a struct type +// scanOnlyEmbedded indicates if we're inside a scan-only embedded struct +func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbedded bool) { + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + + // Check if this is an embedded struct + if field.Anonymous { + // Unwrap pointer type if necessary + fieldType := field.Type + if fieldType.Kind() == reflect.Pointer { + fieldType = fieldType.Elem() + } + + // Check if the embedded struct itself is scan-only + isScanOnly := scanOnlyEmbedded + bunTag := field.Tag.Get("bun") + if bunTag != "" && isBunFieldScanOnly(bunTag) { + isScanOnly = true + } + + // Recursively process embedded struct + if fieldType.Kind() == reflect.Struct { + collectSQLColumnsFromType(fieldType, columns, isScanOnly) + continue + } + } + + // Skip fields in scan-only embedded structs + if scanOnlyEmbedded { + continue + } + + // Get bun and gorm tags + bunTag := field.Tag.Get("bun") + gormTag := field.Tag.Get("gorm") + + // Skip if neither bun nor gorm tag exists + if bunTag == "" && gormTag == "" { + continue + } + + // Skip if explicitly marked with "-" + if bunTag == "-" || gormTag == "-" { + continue + } + + // Skip if field itself is scan-only (bun) + if bunTag != "" && isBunFieldScanOnly(bunTag) { + continue + } + + // Skip if field itself is read-only (gorm) + if gormTag != "" && isGormFieldReadOnly(gormTag) { + continue + } + + // Skip relation fields (bun) + if bunTag != "" { + // Skip if it's a bun relation (rel:, join:, or m2m:) + if strings.Contains(bunTag, "rel:") || + strings.Contains(bunTag, "join:") || + strings.Contains(bunTag, "m2m:") { + continue + } + } + + // Skip relation fields (gorm) + if gormTag != "" { + // Skip if it has gorm relationship tags + if strings.Contains(gormTag, "foreignKey:") || + strings.Contains(gormTag, "references:") || + strings.Contains(gormTag, "many2many:") || + strings.Contains(gormTag, "constraint:") { + continue + } + } + + // Get column name + columnName := "" + if bunTag != "" { + columnName = ExtractColumnFromBunTag(bunTag) + } + if columnName == "" && gormTag != "" { + columnName = ExtractColumnFromGormTag(gormTag) + } + + // Skip if we couldn't extract a column name + if columnName == "" { + continue + } + + *columns = append(*columns, columnName) + } +} + // IsColumnWritable checks if a column can be written to in the database // For bun: returns false if the field has "scanonly" tag // For gorm: returns false if the field has "<-:false" or "->" (read-only) tag diff --git a/pkg/reflection/model_utils_test.go b/pkg/reflection/model_utils_test.go index 2e581a7..64814e5 100644 --- a/pkg/reflection/model_utils_test.go +++ b/pkg/reflection/model_utils_test.go @@ -474,3 +474,143 @@ func TestIsColumnWritableWithEmbedded(t *testing.T) { }) } } + +// Test models with relations for GetSQLModelColumns +type User struct { + ID int `bun:"id,pk" json:"id"` + Name string `bun:"name" json:"name"` + Email string `bun:"email" json:"email"` + ProfileData string `json:"profile_data"` // No bun/gorm tag + Posts []Post `bun:"rel:has-many,join:id=user_id" json:"posts"` + Profile *Profile `bun:"rel:has-one,join:id=user_id" json:"profile"` + RowNumber int64 `bun:",scanonly" json:"_rownumber"` +} + +type Post struct { + ID int `gorm:"column:id;primaryKey" json:"id"` + Title string `gorm:"column:title" json:"title"` + UserID int `gorm:"column:user_id;foreignKey" json:"user_id"` + User *User `gorm:"foreignKey:UserID;references:ID" json:"user"` + Tags []Tag `gorm:"many2many:post_tags" json:"tags"` + Content string `json:"content"` // No bun/gorm tag +} + +type Profile struct { + ID int `bun:"id,pk" json:"id"` + Bio string `bun:"bio" json:"bio"` + UserID int `bun:"user_id" json:"user_id"` +} + +type Tag struct { + ID int `gorm:"column:id;primaryKey" json:"id"` + Name string `gorm:"column:name" json:"name"` +} + +// Model with scan-only embedded struct +type EntityWithScanOnlyEmbedded struct { + ID int `bun:"id,pk" json:"id"` + Name string `bun:"name" json:"name"` + AdhocBuffer `bun:",scanonly"` // Entire embedded struct is scan-only +} + +func TestGetSQLModelColumns(t *testing.T) { + tests := []struct { + name string + model any + expected []string + }{ + { + name: "Bun model with relations - excludes relations and non-SQL fields", + model: User{}, + // Should include: id, name, email (has bun tags) + // Should exclude: profile_data (no bun tag), Posts/Profile (relations), RowNumber (scan-only in embedded would be excluded) + expected: []string{"id", "name", "email"}, + }, + { + name: "GORM model with relations - excludes relations and non-SQL fields", + model: Post{}, + // Should include: id, title, user_id (has gorm tags) + // Should exclude: content (no gorm tag), User/Tags (relations) + expected: []string{"id", "title", "user_id"}, + }, + { + name: "Model with embedded base and scan-only embedded", + model: EntityWithScanOnlyEmbedded{}, + // Should include: id, name from main struct + // Should exclude: all fields from AdhocBuffer (scan-only embedded struct) + expected: []string{"id", "name"}, + }, + { + name: "Model with embedded - includes SQL fields, excludes scan-only", + model: ModelWithEmbedded{}, + // Should include: rid_base, created_at (from BaseModel), name, description (from main) + // Should exclude: cql1, cql2, _rownumber (from AdhocBuffer - scan-only fields) + expected: []string{"rid_base", "created_at", "name", "description"}, + }, + { + name: "GORM model with embedded - includes SQL fields, excludes scan-only", + model: GormModelWithEmbedded{}, + // Should include: rid_base, created_at (from GormBaseModel), name, description (from main) + // Should exclude: cql1, cql2 (scan-only), _rownumber (no gorm column tag, marked as -) + expected: []string{"rid_base", "created_at", "name", "description"}, + }, + { + name: "Simple Profile model", + model: Profile{}, + // Should include all fields with bun tags + expected: []string{"id", "bio", "user_id"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetSQLModelColumns(tt.model) + if len(result) != len(tt.expected) { + t.Errorf("GetSQLModelColumns() returned %d columns, want %d.\nGot: %v\nWant: %v", + len(result), len(tt.expected), result, tt.expected) + return + } + for i, col := range result { + if col != tt.expected[i] { + t.Errorf("GetSQLModelColumns()[%d] = %v, want %v.\nFull result: %v", + i, col, tt.expected[i], result) + } + } + }) + } +} + +func TestGetSQLModelColumnsVsGetModelColumns(t *testing.T) { + // Demonstrate the difference between GetModelColumns and GetSQLModelColumns + user := User{} + + allColumns := GetModelColumns(user) + sqlColumns := GetSQLModelColumns(user) + + t.Logf("GetModelColumns(User): %v", allColumns) + t.Logf("GetSQLModelColumns(User): %v", sqlColumns) + + // GetModelColumns should return more columns (includes fields with only json tags) + if len(allColumns) <= len(sqlColumns) { + t.Errorf("Expected GetModelColumns to return more columns than GetSQLModelColumns") + } + + // GetSQLModelColumns should not include 'profile_data' (no bun tag) + for _, col := range sqlColumns { + if col == "profile_data" { + t.Errorf("GetSQLModelColumns should not include 'profile_data' (no bun/gorm tag)") + } + } + + // GetModelColumns should include 'profile_data' (has json tag) + hasProfileData := false + for _, col := range allColumns { + if col == "profile_data" { + hasProfileData = true + break + } + } + if !hasProfileData { + t.Errorf("GetModelColumns should include 'profile_data' (has json tag)") + } +} diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 2372733..ccc1cb5 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -191,6 +191,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st query = query.Table(tableName) } + if len(options.Columns) == 0 && (len(options.ComputedColumns) > 0) { + logger.Debug("Populating options.Columns with all model columns since computed columns are additions") + options.Columns = reflection.GetSQLModelColumns(model) + } + // Apply column selection if len(options.Columns) > 0 { logger.Debug("Selecting columns: %v", options.Columns) @@ -1145,7 +1150,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre logger.Debug("Applying preload: %s", relationFieldName) query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery { if len(preload.OmitColumns) > 0 { - allCols := reflection.GetModelColumns(model) + allCols := reflection.GetSQLModelColumns(model) // Remove omitted columns preload.Columns = []string{} for _, col := range allCols { diff --git a/pkg/restheadspec/handler.go b/pkg/restheadspec/handler.go index 60aeb48..a5e5250 100644 --- a/pkg/restheadspec/handler.go +++ b/pkg/restheadspec/handler.go @@ -264,7 +264,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st // populate it with all model columns first since computed columns are additions if len(options.Columns) == 0 && (len(options.ComputedQL) > 0 || len(options.ComputedColumns) > 0) { logger.Debug("Populating options.Columns with all model columns since computed columns are additions") - options.Columns = reflection.GetModelColumns(model) + options.Columns = reflection.GetSQLModelColumns(model) } // Apply ComputedQL fields if any