mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-06 14:26:22 +00:00
More reflection function to handle sql columns and get default sqlcolumn lists.
This commit is contained in:
parent
05962035b6
commit
59bd709460
@ -323,6 +323,127 @@ func ExtractColumnFromBunTag(tag string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSQLModelColumns extracts column names that have valid SQL field mappings
|
||||||
|
// This function only returns columns that:
|
||||||
|
// 1. Have bun or gorm tags (not just json tags)
|
||||||
|
// 2. Are not relations (no rel:, join:, foreignKey, references, many2many tags)
|
||||||
|
// 3. Are not scan-only embedded fields
|
||||||
|
func GetSQLModelColumns(model any) []string {
|
||||||
|
var columns []string
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
|
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that we have a struct type
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
collectSQLColumnsFromType(modelType, &columns, false)
|
||||||
|
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectSQLColumnsFromType recursively collects SQL column names from a struct type
|
||||||
|
// scanOnlyEmbedded indicates if we're inside a scan-only embedded struct
|
||||||
|
func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbedded bool) {
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the embedded struct itself is scan-only
|
||||||
|
isScanOnly := scanOnlyEmbedded
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||||
|
isScanOnly = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively process embedded struct
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
collectSQLColumnsFromType(fieldType, columns, isScanOnly)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip fields in scan-only embedded structs
|
||||||
|
if scanOnlyEmbedded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get bun and gorm tags
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
|
||||||
|
// Skip if neither bun nor gorm tag exists
|
||||||
|
if bunTag == "" && gormTag == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if explicitly marked with "-"
|
||||||
|
if bunTag == "-" || gormTag == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if field itself is scan-only (bun)
|
||||||
|
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if field itself is read-only (gorm)
|
||||||
|
if gormTag != "" && isGormFieldReadOnly(gormTag) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip relation fields (bun)
|
||||||
|
if bunTag != "" {
|
||||||
|
// Skip if it's a bun relation (rel:, join:, or m2m:)
|
||||||
|
if strings.Contains(bunTag, "rel:") ||
|
||||||
|
strings.Contains(bunTag, "join:") ||
|
||||||
|
strings.Contains(bunTag, "m2m:") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip relation fields (gorm)
|
||||||
|
if gormTag != "" {
|
||||||
|
// Skip if it has gorm relationship tags
|
||||||
|
if strings.Contains(gormTag, "foreignKey:") ||
|
||||||
|
strings.Contains(gormTag, "references:") ||
|
||||||
|
strings.Contains(gormTag, "many2many:") ||
|
||||||
|
strings.Contains(gormTag, "constraint:") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get column name
|
||||||
|
columnName := ""
|
||||||
|
if bunTag != "" {
|
||||||
|
columnName = ExtractColumnFromBunTag(bunTag)
|
||||||
|
}
|
||||||
|
if columnName == "" && gormTag != "" {
|
||||||
|
columnName = ExtractColumnFromGormTag(gormTag)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip if we couldn't extract a column name
|
||||||
|
if columnName == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
*columns = append(*columns, columnName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// IsColumnWritable checks if a column can be written to in the database
|
// IsColumnWritable checks if a column can be written to in the database
|
||||||
// For bun: returns false if the field has "scanonly" tag
|
// For bun: returns false if the field has "scanonly" tag
|
||||||
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
||||||
|
|||||||
@ -474,3 +474,143 @@ func TestIsColumnWritableWithEmbedded(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test models with relations for GetSQLModelColumns
|
||||||
|
type User struct {
|
||||||
|
ID int `bun:"id,pk" json:"id"`
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
Email string `bun:"email" json:"email"`
|
||||||
|
ProfileData string `json:"profile_data"` // No bun/gorm tag
|
||||||
|
Posts []Post `bun:"rel:has-many,join:id=user_id" json:"posts"`
|
||||||
|
Profile *Profile `bun:"rel:has-one,join:id=user_id" json:"profile"`
|
||||||
|
RowNumber int64 `bun:",scanonly" json:"_rownumber"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Post struct {
|
||||||
|
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||||
|
Title string `gorm:"column:title" json:"title"`
|
||||||
|
UserID int `gorm:"column:user_id;foreignKey" json:"user_id"`
|
||||||
|
User *User `gorm:"foreignKey:UserID;references:ID" json:"user"`
|
||||||
|
Tags []Tag `gorm:"many2many:post_tags" json:"tags"`
|
||||||
|
Content string `json:"content"` // No bun/gorm tag
|
||||||
|
}
|
||||||
|
|
||||||
|
type Profile struct {
|
||||||
|
ID int `bun:"id,pk" json:"id"`
|
||||||
|
Bio string `bun:"bio" json:"bio"`
|
||||||
|
UserID int `bun:"user_id" json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Tag struct {
|
||||||
|
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||||
|
Name string `gorm:"column:name" json:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model with scan-only embedded struct
|
||||||
|
type EntityWithScanOnlyEmbedded struct {
|
||||||
|
ID int `bun:"id,pk" json:"id"`
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
AdhocBuffer `bun:",scanonly"` // Entire embedded struct is scan-only
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSQLModelColumns(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with relations - excludes relations and non-SQL fields",
|
||||||
|
model: User{},
|
||||||
|
// Should include: id, name, email (has bun tags)
|
||||||
|
// Should exclude: profile_data (no bun tag), Posts/Profile (relations), RowNumber (scan-only in embedded would be excluded)
|
||||||
|
expected: []string{"id", "name", "email"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with relations - excludes relations and non-SQL fields",
|
||||||
|
model: Post{},
|
||||||
|
// Should include: id, title, user_id (has gorm tags)
|
||||||
|
// Should exclude: content (no gorm tag), User/Tags (relations)
|
||||||
|
expected: []string{"id", "title", "user_id"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Model with embedded base and scan-only embedded",
|
||||||
|
model: EntityWithScanOnlyEmbedded{},
|
||||||
|
// Should include: id, name from main struct
|
||||||
|
// Should exclude: all fields from AdhocBuffer (scan-only embedded struct)
|
||||||
|
expected: []string{"id", "name"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Model with embedded - includes SQL fields, excludes scan-only",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
// Should include: rid_base, created_at (from BaseModel), name, description (from main)
|
||||||
|
// Should exclude: cql1, cql2, _rownumber (from AdhocBuffer - scan-only fields)
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded - includes SQL fields, excludes scan-only",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
// Should include: rid_base, created_at (from GormBaseModel), name, description (from main)
|
||||||
|
// Should exclude: cql1, cql2 (scan-only), _rownumber (no gorm column tag, marked as -)
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Simple Profile model",
|
||||||
|
model: Profile{},
|
||||||
|
// Should include all fields with bun tags
|
||||||
|
expected: []string{"id", "bio", "user_id"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetSQLModelColumns(tt.model)
|
||||||
|
if len(result) != len(tt.expected) {
|
||||||
|
t.Errorf("GetSQLModelColumns() returned %d columns, want %d.\nGot: %v\nWant: %v",
|
||||||
|
len(result), len(tt.expected), result, tt.expected)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i, col := range result {
|
||||||
|
if col != tt.expected[i] {
|
||||||
|
t.Errorf("GetSQLModelColumns()[%d] = %v, want %v.\nFull result: %v",
|
||||||
|
i, col, tt.expected[i], result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSQLModelColumnsVsGetModelColumns(t *testing.T) {
|
||||||
|
// Demonstrate the difference between GetModelColumns and GetSQLModelColumns
|
||||||
|
user := User{}
|
||||||
|
|
||||||
|
allColumns := GetModelColumns(user)
|
||||||
|
sqlColumns := GetSQLModelColumns(user)
|
||||||
|
|
||||||
|
t.Logf("GetModelColumns(User): %v", allColumns)
|
||||||
|
t.Logf("GetSQLModelColumns(User): %v", sqlColumns)
|
||||||
|
|
||||||
|
// GetModelColumns should return more columns (includes fields with only json tags)
|
||||||
|
if len(allColumns) <= len(sqlColumns) {
|
||||||
|
t.Errorf("Expected GetModelColumns to return more columns than GetSQLModelColumns")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSQLModelColumns should not include 'profile_data' (no bun tag)
|
||||||
|
for _, col := range sqlColumns {
|
||||||
|
if col == "profile_data" {
|
||||||
|
t.Errorf("GetSQLModelColumns should not include 'profile_data' (no bun/gorm tag)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelColumns should include 'profile_data' (has json tag)
|
||||||
|
hasProfileData := false
|
||||||
|
for _, col := range allColumns {
|
||||||
|
if col == "profile_data" {
|
||||||
|
hasProfileData = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasProfileData {
|
||||||
|
t.Errorf("GetModelColumns should include 'profile_data' (has json tag)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -191,6 +191,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
query = query.Table(tableName)
|
query = query.Table(tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(options.Columns) == 0 && (len(options.ComputedColumns) > 0) {
|
||||||
|
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
|
||||||
|
options.Columns = reflection.GetSQLModelColumns(model)
|
||||||
|
}
|
||||||
|
|
||||||
// Apply column selection
|
// Apply column selection
|
||||||
if len(options.Columns) > 0 {
|
if len(options.Columns) > 0 {
|
||||||
logger.Debug("Selecting columns: %v", options.Columns)
|
logger.Debug("Selecting columns: %v", options.Columns)
|
||||||
@ -1145,7 +1150,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
|||||||
logger.Debug("Applying preload: %s", relationFieldName)
|
logger.Debug("Applying preload: %s", relationFieldName)
|
||||||
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
|
||||||
if len(preload.OmitColumns) > 0 {
|
if len(preload.OmitColumns) > 0 {
|
||||||
allCols := reflection.GetModelColumns(model)
|
allCols := reflection.GetSQLModelColumns(model)
|
||||||
// Remove omitted columns
|
// Remove omitted columns
|
||||||
preload.Columns = []string{}
|
preload.Columns = []string{}
|
||||||
for _, col := range allCols {
|
for _, col := range allCols {
|
||||||
|
|||||||
@ -264,7 +264,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// populate it with all model columns first since computed columns are additions
|
// populate it with all model columns first since computed columns are additions
|
||||||
if len(options.Columns) == 0 && (len(options.ComputedQL) > 0 || len(options.ComputedColumns) > 0) {
|
if len(options.Columns) == 0 && (len(options.ComputedQL) > 0 || len(options.ComputedColumns) > 0) {
|
||||||
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
|
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
|
||||||
options.Columns = reflection.GetModelColumns(model)
|
options.Columns = reflection.GetSQLModelColumns(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply ComputedQL fields if any
|
// Apply ComputedQL fields if any
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user