mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 15:54:26 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb54ec5e27 | ||
|
|
7d6a9025f5 | ||
|
|
35089f511f | ||
|
|
66b6a0d835 | ||
|
|
456c165814 | ||
|
|
850d7b546c | ||
|
|
a44ef90d7c |
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -365,6 +366,14 @@ func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
|||||||
|
|
||||||
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
|
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
|
||||||
b.query = b.query.Table(table)
|
b.query = b.query.Table(table)
|
||||||
|
if b.model == nil {
|
||||||
|
// Try to get table name from table string if model is not set
|
||||||
|
|
||||||
|
model, err := modelregistry.GetModelByName(table)
|
||||||
|
if err == nil {
|
||||||
|
b.model = model
|
||||||
|
}
|
||||||
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -379,12 +388,17 @@ func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuer
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
||||||
|
pkName := reflection.GetPrimaryKeyName(b.model)
|
||||||
for column, value := range values {
|
for column, value := range values {
|
||||||
// Validate column is writable if model is set
|
// Validate column is writable if model is set
|
||||||
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
|
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
|
||||||
// Skip scan-only columns
|
// Skip scan-only columns
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if pkName != "" && column == pkName {
|
||||||
|
// Skip primary key updates
|
||||||
|
continue
|
||||||
|
}
|
||||||
b.query = b.query.Set(column+" = ?", value)
|
b.query = b.query.Set(column+" = ?", value)
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -98,6 +99,7 @@ func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
|||||||
g.db = g.db.Table(table)
|
g.db = g.db.Table(table)
|
||||||
// Check if the table name contains schema (e.g., "schema.table")
|
// Check if the table name contains schema (e.g., "schema.table")
|
||||||
g.schema, g.tableName = parseTableName(table)
|
g.schema, g.tableName = parseTableName(table)
|
||||||
|
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,6 +342,13 @@ func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
|||||||
|
|
||||||
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
|
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
|
||||||
g.db = g.db.Table(table)
|
g.db = g.db.Table(table)
|
||||||
|
if g.model == nil {
|
||||||
|
// Try to get table name from table string if model is not set
|
||||||
|
model, err := modelregistry.GetModelByName(table)
|
||||||
|
if err == nil {
|
||||||
|
g.model = model
|
||||||
|
}
|
||||||
|
}
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -360,13 +369,20 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQue
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
||||||
|
|
||||||
// Filter out read-only columns if model is set
|
// Filter out read-only columns if model is set
|
||||||
if g.model != nil {
|
if g.model != nil {
|
||||||
|
pkName := reflection.GetPrimaryKeyName(g.model)
|
||||||
filteredValues := make(map[string]interface{})
|
filteredValues := make(map[string]interface{})
|
||||||
for column, value := range values {
|
for column, value := range values {
|
||||||
|
if pkName != "" && column == pkName {
|
||||||
|
// Skip primary key updates
|
||||||
|
continue
|
||||||
|
}
|
||||||
if reflection.IsColumnWritable(g.model, column) {
|
if reflection.IsColumnWritable(g.model, column) {
|
||||||
filteredValues[column] = value
|
filteredValues[column] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
g.updates = filteredValues
|
g.updates = filteredValues
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -111,6 +111,9 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
// Inject parent IDs for foreign key resolution
|
// Inject parent IDs for foreign key resolution
|
||||||
p.injectForeignKeys(regularData, modelType, parentIDs)
|
p.injectForeignKeys(regularData, modelType, parentIDs)
|
||||||
|
|
||||||
|
// Get the primary key name for this model
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
// Process based on operation
|
// Process based on operation
|
||||||
switch strings.ToLower(operation) {
|
switch strings.ToLower(operation) {
|
||||||
case "insert", "create":
|
case "insert", "create":
|
||||||
@@ -128,30 +131,30 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
}
|
}
|
||||||
|
|
||||||
case "update":
|
case "update":
|
||||||
rows, err := p.processUpdate(ctx, regularData, tableName, data["id"])
|
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("update failed: %w", err)
|
return nil, fmt.Errorf("update failed: %w", err)
|
||||||
}
|
}
|
||||||
result.ID = data["id"]
|
result.ID = data[pkName]
|
||||||
result.AffectedRows = rows
|
result.AffectedRows = rows
|
||||||
result.Data = regularData
|
result.Data = regularData
|
||||||
|
|
||||||
// Process child relations for update
|
// Process child relations for update
|
||||||
if err := p.processChildRelations(ctx, "update", data["id"], relationFields, result.RelationData, modelType); err != nil {
|
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
case "delete":
|
case "delete":
|
||||||
// Process child relations first (for referential integrity)
|
// Process child relations first (for referential integrity)
|
||||||
if err := p.processChildRelations(ctx, "delete", data["id"], relationFields, result.RelationData, modelType); err != nil {
|
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||||
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
|
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := p.processDelete(ctx, tableName, data["id"])
|
rows, err := p.processDelete(ctx, tableName, data[pkName])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("delete failed: %w", err)
|
return nil, fmt.Errorf("delete failed: %w", err)
|
||||||
}
|
}
|
||||||
result.ID = data["id"]
|
result.ID = data[pkName]
|
||||||
result.AffectedRows = rows
|
result.AffectedRows = rows
|
||||||
result.Data = regularData
|
result.Data = regularData
|
||||||
|
|
||||||
@@ -378,8 +381,16 @@ func (p *NestedCUDProcessor) getTableNameForModel(model interface{}, defaultName
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ShouldUseNestedProcessor determines if we should use nested CUD processing
|
// ShouldUseNestedProcessor determines if we should use nested CUD processing
|
||||||
// It checks if the data contains nested relations or a _request field
|
// It recursively checks if the data contains:
|
||||||
|
// 1. A _request field at any level, OR
|
||||||
|
// 2. Nested relations that themselves contain further nested relations or _request fields
|
||||||
|
// This ensures nested processing is only used when there are deeply nested operations
|
||||||
func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider) bool {
|
func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider) bool {
|
||||||
|
return shouldUseNestedProcessorDepth(data, model, relationshipHelper, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldUseNestedProcessorDepth is the internal recursive implementation with depth tracking
|
||||||
|
func shouldUseNestedProcessorDepth(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider, depth int) bool {
|
||||||
// Check for _request field
|
// Check for _request field
|
||||||
if _, hasCRUDRequest := data["_request"]; hasCRUDRequest {
|
if _, hasCRUDRequest := data["_request"]; hasCRUDRequest {
|
||||||
return true
|
return true
|
||||||
@@ -407,19 +418,33 @@ func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, re
|
|||||||
if relInfo != nil {
|
if relInfo != nil {
|
||||||
// Check if the value is actually nested data (object or array)
|
// Check if the value is actually nested data (object or array)
|
||||||
switch v := value.(type) {
|
switch v := value.(type) {
|
||||||
case map[string]interface{}:
|
case map[string]interface{}, []interface{}, []map[string]interface{}:
|
||||||
//logger.Debug("Found nested relation field: %s", key)
|
// If we're already at a nested level (depth > 0) and found a relation,
|
||||||
return ShouldUseNestedProcessor(v, relInfo.RelatedModel, relationshipHelper)
|
// that means we have multi-level nesting, so return true
|
||||||
case []interface{}, []map[string]interface{}:
|
if depth > 0 {
|
||||||
//logger.Debug("Found nested relation field: %s", key)
|
return true
|
||||||
for _, item := range v.([]interface{}) {
|
}
|
||||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
// At depth 0, recurse to check if the nested data has further nesting
|
||||||
if ShouldUseNestedProcessor(itemMap, relInfo.RelatedModel, relationshipHelper) {
|
switch typedValue := v.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
if shouldUseNestedProcessorDepth(typedValue, relInfo.RelatedModel, relationshipHelper, depth+1) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
case []interface{}:
|
||||||
|
for _, item := range typedValue {
|
||||||
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||||
|
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []map[string]interface{}:
|
||||||
|
for _, itemMap := range typedValue {
|
||||||
|
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ var defaultRegistry = &DefaultModelRegistry{
|
|||||||
models: make(map[string]interface{}),
|
models: make(map[string]interface{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Global list of registries (searched in order)
|
||||||
|
var registries = []*DefaultModelRegistry{defaultRegistry}
|
||||||
|
var registriesMutex sync.RWMutex
|
||||||
|
|
||||||
// NewModelRegistry creates a new model registry
|
// NewModelRegistry creates a new model registry
|
||||||
func NewModelRegistry() *DefaultModelRegistry {
|
func NewModelRegistry() *DefaultModelRegistry {
|
||||||
return &DefaultModelRegistry{
|
return &DefaultModelRegistry{
|
||||||
@@ -24,6 +28,34 @@ func NewModelRegistry() *DefaultModelRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetDefaultRegistry(registry *DefaultModelRegistry) {
|
||||||
|
registriesMutex.Lock()
|
||||||
|
foundAt := -1
|
||||||
|
for idx, r := range registries {
|
||||||
|
if r == defaultRegistry {
|
||||||
|
foundAt = idx
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defaultRegistry = registry
|
||||||
|
if foundAt >= 0 {
|
||||||
|
registries[foundAt] = registry
|
||||||
|
} else {
|
||||||
|
registries = append([]*DefaultModelRegistry{registry}, registries...)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer registriesMutex.Unlock()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddRegistry adds a registry to the global list of registries
|
||||||
|
// Registries are searched in the order they were added
|
||||||
|
func AddRegistry(registry *DefaultModelRegistry) {
|
||||||
|
registriesMutex.Lock()
|
||||||
|
defer registriesMutex.Unlock()
|
||||||
|
registries = append(registries, registry)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error {
|
func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error {
|
||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
@@ -107,9 +139,19 @@ func RegisterModel(model interface{}, name string) error {
|
|||||||
return defaultRegistry.RegisterModel(name, model)
|
return defaultRegistry.RegisterModel(name, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModelByName retrieves a model from the default global registry by name
|
// GetModelByName retrieves a model by searching through all registries in order
|
||||||
|
// Returns the first match found
|
||||||
func GetModelByName(name string) (interface{}, error) {
|
func GetModelByName(name string) (interface{}, error) {
|
||||||
return defaultRegistry.GetModel(name)
|
registriesMutex.RLock()
|
||||||
|
defer registriesMutex.RUnlock()
|
||||||
|
|
||||||
|
for _, registry := range registries {
|
||||||
|
if model, err := registry.GetModel(name); err == nil {
|
||||||
|
return model, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("model %s not found in any registry", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IterateModels iterates over all models in the default global registry
|
// IterateModels iterates over all models in the default global registry
|
||||||
@@ -122,14 +164,26 @@ func IterateModels(fn func(name string, model interface{})) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetModels returns a list of all models in the default global registry
|
// GetModels returns a list of all models from all registries
|
||||||
|
// Models are collected in registry order, with duplicates included
|
||||||
func GetModels() []interface{} {
|
func GetModels() []interface{} {
|
||||||
defaultRegistry.mutex.RLock()
|
registriesMutex.RLock()
|
||||||
defer defaultRegistry.mutex.RUnlock()
|
defer registriesMutex.RUnlock()
|
||||||
|
|
||||||
models := make([]interface{}, 0, len(defaultRegistry.models))
|
var models []interface{}
|
||||||
for _, model := range defaultRegistry.models {
|
seen := make(map[string]bool)
|
||||||
models = append(models, model)
|
|
||||||
|
for _, registry := range registries {
|
||||||
|
registry.mutex.RLock()
|
||||||
|
for name, model := range registry.models {
|
||||||
|
// Only add the first occurrence of each model name
|
||||||
|
if !seen[name] {
|
||||||
|
models = append(models, model)
|
||||||
|
seen[name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
registry.mutex.RUnlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ type ModelFieldDetail struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
|
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
|
||||||
|
// This function recursively processes embedded structs to include their fields
|
||||||
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
@@ -37,14 +38,43 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
|||||||
if record.Kind() != reflect.Struct {
|
if record.Kind() != reflect.Struct {
|
||||||
return lst
|
return lst
|
||||||
}
|
}
|
||||||
|
|
||||||
|
collectFieldDetails(record, &lst)
|
||||||
|
|
||||||
|
return lst
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectFieldDetails recursively collects field details from a struct value and its embedded fields
|
||||||
|
func collectFieldDetails(record reflect.Value, lst *[]ModelFieldDetail) {
|
||||||
modeltype := record.Type()
|
modeltype := record.Type()
|
||||||
|
|
||||||
for i := 0; i < modeltype.NumField(); i++ {
|
for i := 0; i < modeltype.NumField(); i++ {
|
||||||
fieldtype := modeltype.Field(i)
|
fieldtype := modeltype.Field(i)
|
||||||
|
fieldValue := record.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if fieldtype.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
embeddedValue := fieldValue
|
||||||
|
if fieldValue.Kind() == reflect.Pointer {
|
||||||
|
if fieldValue.IsNil() {
|
||||||
|
// Skip nil embedded pointers
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
embeddedValue = fieldValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively process embedded struct
|
||||||
|
if embeddedValue.Kind() == reflect.Struct {
|
||||||
|
collectFieldDetails(embeddedValue, lst)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
gormdetail := fieldtype.Tag.Get("gorm")
|
gormdetail := fieldtype.Tag.Get("gorm")
|
||||||
gormdetail = strings.Trim(gormdetail, " ")
|
gormdetail = strings.Trim(gormdetail, " ")
|
||||||
fielddetail := ModelFieldDetail{}
|
fielddetail := ModelFieldDetail{}
|
||||||
fielddetail.FieldValue = record.Field(i)
|
fielddetail.FieldValue = fieldValue
|
||||||
fielddetail.Name = fieldtype.Name
|
fielddetail.Name = fieldtype.Name
|
||||||
fielddetail.DataType = fieldtype.Type.Name()
|
fielddetail.DataType = fieldtype.Type.Name()
|
||||||
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
|
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
|
||||||
@@ -80,10 +110,8 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
|||||||
}
|
}
|
||||||
// ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
|
// ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
|
||||||
|
|
||||||
lst = append(lst, fielddetail)
|
*lst = append(*lst, fielddetail)
|
||||||
|
|
||||||
}
|
}
|
||||||
return lst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func fnFindKeyVal(src, key string) string {
|
func fnFindKeyVal(src, key string) string {
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ func GetPrimaryKeyName(model any) string {
|
|||||||
|
|
||||||
// GetPrimaryKeyValue extracts the primary key value from a model instance
|
// GetPrimaryKeyValue extracts the primary key value from a model instance
|
||||||
// Returns the value of the primary key field
|
// Returns the value of the primary key field
|
||||||
func GetPrimaryKeyValue(model any) interface{} {
|
func GetPrimaryKeyValue(model any) any {
|
||||||
if model == nil || reflect.TypeOf(model) == nil {
|
if model == nil || reflect.TypeOf(model) == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -61,38 +61,51 @@ func GetPrimaryKeyValue(model any) interface{} {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
typ := val.Type()
|
|
||||||
|
|
||||||
// Try Bun tag first
|
// Try Bun tag first
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
if pkValue := findPrimaryKeyValue(val, "bun"); pkValue != nil {
|
||||||
field := typ.Field(i)
|
return pkValue
|
||||||
bunTag := field.Tag.Get("bun")
|
|
||||||
if strings.Contains(bunTag, "pk") {
|
|
||||||
fieldValue := val.Field(i)
|
|
||||||
if fieldValue.CanInterface() {
|
|
||||||
return fieldValue.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to GORM tag
|
// Fall back to GORM tag
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
if pkValue := findPrimaryKeyValue(val, "gorm"); pkValue != nil {
|
||||||
field := typ.Field(i)
|
return pkValue
|
||||||
gormTag := field.Tag.Get("gorm")
|
|
||||||
if strings.Contains(gormTag, "primaryKey") {
|
|
||||||
fieldValue := val.Field(i)
|
|
||||||
if fieldValue.CanInterface() {
|
|
||||||
return fieldValue.Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Last resort: look for field named "ID" or "Id"
|
// Last resort: look for field named "ID" or "Id"
|
||||||
|
if pkValue := findFieldByName(val, "id"); pkValue != nil {
|
||||||
|
return pkValue
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findPrimaryKeyValue recursively searches for a primary key field in the struct
|
||||||
|
func findPrimaryKeyValue(val reflect.Value, ormType string) any {
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
field := typ.Field(i)
|
field := typ.Field(i)
|
||||||
if strings.ToLower(field.Name) == "id" {
|
fieldValue := val.Field(i)
|
||||||
fieldValue := val.Field(i)
|
|
||||||
if fieldValue.CanInterface() {
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||||
|
// Recursively search in embedded struct
|
||||||
|
if pkValue := findPrimaryKeyValue(fieldValue, ormType); pkValue != nil {
|
||||||
|
return pkValue
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for primary key tag
|
||||||
|
switch ormType {
|
||||||
|
case "bun":
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if strings.Contains(bunTag, "pk") && fieldValue.CanInterface() {
|
||||||
|
return fieldValue.Interface()
|
||||||
|
}
|
||||||
|
case "gorm":
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
if strings.Contains(gormTag, "primaryKey") && fieldValue.CanInterface() {
|
||||||
return fieldValue.Interface()
|
return fieldValue.Interface()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,8 +114,35 @@ func GetPrimaryKeyValue(model any) interface{} {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// findFieldByName recursively searches for a field by name in the struct
|
||||||
|
func findFieldByName(val reflect.Value, name string) any {
|
||||||
|
typ := val.Type()
|
||||||
|
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
fieldValue := val.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||||
|
// Recursively search in embedded struct
|
||||||
|
if result := findFieldByName(fieldValue, name); result != nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if field name matches
|
||||||
|
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() {
|
||||||
|
return fieldValue.Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetModelColumns extracts all column names from a model using reflection
|
// GetModelColumns extracts all column names from a model using reflection
|
||||||
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
||||||
|
// This function recursively processes embedded structs to include their fields
|
||||||
func GetModelColumns(model any) []string {
|
func GetModelColumns(model any) []string {
|
||||||
var columns []string
|
var columns []string
|
||||||
|
|
||||||
@@ -118,18 +158,38 @@ func GetModelColumns(model any) []string {
|
|||||||
return columns
|
return columns
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
collectColumnsFromType(modelType, &columns)
|
||||||
field := modelType.Field(i)
|
|
||||||
|
return columns
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectColumnsFromType recursively collects column names from a struct type and its embedded fields
|
||||||
|
func collectColumnsFromType(typ reflect.Type, columns *[]string) {
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively process embedded struct
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
collectColumnsFromType(fieldType, columns)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Get column name using the same logic as primary key extraction
|
// Get column name using the same logic as primary key extraction
|
||||||
columnName := getColumnNameFromField(field)
|
columnName := getColumnNameFromField(field)
|
||||||
|
|
||||||
if columnName != "" {
|
if columnName != "" {
|
||||||
columns = append(columns, columnName)
|
*columns = append(*columns, columnName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return columns
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getColumnNameFromField extracts the column name from a struct field
|
// getColumnNameFromField extracts the column name from a struct field
|
||||||
@@ -166,6 +226,7 @@ func getColumnNameFromField(field reflect.StructField) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
||||||
|
// This function recursively searches embedded structs
|
||||||
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||||
val := reflect.ValueOf(model)
|
val := reflect.ValueOf(model)
|
||||||
if val.Kind() == reflect.Pointer {
|
if val.Kind() == reflect.Pointer {
|
||||||
@@ -177,9 +238,31 @@ func getPrimaryKeyFromReflection(model any, ormType string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
typ := val.Type()
|
typ := val.Type()
|
||||||
|
return findPrimaryKeyNameFromType(typ, ormType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// findPrimaryKeyNameFromType recursively searches for the primary key field name in a struct type
|
||||||
|
func findPrimaryKeyNameFromType(typ reflect.Type, ormType string) string {
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
field := typ.Field(i)
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively search in embedded struct
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
if pkName := findPrimaryKeyNameFromType(fieldType, ormType); pkName != "" {
|
||||||
|
return pkName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
switch ormType {
|
switch ormType {
|
||||||
case "gorm":
|
case "gorm":
|
||||||
// Check for gorm tag with primaryKey
|
// Check for gorm tag with primaryKey
|
||||||
@@ -231,6 +314,9 @@ func ExtractColumnFromGormTag(tag string) string {
|
|||||||
// Example: ",pk" -> "" (will fall back to json tag)
|
// Example: ",pk" -> "" (will fall back to json tag)
|
||||||
func ExtractColumnFromBunTag(tag string) string {
|
func ExtractColumnFromBunTag(tag string) string {
|
||||||
parts := strings.Split(tag, ",")
|
parts := strings.Split(tag, ",")
|
||||||
|
if strings.HasPrefix(strings.ToLower(tag), "table:") || strings.HasPrefix(strings.ToLower(tag), "rel:") || strings.HasPrefix(strings.ToLower(tag), "join:") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
if len(parts) > 0 && parts[0] != "" {
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
return parts[0]
|
return parts[0]
|
||||||
}
|
}
|
||||||
@@ -240,6 +326,7 @@ func ExtractColumnFromBunTag(tag string) string {
|
|||||||
// IsColumnWritable checks if a column can be written to in the database
|
// IsColumnWritable checks if a column can be written to in the database
|
||||||
// For bun: returns false if the field has "scanonly" tag
|
// For bun: returns false if the field has "scanonly" tag
|
||||||
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
||||||
|
// This function recursively searches embedded structs
|
||||||
func IsColumnWritable(model any, columnName string) bool {
|
func IsColumnWritable(model any, columnName string) bool {
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
@@ -253,8 +340,37 @@ func IsColumnWritable(model any, columnName string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
found, writable := isColumnWritableInType(modelType, columnName)
|
||||||
field := modelType.Field(i)
|
if found {
|
||||||
|
return writable
|
||||||
|
}
|
||||||
|
|
||||||
|
// Column not found in model, allow it (might be a dynamic column)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// isColumnWritableInType recursively searches for a column and checks if it's writable
|
||||||
|
// Returns (found, writable) where found indicates if the column was found
|
||||||
|
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) {
|
||||||
|
for i := 0; i < typ.NumField(); i++ {
|
||||||
|
field := typ.Field(i)
|
||||||
|
|
||||||
|
// Check if this is an embedded struct
|
||||||
|
if field.Anonymous {
|
||||||
|
// Unwrap pointer type if necessary
|
||||||
|
fieldType := field.Type
|
||||||
|
if fieldType.Kind() == reflect.Pointer {
|
||||||
|
fieldType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively search in embedded struct
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
if found, writable := isColumnWritableInType(fieldType, columnName); found {
|
||||||
|
return true, writable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Check if this field matches the column name
|
// Check if this field matches the column name
|
||||||
fieldColumnName := getColumnNameFromField(field)
|
fieldColumnName := getColumnNameFromField(field)
|
||||||
@@ -262,11 +378,12 @@ func IsColumnWritable(model any, columnName string) bool {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Found the field, now check if it's writable
|
||||||
// Check bun tag for scanonly
|
// Check bun tag for scanonly
|
||||||
bunTag := field.Tag.Get("bun")
|
bunTag := field.Tag.Get("bun")
|
||||||
if bunTag != "" {
|
if bunTag != "" {
|
||||||
if isBunFieldScanOnly(bunTag) {
|
if isBunFieldScanOnly(bunTag) {
|
||||||
return false
|
return true, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -274,16 +391,16 @@ func IsColumnWritable(model any, columnName string) bool {
|
|||||||
gormTag := field.Tag.Get("gorm")
|
gormTag := field.Tag.Get("gorm")
|
||||||
if gormTag != "" {
|
if gormTag != "" {
|
||||||
if isGormFieldReadOnly(gormTag) {
|
if isGormFieldReadOnly(gormTag) {
|
||||||
return false
|
return true, false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Column is writable
|
// Column is writable
|
||||||
return true
|
return true, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Column not found in model, allow it (might be a dynamic column)
|
// Column not found
|
||||||
return true
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
|
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
|
||||||
|
|||||||
@@ -231,3 +231,246 @@ func TestGetModelColumns(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test models with embedded structs
|
||||||
|
|
||||||
|
type BaseModel struct {
|
||||||
|
ID int `bun:"rid_base,pk" json:"id"`
|
||||||
|
CreatedAt string `bun:"created_at" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AdhocBuffer struct {
|
||||||
|
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:",scanonly"`
|
||||||
|
CQL2 string `json:"cql2,omitempty" gorm:"->" bun:",scanonly"`
|
||||||
|
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ModelWithEmbedded struct {
|
||||||
|
BaseModel
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
Description string `bun:"description" json:"description"`
|
||||||
|
AdhocBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
type GormBaseModel struct {
|
||||||
|
ID int `gorm:"column:rid_base;primaryKey" json:"id"`
|
||||||
|
CreatedAt string `gorm:"column:created_at" json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GormAdhocBuffer struct {
|
||||||
|
CQL1 string `json:"cql1,omitempty" gorm:"column:cql1;->" bun:",scanonly"`
|
||||||
|
CQL2 string `json:"cql2,omitempty" gorm:"column:cql2;->" bun:",scanonly"`
|
||||||
|
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GormModelWithEmbedded struct {
|
||||||
|
GormBaseModel
|
||||||
|
Name string `gorm:"column:name" json:"name"`
|
||||||
|
Description string `gorm:"column:description" json:"description"`
|
||||||
|
GormAdhocBuffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPrimaryKeyNameWithEmbedded(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded base",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
expected: "rid_base",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded base (pointer)",
|
||||||
|
model: &ModelWithEmbedded{},
|
||||||
|
expected: "rid_base",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded base",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
expected: "rid_base",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded base (pointer)",
|
||||||
|
model: &GormModelWithEmbedded{},
|
||||||
|
expected: "rid_base",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetPrimaryKeyName(tt.model)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPrimaryKeyValueWithEmbedded(t *testing.T) {
|
||||||
|
bunModel := ModelWithEmbedded{
|
||||||
|
BaseModel: BaseModel{
|
||||||
|
ID: 123,
|
||||||
|
CreatedAt: "2024-01-01",
|
||||||
|
},
|
||||||
|
Name: "Test",
|
||||||
|
Description: "Test Description",
|
||||||
|
}
|
||||||
|
|
||||||
|
gormModel := GormModelWithEmbedded{
|
||||||
|
GormBaseModel: GormBaseModel{
|
||||||
|
ID: 456,
|
||||||
|
CreatedAt: "2024-01-02",
|
||||||
|
},
|
||||||
|
Name: "GORM Test",
|
||||||
|
Description: "GORM Test Description",
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded base",
|
||||||
|
model: bunModel,
|
||||||
|
expected: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded base (pointer)",
|
||||||
|
model: &bunModel,
|
||||||
|
expected: 123,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded base",
|
||||||
|
model: gormModel,
|
||||||
|
expected: 456,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded base (pointer)",
|
||||||
|
model: &gormModel,
|
||||||
|
expected: 456,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetPrimaryKeyValue(tt.model)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetPrimaryKeyValue() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetModelColumnsWithEmbedded(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded structs",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model with embedded structs (pointer)",
|
||||||
|
model: &ModelWithEmbedded{},
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded structs",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model with embedded structs (pointer)",
|
||||||
|
model: &GormModelWithEmbedded{},
|
||||||
|
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetModelColumns(tt.model)
|
||||||
|
if len(result) != len(tt.expected) {
|
||||||
|
t.Errorf("GetModelColumns() returned %d columns, want %d. Got: %v", len(result), len(tt.expected), result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i, col := range result {
|
||||||
|
if col != tt.expected[i] {
|
||||||
|
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsColumnWritableWithEmbedded(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
model any
|
||||||
|
columnName string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Bun model - writable column in main struct",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
columnName: "name",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model - writable column in embedded base",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
columnName: "rid_base",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model - scanonly column in embedded adhoc buffer",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
columnName: "cql1",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Bun model - scanonly column _rownumber",
|
||||||
|
model: ModelWithEmbedded{},
|
||||||
|
columnName: "_rownumber",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model - writable column in main struct",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
columnName: "name",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model - writable column in embedded base",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
columnName: "rid_base",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model - readonly column in embedded adhoc buffer",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
columnName: "cql1",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "GORM model - readonly column _rownumber",
|
||||||
|
model: GormModelWithEmbedded{},
|
||||||
|
columnName: "_rownumber",
|
||||||
|
expected: false, // bun:",scanonly" marks it as read-only, takes precedence over gorm:"-"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := IsColumnWritable(tt.model, tt.columnName)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsColumnWritable(%s) = %v, want %v", tt.columnName, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -610,6 +610,9 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
dataSlice := h.normalizeToSlice(data)
|
dataSlice := h.normalizeToSlice(data)
|
||||||
logger.Debug("Processing %d item(s) for creation", len(dataSlice))
|
logger.Debug("Processing %d item(s) for creation", len(dataSlice))
|
||||||
|
|
||||||
|
// Store original data maps for merging later
|
||||||
|
originalDataMaps := make([]map[string]interface{}, 0, len(dataSlice))
|
||||||
|
|
||||||
// Process all items in a transaction
|
// Process all items in a transaction
|
||||||
results := make([]interface{}, 0, len(dataSlice))
|
results := make([]interface{}, 0, len(dataSlice))
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
@@ -630,6 +633,13 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store a copy of the original data map for merging later
|
||||||
|
originalMap := make(map[string]interface{})
|
||||||
|
for k, v := range itemMap {
|
||||||
|
originalMap[k] = v
|
||||||
|
}
|
||||||
|
originalDataMaps = append(originalDataMaps, originalMap)
|
||||||
|
|
||||||
// Extract nested relations if present (but don't process them yet)
|
// Extract nested relations if present (but don't process them yet)
|
||||||
var nestedRelations map[string]interface{}
|
var nestedRelations map[string]interface{}
|
||||||
if h.shouldUseNestedProcessor(itemMap, model) {
|
if h.shouldUseNestedProcessor(itemMap, model) {
|
||||||
@@ -704,14 +714,26 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Merge created records with original request data
|
||||||
|
// This preserves extra keys from the request
|
||||||
|
mergedResults := make([]interface{}, 0, len(results))
|
||||||
|
for i, result := range results {
|
||||||
|
if i < len(originalDataMaps) {
|
||||||
|
merged := h.mergeRecordWithRequest(result, originalDataMaps[i])
|
||||||
|
mergedResults = append(mergedResults, merged)
|
||||||
|
} else {
|
||||||
|
mergedResults = append(mergedResults, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Execute AfterCreate hooks
|
// Execute AfterCreate hooks
|
||||||
var responseData interface{}
|
var responseData interface{}
|
||||||
if len(results) == 1 {
|
if len(mergedResults) == 1 {
|
||||||
responseData = results[0]
|
responseData = mergedResults[0]
|
||||||
hookCtx.Result = results[0]
|
hookCtx.Result = mergedResults[0]
|
||||||
} else {
|
} else {
|
||||||
responseData = results
|
responseData = mergedResults
|
||||||
hookCtx.Result = map[string]interface{}{"created": len(results), "data": results}
|
hookCtx.Result = map[string]interface{}{"created": len(mergedResults), "data": mergedResults}
|
||||||
}
|
}
|
||||||
hookCtx.Error = nil
|
hookCtx.Error = nil
|
||||||
|
|
||||||
@@ -721,7 +743,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully created %d record(s)", len(results))
|
logger.Info("Successfully created %d record(s)", len(mergedResults))
|
||||||
h.sendResponseWithOptions(w, responseData, nil, &options)
|
h.sendResponseWithOptions(w, responseData, nil, &options)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -790,6 +812,12 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the primary key name for the model
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
|
// Variable to store the updated record
|
||||||
|
var updatedRecord interface{}
|
||||||
|
|
||||||
// Process nested relations if present
|
// Process nested relations if present
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
// Create temporary nested processor with transaction
|
// Create temporary nested processor with transaction
|
||||||
@@ -808,11 +836,10 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Ensure ID is in the data map for the update
|
// Ensure ID is in the data map for the update
|
||||||
dataMap["id"] = targetID
|
dataMap[pkName] = targetID
|
||||||
|
|
||||||
// Create update query
|
// Create update query
|
||||||
query := tx.NewUpdate().Table(tableName).SetMap(dataMap)
|
query := tx.NewUpdate().Table(tableName).SetMap(dataMap)
|
||||||
pkName := reflection.GetPrimaryKeyName(model)
|
|
||||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
@@ -840,10 +867,18 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store result for hooks
|
// Fetch the updated record to return the new values
|
||||||
hookCtx.Result = map[string]interface{}{
|
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
"updated": result.RowsAffected(),
|
selectQuery := tx.NewSelect().Model(modelValue).Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to fetch updated record: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
updatedRecord = modelValue
|
||||||
|
|
||||||
|
// Store result for hooks
|
||||||
|
hookCtx.Result = updatedRecord
|
||||||
|
_ = result // Keep result variable for potential future use
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -853,7 +888,12 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Merge the updated record with the original request data
|
||||||
|
// This preserves extra keys from the request and updates values from the database
|
||||||
|
mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap)
|
||||||
|
|
||||||
// Execute AfterUpdate hooks
|
// Execute AfterUpdate hooks
|
||||||
|
hookCtx.Result = mergedData
|
||||||
hookCtx.Error = nil
|
hookCtx.Error = nil
|
||||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||||
logger.Error("AfterUpdate hook failed: %v", err)
|
logger.Error("AfterUpdate hook failed: %v", err)
|
||||||
@@ -862,7 +902,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully updated record with ID: %v", targetID)
|
logger.Info("Successfully updated record with ID: %v", targetID)
|
||||||
h.sendResponseWithOptions(w, hookCtx.Result, nil, &options)
|
h.sendResponseWithOptions(w, mergedData, nil, &options)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) {
|
func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id string, data interface{}) {
|
||||||
@@ -936,6 +976,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
// Array of IDs or objects with ID field
|
// Array of IDs or objects with ID field
|
||||||
logger.Info("Batch delete with %d items ([]interface{})", len(v))
|
logger.Info("Batch delete with %d items ([]interface{})", len(v))
|
||||||
deletedCount := 0
|
deletedCount := 0
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
var itemID interface{}
|
var itemID interface{}
|
||||||
@@ -945,7 +986,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
case string:
|
case string:
|
||||||
itemID = v
|
itemID = v
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
itemID = v["id"]
|
itemID = v[pkName]
|
||||||
default:
|
default:
|
||||||
itemID = item
|
itemID = item
|
||||||
}
|
}
|
||||||
@@ -1002,9 +1043,10 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
// Array of objects with id field
|
// Array of objects with id field
|
||||||
logger.Info("Batch delete with %d items ([]map[string]interface{})", len(v))
|
logger.Info("Batch delete with %d items ([]map[string]interface{})", len(v))
|
||||||
deletedCount := 0
|
deletedCount := 0
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
if itemID, ok := item["id"]; ok && itemID != nil {
|
if itemID, ok := item[pkName]; ok && itemID != nil {
|
||||||
itemIDStr := fmt.Sprintf("%v", itemID)
|
itemIDStr := fmt.Sprintf("%v", itemID)
|
||||||
|
|
||||||
// Execute hooks for each item
|
// Execute hooks for each item
|
||||||
@@ -1052,7 +1094,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
case map[string]interface{}:
|
case map[string]interface{}:
|
||||||
// Single object with id field
|
// Single object with id field
|
||||||
if itemID, ok := v["id"]; ok && itemID != nil {
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
if itemID, ok := v[pkName]; ok && itemID != nil {
|
||||||
id = fmt.Sprintf("%v", itemID)
|
id = fmt.Sprintf("%v", itemID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1122,6 +1165,39 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
h.sendResponse(w, responseData, nil)
|
h.sendResponse(w, responseData, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mergeRecordWithRequest merges a database record with the original request data
|
||||||
|
// This preserves extra keys from the request that aren't in the database model
|
||||||
|
// and updates values from the database (e.g., from SQL triggers or defaults)
|
||||||
|
func (h *Handler) mergeRecordWithRequest(dbRecord interface{}, requestData map[string]interface{}) map[string]interface{} {
|
||||||
|
// Convert the database record to a map
|
||||||
|
dbMap := make(map[string]interface{})
|
||||||
|
|
||||||
|
// Marshal and unmarshal to convert struct to map
|
||||||
|
jsonData, err := json.Marshal(dbRecord)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to marshal database record for merging: %v", err)
|
||||||
|
return requestData
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &dbMap); err != nil {
|
||||||
|
logger.Warn("Failed to unmarshal database record for merging: %v", err)
|
||||||
|
return requestData
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start with the request data (preserves extra keys)
|
||||||
|
result := make(map[string]interface{})
|
||||||
|
for k, v := range requestData {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update with values from database (overwrites with DB values, including trigger changes)
|
||||||
|
for k, v := range dbMap {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// normalizeToSlice converts data to a slice. Single items become a 1-item slice.
|
// normalizeToSlice converts data to a slice. Single items become a 1-item slice.
|
||||||
func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
|
func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
|
||||||
if data == nil {
|
if data == nil {
|
||||||
@@ -1658,22 +1734,22 @@ func (h *Handler) cleanJSON(data interface{}) interface{} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, message string, err error) {
|
func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, message string, err error) {
|
||||||
var details string
|
var errorMsg string
|
||||||
if err != nil {
|
if err != nil {
|
||||||
details = err.Error()
|
errorMsg = err.Error()
|
||||||
|
} else if message != "" {
|
||||||
|
errorMsg = message
|
||||||
|
} else {
|
||||||
|
errorMsg = code
|
||||||
}
|
}
|
||||||
|
|
||||||
response := common.Response{
|
response := map[string]interface{}{
|
||||||
Success: false,
|
"_error": errorMsg,
|
||||||
Error: &common.APIError{
|
"_retval": 1,
|
||||||
Code: code,
|
|
||||||
Message: message,
|
|
||||||
Details: details,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
w.WriteHeader(statusCode)
|
w.WriteHeader(statusCode)
|
||||||
if err := w.WriteJSON(response); err != nil {
|
if jsonErr := w.WriteJSON(response); jsonErr != nil {
|
||||||
logger.Error("Failed to write JSON error response: %v", err)
|
logger.Error("Failed to write JSON error response: %v", jsonErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1904,7 +1980,8 @@ func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// shouldUseNestedProcessor determines if we should use nested CUD processing
|
// shouldUseNestedProcessor determines if we should use nested CUD processing
|
||||||
// It checks if the data contains nested relations or a _request field
|
// It recursively checks if the data contains deeply nested relations or _request fields
|
||||||
|
// Simple one-level relations without further nesting don't require the nested processor
|
||||||
func (h *Handler) shouldUseNestedProcessor(data map[string]interface{}, model interface{}) bool {
|
func (h *Handler) shouldUseNestedProcessor(data map[string]interface{}, model interface{}) bool {
|
||||||
return common.ShouldUseNestedProcessor(data, model, h)
|
return common.ShouldUseNestedProcessor(data, model, h)
|
||||||
}
|
}
|
||||||
@@ -1966,12 +2043,40 @@ func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName strin
|
|||||||
// Determine if it's belongsTo or hasMany/hasOne
|
// Determine if it's belongsTo or hasMany/hasOne
|
||||||
if field.Type.Kind() == reflect.Slice {
|
if field.Type.Kind() == reflect.Slice {
|
||||||
info.relationType = "hasMany"
|
info.relationType = "hasMany"
|
||||||
|
// Get the element type for slice
|
||||||
|
elemType := field.Type.Elem()
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
info.relatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
|
}
|
||||||
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
} else if field.Type.Kind() == reflect.Ptr || field.Type.Kind() == reflect.Struct {
|
||||||
info.relationType = "belongsTo"
|
info.relationType = "belongsTo"
|
||||||
|
elemType := field.Type
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
info.relatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if strings.Contains(gormTag, "many2many") {
|
} else if strings.Contains(gormTag, "many2many") {
|
||||||
info.relationType = "many2many"
|
info.relationType = "many2many"
|
||||||
info.joinTable = h.extractTagValue(gormTag, "many2many")
|
info.joinTable = h.extractTagValue(gormTag, "many2many")
|
||||||
|
// Get the element type for many2many (always slice)
|
||||||
|
if field.Type.Kind() == reflect.Slice {
|
||||||
|
elemType := field.Type.Elem()
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
}
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
info.relatedModel = reflect.New(elemType).Elem().Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Field has no GORM relationship tags, so it's not a relation
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ func TestShouldUseNestedProcessor(t *testing.T) {
|
|||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Data with nested posts",
|
name: "Data with simple nested posts (no further nesting)",
|
||||||
data: map[string]interface{}{
|
data: map[string]interface{}{
|
||||||
"name": "John",
|
"name": "John",
|
||||||
"posts": []map[string]interface{}{
|
"posts": []map[string]interface{}{
|
||||||
@@ -140,7 +140,23 @@ func TestShouldUseNestedProcessor(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
model: TestUser{},
|
model: TestUser{},
|
||||||
expected: true,
|
expected: false, // Simple one-level nesting doesn't require nested processor
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Data with deeply nested relations",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"name": "John",
|
||||||
|
"posts": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"title": "Post 1",
|
||||||
|
"comments": []map[string]interface{}{
|
||||||
|
{"content": "Comment 1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expected: true, // Multi-level nesting requires nested processor
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Data without nested relations",
|
name: "Data without nested relations",
|
||||||
@@ -159,6 +175,20 @@ func TestShouldUseNestedProcessor(t *testing.T) {
|
|||||||
model: TestUser{},
|
model: TestUser{},
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Nested data with _request field",
|
||||||
|
data: map[string]interface{}{
|
||||||
|
"name": "John",
|
||||||
|
"posts": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"_request": "insert",
|
||||||
|
"title": "Post 1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
model: TestUser{},
|
||||||
|
expected: true, // _request at nested level requires nested processor
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
type TestModel struct {
|
type TestModel struct {
|
||||||
ID int64 `json:"id" bun:"id,pk"`
|
ID int64 `json:"id" bun:"id,pk"`
|
||||||
Name string `json:"name" bun:"name"`
|
Name string `json:"name" bun:"name"`
|
||||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:"-"`
|
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetRowNumbersOnRecords(t *testing.T) {
|
func TestSetRowNumbersOnRecords(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user