Updated how model registry works

This commit is contained in:
Hein 2025-11-07 08:26:50 +02:00
parent 8e06736701
commit d122c7af42
5 changed files with 203 additions and 23 deletions

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common" "github.com/Warky-Devs/ResolveSpec/pkg/common"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@ -76,16 +77,25 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
// BunSelectQuery implements SelectQuery for Bun // BunSelectQuery implements SelectQuery for Bun
type BunSelectQuery struct { type BunSelectQuery struct {
query *bun.SelectQuery query *bun.SelectQuery
tableName string
tableAlias string
} }
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery { func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
b.query = b.query.Model(model) b.query = b.query.Model(model)
// Try to get table name from model if it implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
b.tableName = provider.TableName()
}
return b return b
} }
func (b *BunSelectQuery) Table(table string) common.SelectQuery { func (b *BunSelectQuery) Table(table string) common.SelectQuery {
b.query = b.query.Table(table) b.query = b.query.Table(table)
b.tableName = table
return b return b
} }
@ -105,12 +115,81 @@ func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.Selec
} }
func (b *BunSelectQuery) Join(query string, args ...interface{}) common.SelectQuery { func (b *BunSelectQuery) Join(query string, args ...interface{}) common.SelectQuery {
b.query = b.query.Join(query, args...) // Extract optional prefix from args
// If the last arg is a string that looks like a table prefix, use it
var prefix string
sqlArgs := args
if len(args) > 0 {
if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") {
// Likely a prefix, not a SQL parameter
prefix = lastArg
sqlArgs = args[:len(args)-1]
}
}
// If no prefix provided, use the table name as prefix
if prefix == "" && b.tableName != "" {
prefix = b.tableName
// Extract just the table name if it has schema
if idx := strings.LastIndex(prefix, "."); idx != -1 {
prefix = prefix[idx+1:]
}
}
// If prefix is provided, add it as an alias in the join
// Bun expects: "JOIN table AS alias ON condition"
joinClause := query
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
// If query doesn't already have AS, check if it's a simple table name
parts := strings.Fields(query)
if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") {
// Simple table name, add prefix: "table AS prefix"
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
if len(parts) > 1 {
// Has ON clause: "table ON ..." becomes "table AS prefix ON ..."
joinClause += " " + strings.Join(parts[1:], " ")
}
}
}
b.query = b.query.Join(joinClause, sqlArgs...)
return b return b
} }
func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery { func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery {
b.query = b.query.Join("LEFT JOIN " + query, args...) // Extract optional prefix from args
var prefix string
sqlArgs := args
if len(args) > 0 {
if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") {
prefix = lastArg
sqlArgs = args[:len(args)-1]
}
}
// If no prefix provided, use the table name as prefix
if prefix == "" && b.tableName != "" {
prefix = b.tableName
if idx := strings.LastIndex(prefix, "."); idx != -1 {
prefix = prefix[idx+1:]
}
}
// Construct LEFT JOIN with prefix
joinClause := query
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
parts := strings.Fields(query)
if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "LEFT") && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") {
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
if len(parts) > 1 {
joinClause += " " + strings.Join(parts[1:], " ")
}
}
}
b.query = b.query.Join("LEFT JOIN " + joinClause, sqlArgs...)
return b return b
} }

View File

@ -2,6 +2,8 @@ package database
import ( import (
"context" "context"
"fmt"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common" "github.com/Warky-Devs/ResolveSpec/pkg/common"
"gorm.io/gorm" "gorm.io/gorm"
@ -67,16 +69,25 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
// GormSelectQuery implements SelectQuery for GORM // GormSelectQuery implements SelectQuery for GORM
type GormSelectQuery struct { type GormSelectQuery struct {
db *gorm.DB db *gorm.DB
tableName string
tableAlias string
} }
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery { func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
g.db = g.db.Model(model) g.db = g.db.Model(model)
// Try to get table name from model if it implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
g.tableName = provider.TableName()
}
return g return g
} }
func (g *GormSelectQuery) Table(table string) common.SelectQuery { func (g *GormSelectQuery) Table(table string) common.SelectQuery {
g.db = g.db.Table(table) g.db = g.db.Table(table)
g.tableName = table
return g return g
} }
@ -96,12 +107,81 @@ func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.Sele
} }
func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQuery {
g.db = g.db.Joins(query, args...) // Extract optional prefix from args
// If the last arg is a string that looks like a table prefix, use it
var prefix string
sqlArgs := args
if len(args) > 0 {
if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") {
// Likely a prefix, not a SQL parameter
prefix = lastArg
sqlArgs = args[:len(args)-1]
}
}
// If no prefix provided, use the table name as prefix
if prefix == "" && g.tableName != "" {
prefix = g.tableName
// Extract just the table name if it has schema
if idx := strings.LastIndex(prefix, "."); idx != -1 {
prefix = prefix[idx+1:]
}
}
// If prefix is provided, add it as an alias in the join
// GORM expects: "JOIN table AS alias ON condition"
joinClause := query
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
// If query doesn't already have AS, check if it's a simple table name
parts := strings.Fields(query)
if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") {
// Simple table name, add prefix: "table AS prefix"
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
if len(parts) > 1 {
// Has ON clause: "table ON ..." becomes "table AS prefix ON ..."
joinClause += " " + strings.Join(parts[1:], " ")
}
}
}
g.db = g.db.Joins(joinClause, sqlArgs...)
return g return g
} }
func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery {
g.db = g.db.Joins("LEFT JOIN "+query, args...) // Extract optional prefix from args
var prefix string
sqlArgs := args
if len(args) > 0 {
if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") {
prefix = lastArg
sqlArgs = args[:len(args)-1]
}
}
// If no prefix provided, use the table name as prefix
if prefix == "" && g.tableName != "" {
prefix = g.tableName
if idx := strings.LastIndex(prefix, "."); idx != -1 {
prefix = prefix[idx+1:]
}
}
// Construct LEFT JOIN with prefix
joinClause := query
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
parts := strings.Fields(query)
if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "LEFT") && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") {
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
if len(parts) > 1 {
joinClause += " " + strings.Join(parts[1:], " ")
}
}
}
g.db = g.db.Joins("LEFT JOIN "+joinClause, sqlArgs...)
return g return g
} }

