From d122c7af42c4346374aa518fa67cb269285f7a75 Mon Sep 17 00:00:00 2001 From: Hein Date: Fri, 7 Nov 2025 08:26:50 +0200 Subject: [PATCH] Updated how model registry works --- pkg/common/adapters/database/bun.go | 85 ++++++++++++++++++++++++++- pkg/common/adapters/database/gorm.go | 86 +++++++++++++++++++++++++++- pkg/modelregistry/model_registry.go | 19 +++++- pkg/resolvespec/handler.go | 12 +++- pkg/testmodels/business.go | 24 ++++---- 5 files changed, 203 insertions(+), 23 deletions(-) diff --git a/pkg/common/adapters/database/bun.go b/pkg/common/adapters/database/bun.go index a26b675..1fedffd 100644 --- a/pkg/common/adapters/database/bun.go +++ b/pkg/common/adapters/database/bun.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "strings" "github.com/Warky-Devs/ResolveSpec/pkg/common" "github.com/uptrace/bun" @@ -76,16 +77,25 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa // BunSelectQuery implements SelectQuery for Bun type BunSelectQuery struct { - query *bun.SelectQuery + query *bun.SelectQuery + tableName string + tableAlias string } func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery { b.query = b.query.Model(model) + + // Try to get table name from model if it implements TableNameProvider + if provider, ok := model.(common.TableNameProvider); ok { + b.tableName = provider.TableName() + } + return b } func (b *BunSelectQuery) Table(table string) common.SelectQuery { b.query = b.query.Table(table) + b.tableName = table return b } @@ -105,12 +115,81 @@ func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.Selec } func (b *BunSelectQuery) Join(query string, args ...interface{}) common.SelectQuery { - b.query = b.query.Join(query, args...) + // Extract optional prefix from args + // If the last arg is a string that looks like a table prefix, use it + var prefix string + sqlArgs := args + + if len(args) > 0 { + if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") { + // Likely a prefix, not a SQL parameter + prefix = lastArg + sqlArgs = args[:len(args)-1] + } + } + + // If no prefix provided, use the table name as prefix + if prefix == "" && b.tableName != "" { + prefix = b.tableName + // Extract just the table name if it has schema + if idx := strings.LastIndex(prefix, "."); idx != -1 { + prefix = prefix[idx+1:] + } + } + + // If prefix is provided, add it as an alias in the join + // Bun expects: "JOIN table AS alias ON condition" + joinClause := query + if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") { + // If query doesn't already have AS, check if it's a simple table name + parts := strings.Fields(query) + if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") { + // Simple table name, add prefix: "table AS prefix" + joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix) + if len(parts) > 1 { + // Has ON clause: "table ON ..." becomes "table AS prefix ON ..." + joinClause += " " + strings.Join(parts[1:], " ") + } + } + } + + b.query = b.query.Join(joinClause, sqlArgs...) return b } func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery { - b.query = b.query.Join("LEFT JOIN " + query, args...) + // Extract optional prefix from args + var prefix string + sqlArgs := args + + if len(args) > 0 { + if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") { + prefix = lastArg + sqlArgs = args[:len(args)-1] + } + } + + // If no prefix provided, use the table name as prefix + if prefix == "" && b.tableName != "" { + prefix = b.tableName + if idx := strings.LastIndex(prefix, "."); idx != -1 { + prefix = prefix[idx+1:] + } + } + + // Construct LEFT JOIN with prefix + joinClause := query + if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") { + parts := strings.Fields(query) + if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "LEFT") && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") { + joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix) + if len(parts) > 1 { + joinClause += " " + strings.Join(parts[1:], " ") + } + } + } + + b.query = b.query.Join("LEFT JOIN " + joinClause, sqlArgs...) return b } diff --git a/pkg/common/adapters/database/gorm.go b/pkg/common/adapters/database/gorm.go index 26de9e9..4748cf4 100644 --- a/pkg/common/adapters/database/gorm.go +++ b/pkg/common/adapters/database/gorm.go @@ -2,6 +2,8 @@ package database import ( "context" + "fmt" + "strings" "github.com/Warky-Devs/ResolveSpec/pkg/common" "gorm.io/gorm" @@ -67,16 +69,25 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab // GormSelectQuery implements SelectQuery for GORM type GormSelectQuery struct { - db *gorm.DB + db *gorm.DB + tableName string + tableAlias string } func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery { g.db = g.db.Model(model) + + // Try to get table name from model if it implements TableNameProvider + if provider, ok := model.(common.TableNameProvider); ok { + g.tableName = provider.TableName() + } + return g } func (g *GormSelectQuery) Table(table string) common.SelectQuery { g.db = g.db.Table(table) + g.tableName = table return g } @@ -96,12 +107,81 @@ func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.Sele } func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQuery { - g.db = g.db.Joins(query, args...) + // Extract optional prefix from args + // If the last arg is a string that looks like a table prefix, use it + var prefix string + sqlArgs := args + + if len(args) > 0 { + if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") { + // Likely a prefix, not a SQL parameter + prefix = lastArg + sqlArgs = args[:len(args)-1] + } + } + + // If no prefix provided, use the table name as prefix + if prefix == "" && g.tableName != "" { + prefix = g.tableName + // Extract just the table name if it has schema + if idx := strings.LastIndex(prefix, "."); idx != -1 { + prefix = prefix[idx+1:] + } + } + + // If prefix is provided, add it as an alias in the join + // GORM expects: "JOIN table AS alias ON condition" + joinClause := query + if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") { + // If query doesn't already have AS, check if it's a simple table name + parts := strings.Fields(query) + if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") { + // Simple table name, add prefix: "table AS prefix" + joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix) + if len(parts) > 1 { + // Has ON clause: "table ON ..." becomes "table AS prefix ON ..." + joinClause += " " + strings.Join(parts[1:], " ") + } + } + } + + g.db = g.db.Joins(joinClause, sqlArgs...) return g } func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery { - g.db = g.db.Joins("LEFT JOIN "+query, args...) + // Extract optional prefix from args + var prefix string + sqlArgs := args + + if len(args) > 0 { + if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") { + prefix = lastArg + sqlArgs = args[:len(args)-1] + } + } + + // If no prefix provided, use the table name as prefix + if prefix == "" && g.tableName != "" { + prefix = g.tableName + if idx := strings.LastIndex(prefix, "."); idx != -1 { + prefix = prefix[idx+1:] + } + } + + // Construct LEFT JOIN with prefix + joinClause := query + if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") { + parts := strings.Fields(query) + if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "LEFT") && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") { + joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix) + if len(parts) > 1 { + joinClause += " " + strings.Join(parts[1:], " ") + } + } + } + + g.db = g.db.Joins("LEFT JOIN "+joinClause, sqlArgs...) return g } diff --git a/pkg/modelregistry/model_registry.go b/pkg/modelregistry/model_registry.go index d30eb25..930da2e 100644 --- a/pkg/modelregistry/model_registry.go +++ b/pkg/modelregistry/model_registry.go @@ -2,6 +2,7 @@ package modelregistry import ( "fmt" + "reflect" "sync" ) @@ -26,11 +27,25 @@ func NewModelRegistry() *DefaultModelRegistry { func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error { r.mutex.Lock() defer r.mutex.Unlock() - + if _, exists := r.models[name]; exists { return fmt.Errorf("model %s already registered", name) } - + + // Validate that model is a non-pointer struct + modelType := reflect.TypeOf(model) + if modelType == nil { + return fmt.Errorf("model cannot be nil") + } + + if modelType.Kind() == reflect.Ptr { + return fmt.Errorf("model must be a non-pointer struct, got pointer to %s", modelType.Elem().Kind()) + } + + if modelType.Kind() != reflect.Struct { + return fmt.Errorf("model must be a struct, got %s", modelType.Kind()) + } + r.models[name] = model return nil } diff --git a/pkg/resolvespec/handler.go b/pkg/resolvespec/handler.go index 8a84630..9a274c9 100644 --- a/pkg/resolvespec/handler.go +++ b/pkg/resolvespec/handler.go @@ -93,7 +93,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schem return } - query := h.db.NewSelect().Model(model) + // Model is now a non-pointer struct, create a pointer instance for ORM + modelType := reflect.TypeOf(model) + modelPtr := reflect.New(modelType).Interface() + + query := h.db.NewSelect().Model(modelPtr) // Get table name tableName := h.getTableName(schema, entity, model) @@ -149,7 +153,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schem var result interface{} if id != "" { logger.Debug("Querying single record with ID: %s", id) - singleResult := model + // Create a pointer to the struct type for scanning + singleResult := reflect.New(modelType).Interface() query = query.Where("id = ?", id) if err := query.Scan(ctx, singleResult); err != nil { logger.Error("Error querying record: %v", err) @@ -159,7 +164,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schem result = singleResult } else { logger.Debug("Querying multiple records") - sliceType := reflect.SliceOf(reflect.TypeOf(model)) + // Create a slice of the struct type (not pointers) + sliceType := reflect.SliceOf(modelType) results := reflect.New(sliceType).Interface() if err := query.Scan(ctx, results); err != nil { diff --git a/pkg/testmodels/business.go b/pkg/testmodels/business.go index d68d8d1..9322592 100644 --- a/pkg/testmodels/business.go +++ b/pkg/testmodels/business.go @@ -140,22 +140,22 @@ func (Comment) TableName() string { // RegisterTestModels registers all test models with the provided registry func RegisterTestModels(registry *modelregistry.DefaultModelRegistry) { - registry.RegisterModel("departments", &Department{}) - registry.RegisterModel("employees", &Employee{}) - registry.RegisterModel("projects", &Project{}) - registry.RegisterModel("project_tasks", &ProjectTask{}) - registry.RegisterModel("documents", &Document{}) - registry.RegisterModel("comments", &Comment{}) + registry.RegisterModel("departments", Department{}) + registry.RegisterModel("employees", Employee{}) + registry.RegisterModel("projects", Project{}) + registry.RegisterModel("project_tasks", ProjectTask{}) + registry.RegisterModel("documents", Document{}) + registry.RegisterModel("comments", Comment{}) } // GetTestModels returns a list of all test model instances func GetTestModels() []interface{} { return []interface{}{ - &Department{}, - &Employee{}, - &Project{}, - &ProjectTask{}, - &Document{}, - &Comment{}, + Department{}, + Employee{}, + Project{}, + ProjectTask{}, + Document{}, + Comment{}, } }