mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-11-13 18:03:53 +00:00
Updated how model registry works
This commit is contained in:
parent
8e06736701
commit
d122c7af42
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@ -77,15 +78,24 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
|||||||
// BunSelectQuery implements SelectQuery for Bun
|
// BunSelectQuery implements SelectQuery for Bun
|
||||||
type BunSelectQuery struct {
|
type BunSelectQuery struct {
|
||||||
query *bun.SelectQuery
|
query *bun.SelectQuery
|
||||||
|
tableName string
|
||||||
|
tableAlias string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||||
b.query = b.query.Model(model)
|
b.query = b.query.Model(model)
|
||||||
|
|
||||||
|
// Try to get table name from model if it implements TableNameProvider
|
||||||
|
if provider, ok := model.(common.TableNameProvider); ok {
|
||||||
|
b.tableName = provider.TableName()
|
||||||
|
}
|
||||||
|
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Table(table string) common.SelectQuery {
|
func (b *BunSelectQuery) Table(table string) common.SelectQuery {
|
||||||
b.query = b.query.Table(table)
|
b.query = b.query.Table(table)
|
||||||
|
b.tableName = table
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,12 +115,81 @@ func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.Selec
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Join(query string, args ...interface{}) common.SelectQuery {
|
func (b *BunSelectQuery) Join(query string, args ...interface{}) common.SelectQuery {
|
||||||
b.query = b.query.Join(query, args...)
|
// Extract optional prefix from args
|
||||||
|
// If the last arg is a string that looks like a table prefix, use it
|
||||||
|
var prefix string
|
||||||
|
sqlArgs := args
|
||||||
|
|
||||||
|
if len(args) > 0 {
|
||||||
|
if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") {
|
||||||
|
// Likely a prefix, not a SQL parameter
|
||||||
|
prefix = lastArg
|
||||||
|
sqlArgs = args[:len(args)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no prefix provided, use the table name as prefix
|
||||||
|
if prefix == "" && b.tableName != "" {
|
||||||
|
prefix = b.tableName
|
||||||
|
// Extract just the table name if it has schema
|
||||||
|
if idx := strings.LastIndex(prefix, "."); idx != -1 {
|
||||||
|
prefix = prefix[idx+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If prefix is provided, add it as an alias in the join
|
||||||
|
// Bun expects: "JOIN table AS alias ON condition"
|
||||||
|
joinClause := query
|
||||||
|
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
|
||||||
|
// If query doesn't already have AS, check if it's a simple table name
|
||||||
|
parts := strings.Fields(query)
|
||||||
|
if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") {
|
||||||
|
// Simple table name, add prefix: "table AS prefix"
|
||||||
|
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
|
||||||
|
if len(parts) > 1 {
|
||||||
|
// Has ON clause: "table ON ..." becomes "table AS prefix ON ..."
|
||||||
|
joinClause += " " + strings.Join(parts[1:], " ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.query = b.query.Join(joinClause, sqlArgs...)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery {
|
func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery {
|
||||||
b.query = b.query.Join("LEFT JOIN " + query, args...)
|
// Extract optional prefix from args
|
||||||
|
var prefix string
|
||||||
|
sqlArgs := args
|
||||||
|
|
||||||
|
if len(args) > 0 {
|
||||||
|
if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") {
|
||||||
|
prefix = lastArg
|
||||||
|
sqlArgs = args[:len(args)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no prefix provided, use the table name as prefix
|
||||||
|
if prefix == "" && b.tableName != "" {
|
||||||
|
prefix = b.tableName
|
||||||
|
if idx := strings.LastIndex(prefix, "."); idx != -1 {
|
||||||
|
prefix = prefix[idx+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct LEFT JOIN with prefix
|
||||||
|
joinClause := query
|
||||||
|
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
|
||||||
|
parts := strings.Fields(query)
|
||||||
|
if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "LEFT") && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") {
|
||||||
|
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
|
||||||
|
if len(parts) > 1 {
|
||||||
|
joinClause += " " + strings.Join(parts[1:], " ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.query = b.query.Join("LEFT JOIN " + joinClause, sqlArgs...)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,8 @@ package database
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
@ -68,15 +70,24 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
|||||||
// GormSelectQuery implements SelectQuery for GORM
|
// GormSelectQuery implements SelectQuery for GORM
|
||||||
type GormSelectQuery struct {
|
type GormSelectQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
tableName string
|
||||||
|
tableAlias string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||||
g.db = g.db.Model(model)
|
g.db = g.db.Model(model)
|
||||||
|
|
||||||
|
// Try to get table name from model if it implements TableNameProvider
|
||||||
|
if provider, ok := model.(common.TableNameProvider); ok {
|
||||||
|
g.tableName = provider.TableName()
|
||||||
|
}
|
||||||
|
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
||||||
g.db = g.db.Table(table)
|
g.db = g.db.Table(table)
|
||||||
|
g.tableName = table
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,12 +107,81 @@ func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.Sele
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQuery {
|
func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQuery {
|
||||||
g.db = g.db.Joins(query, args...)
|
// Extract optional prefix from args
|
||||||
|
// If the last arg is a string that looks like a table prefix, use it
|
||||||
|
var prefix string
|
||||||
|
sqlArgs := args
|
||||||
|
|
||||||
|
if len(args) > 0 {
|
||||||
|
if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") {
|
||||||
|
// Likely a prefix, not a SQL parameter
|
||||||
|
prefix = lastArg
|
||||||
|
sqlArgs = args[:len(args)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no prefix provided, use the table name as prefix
|
||||||
|
if prefix == "" && g.tableName != "" {
|
||||||
|
prefix = g.tableName
|
||||||
|
// Extract just the table name if it has schema
|
||||||
|
if idx := strings.LastIndex(prefix, "."); idx != -1 {
|
||||||
|
prefix = prefix[idx+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If prefix is provided, add it as an alias in the join
|
||||||
|
// GORM expects: "JOIN table AS alias ON condition"
|
||||||
|
joinClause := query
|
||||||
|
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
|
||||||
|
// If query doesn't already have AS, check if it's a simple table name
|
||||||
|
parts := strings.Fields(query)
|
||||||
|
if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") {
|
||||||
|
// Simple table name, add prefix: "table AS prefix"
|
||||||
|
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
|
||||||
|
if len(parts) > 1 {
|
||||||
|
// Has ON clause: "table ON ..." becomes "table AS prefix ON ..."
|
||||||
|
joinClause += " " + strings.Join(parts[1:], " ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
g.db = g.db.Joins(joinClause, sqlArgs...)
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery {
|
func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery {
|
||||||
g.db = g.db.Joins("LEFT JOIN "+query, args...)
|
// Extract optional prefix from args
|
||||||
|
var prefix string
|
||||||
|
sqlArgs := args
|
||||||
|
|
||||||
|
if len(args) > 0 {
|
||||||
|
if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") {
|
||||||
|
prefix = lastArg
|
||||||
|
sqlArgs = args[:len(args)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no prefix provided, use the table name as prefix
|
||||||
|
if prefix == "" && g.tableName != "" {
|
||||||
|
prefix = g.tableName
|
||||||
|
if idx := strings.LastIndex(prefix, "."); idx != -1 {
|
||||||
|
prefix = prefix[idx+1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct LEFT JOIN with prefix
|
||||||
|
joinClause := query
|
||||||
|
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
|
||||||
|
parts := strings.Fields(query)
|
||||||
|
if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "LEFT") && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") {
|
||||||
|
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
|
||||||
|
if len(parts) > 1 {
|
||||||
|
joinClause += " " + strings.Join(parts[1:], " ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
g.db = g.db.Joins("LEFT JOIN "+joinClause, sqlArgs...)
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package modelregistry
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,6 +32,20 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
|
|||||||
return fmt.Errorf("model %s already registered", name)
|
return fmt.Errorf("model %s already registered", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate that model is a non-pointer struct
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType == nil {
|
||||||
|
return fmt.Errorf("model cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
return fmt.Errorf("model must be a non-pointer struct, got pointer to %s", modelType.Elem().Kind())
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType.Kind() != reflect.Struct {
|
||||||
|
return fmt.Errorf("model must be a struct, got %s", modelType.Kind())
|
||||||
|
}
|
||||||
|
|
||||||
r.models[name] = model
|
r.models[name] = model
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -93,7 +93,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schem
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
query := h.db.NewSelect().Model(model)
|
// Model is now a non-pointer struct, create a pointer instance for ORM
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
modelPtr := reflect.New(modelType).Interface()
|
||||||
|
|
||||||
|
query := h.db.NewSelect().Model(modelPtr)
|
||||||
|
|
||||||
// Get table name
|
// Get table name
|
||||||
tableName := h.getTableName(schema, entity, model)
|
tableName := h.getTableName(schema, entity, model)
|
||||||
@ -149,7 +153,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schem
|
|||||||
var result interface{}
|
var result interface{}
|
||||||
if id != "" {
|
if id != "" {
|
||||||
logger.Debug("Querying single record with ID: %s", id)
|
logger.Debug("Querying single record with ID: %s", id)
|
||||||
singleResult := model
|
// Create a pointer to the struct type for scanning
|
||||||
|
singleResult := reflect.New(modelType).Interface()
|
||||||
query = query.Where("id = ?", id)
|
query = query.Where("id = ?", id)
|
||||||
if err := query.Scan(ctx, singleResult); err != nil {
|
if err := query.Scan(ctx, singleResult); err != nil {
|
||||||
logger.Error("Error querying record: %v", err)
|
logger.Error("Error querying record: %v", err)
|
||||||
@ -159,7 +164,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, schem
|
|||||||
result = singleResult
|
result = singleResult
|
||||||
} else {
|
} else {
|
||||||
logger.Debug("Querying multiple records")
|
logger.Debug("Querying multiple records")
|
||||||
sliceType := reflect.SliceOf(reflect.TypeOf(model))
|
// Create a slice of the struct type (not pointers)
|
||||||
|
sliceType := reflect.SliceOf(modelType)
|
||||||
results := reflect.New(sliceType).Interface()
|
results := reflect.New(sliceType).Interface()
|
||||||
|
|
||||||
if err := query.Scan(ctx, results); err != nil {
|
if err := query.Scan(ctx, results); err != nil {
|
||||||
|
|||||||
@ -140,22 +140,22 @@ func (Comment) TableName() string {
|
|||||||
|
|
||||||
// RegisterTestModels registers all test models with the provided registry
|
// RegisterTestModels registers all test models with the provided registry
|
||||||
func RegisterTestModels(registry *modelregistry.DefaultModelRegistry) {
|
func RegisterTestModels(registry *modelregistry.DefaultModelRegistry) {
|
||||||
registry.RegisterModel("departments", &Department{})
|
registry.RegisterModel("departments", Department{})
|
||||||
registry.RegisterModel("employees", &Employee{})
|
registry.RegisterModel("employees", Employee{})
|
||||||
registry.RegisterModel("projects", &Project{})
|
registry.RegisterModel("projects", Project{})
|
||||||
registry.RegisterModel("project_tasks", &ProjectTask{})
|
registry.RegisterModel("project_tasks", ProjectTask{})
|
||||||
registry.RegisterModel("documents", &Document{})
|
registry.RegisterModel("documents", Document{})
|
||||||
registry.RegisterModel("comments", &Comment{})
|
registry.RegisterModel("comments", Comment{})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetTestModels returns a list of all test model instances
|
// GetTestModels returns a list of all test model instances
|
||||||
func GetTestModels() []interface{} {
|
func GetTestModels() []interface{} {
|
||||||
return []interface{}{
|
return []interface{}{
|
||||||
&Department{},
|
Department{},
|
||||||
&Employee{},
|
Employee{},
|
||||||
&Project{},
|
Project{},
|
||||||
&ProjectTask{},
|
ProjectTask{},
|
||||||
&Document{},
|
Document{},
|
||||||
&Comment{},
|
Comment{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user