View File

@ -2,6 +2,7 @@ package modelregistry
import ( import (
"fmt" "fmt"
"reflect"
"sync" "sync"
) )
@ -31,6 +32,20 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
return fmt.Errorf("model %s already registered", name) return fmt.Errorf("model %s already registered", name)
} }
// Validate that model is a non-pointer struct
modelType := reflect.TypeOf(model)
if modelType == nil {
return fmt.Errorf("model cannot be nil")
}
if modelType.Kind() == reflect.Ptr {
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s", modelType.Elem().Kind())
}
if modelType.Kind() != reflect.Struct {
return fmt.Errorf("model must be a struct, got %s", modelType.Kind())
}
r.models[name] = model r.models[name] = model
return nil return nil
} }

View File

@ -93,7 +93,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schem
return return
} }
query := h.db.NewSelect().Model(model) // Model is now a non-pointer struct, create a pointer instance for ORM
modelType := reflect.TypeOf(model)
modelPtr := reflect.New(modelType).Interface()
query := h.db.NewSelect().Model(modelPtr)
// Get table name // Get table name
tableName := h.getTableName(schema, entity, model) tableName := h.getTableName(schema, entity, model)
@ -149,7 +153,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schem
var result interface{} var result interface{}
if id != "" { if id != "" {
logger.Debug("Querying single record with ID: %s", id) logger.Debug("Querying single record with ID: %s", id)
singleResult := model // Create a pointer to the struct type for scanning
singleResult := reflect.New(modelType).Interface()
query = query.Where("id = ?", id) query = query.Where("id = ?", id)
if err := query.Scan(ctx, singleResult); err != nil { if err := query.Scan(ctx, singleResult); err != nil {
logger.Error("Error querying record: %v", err) logger.Error("Error querying record: %v", err)
@ -159,7 +164,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schem
result = singleResult result = singleResult
} else { } else {
logger.Debug("Querying multiple records") logger.Debug("Querying multiple records")
sliceType := reflect.SliceOf(reflect.TypeOf(model)) // Create a slice of the struct type (not pointers)
sliceType := reflect.SliceOf(modelType)
results := reflect.New(sliceType).Interface() results := reflect.New(sliceType).Interface()
if err := query.Scan(ctx, results); err != nil { if err := query.Scan(ctx, results); err != nil {

View File

@ -140,22 +140,22 @@ func (Comment) TableName() string {
// RegisterTestModels registers all test models with the provided registry // RegisterTestModels registers all test models with the provided registry
func RegisterTestModels(registry *modelregistry.DefaultModelRegistry) { func RegisterTestModels(registry *modelregistry.DefaultModelRegistry) {
registry.RegisterModel("departments", &Department{}) registry.RegisterModel("departments", Department{})
registry.RegisterModel("employees", &Employee{}) registry.RegisterModel("employees", Employee{})
registry.RegisterModel("projects", &Project{}) registry.RegisterModel("projects", Project{})
registry.RegisterModel("project_tasks", &ProjectTask{}) registry.RegisterModel("project_tasks", ProjectTask{})
registry.RegisterModel("documents", &Document{}) registry.RegisterModel("documents", Document{})
registry.RegisterModel("comments", &Comment{}) registry.RegisterModel("comments", Comment{})
} }
// GetTestModels returns a list of all test model instances // GetTestModels returns a list of all test model instances
func GetTestModels() []interface{} { func GetTestModels() []interface{} {
return []interface{}{ return []interface{}{
&Department{}, Department{},
&Employee{}, Employee{},
&Project{}, Project{},
&ProjectTask{}, ProjectTask{},
&Document{}, Document{},
&Comment{}, Comment{},
} }
} }