package database import ( "context" "fmt" "strings" "gorm.io/gorm" "github.com/bitechdev/ResolveSpec/pkg/common" ) // GormAdapter adapts GORM to work with our Database interface type GormAdapter struct { db *gorm.DB } // NewGormAdapter creates a new GORM adapter func NewGormAdapter(db *gorm.DB) *GormAdapter { return &GormAdapter{db: db} } func (g *GormAdapter) NewSelect() common.SelectQuery { return &GormSelectQuery{db: g.db} } func (g *GormAdapter) NewInsert() common.InsertQuery { return &GormInsertQuery{db: g.db} } func (g *GormAdapter) NewUpdate() common.UpdateQuery { return &GormUpdateQuery{db: g.db} } func (g *GormAdapter) NewDelete() common.DeleteQuery { return &GormDeleteQuery{db: g.db} } func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) { result := g.db.WithContext(ctx).Exec(query, args...) return &GormResult{result: result}, result.Error } func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error { return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error } func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) { tx := g.db.WithContext(ctx).Begin() if tx.Error != nil { return nil, tx.Error } return &GormAdapter{db: tx}, nil } func (g *GormAdapter) CommitTx(ctx context.Context) error { return g.db.WithContext(ctx).Commit().Error } func (g *GormAdapter) RollbackTx(ctx context.Context) error { return g.db.WithContext(ctx).Rollback().Error } func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error { return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { adapter := &GormAdapter{db: tx} return fn(adapter) }) } // GormSelectQuery implements SelectQuery for GORM type GormSelectQuery struct { db *gorm.DB schema string // Separated schema name tableName string // Just the table name, without schema tableAlias string } func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery { g.db = g.db.Model(model) // Try to get table name from model if it implements TableNameProvider if provider, ok := model.(common.TableNameProvider); ok { fullTableName := provider.TableName() // Check if the table name contains schema (e.g., "schema.table") g.schema, g.tableName = parseTableName(fullTableName) } if provider, ok := model.(common.TableAliasProvider); ok { g.tableAlias = provider.TableAlias() } return g } func (g *GormSelectQuery) Table(table string) common.SelectQuery { g.db = g.db.Table(table) // Check if the table name contains schema (e.g., "schema.table") g.schema, g.tableName = parseTableName(table) return g } func (g *GormSelectQuery) Column(columns ...string) common.SelectQuery { g.db = g.db.Select(columns) return g } func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery { g.db = g.db.Select(query, args...) return g } func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { g.db = g.db.Where(query, args...) return g } func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery { g.db = g.db.Or(query, args...) return g } func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQuery { // Extract optional prefix from args // If the last arg is a string that looks like a table prefix, use it var prefix string sqlArgs := args if len(args) > 0 { if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") { // Likely a prefix, not a SQL parameter prefix = lastArg sqlArgs = args[:len(args)-1] } } // If no prefix provided, use the table name as prefix (already separated from schema) if prefix == "" && g.tableName != "" { prefix = g.tableName } // If prefix is provided, add it as an alias in the join // GORM expects: "JOIN table AS alias ON condition" joinClause := query if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") { // If query doesn't already have AS, check if it's a simple table name parts := strings.Fields(query) if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") { // Simple table name, add prefix: "table AS prefix" joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix) if len(parts) > 1 { // Has ON clause: "table ON ..." becomes "table AS prefix ON ..." joinClause += " " + strings.Join(parts[1:], " ") } } } g.db = g.db.Joins(joinClause, sqlArgs...) return g } func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery { // Extract optional prefix from args var prefix string sqlArgs := args if len(args) > 0 { if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") { prefix = lastArg sqlArgs = args[:len(args)-1] } } // If no prefix provided, use the table name as prefix (already separated from schema) if prefix == "" && g.tableName != "" { prefix = g.tableName } // Construct LEFT JOIN with prefix joinClause := query if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") { parts := strings.Fields(query) if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "LEFT") && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") { joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix) if len(parts) > 1 { joinClause += " " + strings.Join(parts[1:], " ") } } } g.db = g.db.Joins("LEFT JOIN "+joinClause, sqlArgs...) return g } func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery { g.db = g.db.Preload(relation, conditions...) return g } func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB { if len(apply) == 0 { return db } wrapper := &GormSelectQuery{ db: g.db, } current := common.SelectQuery(wrapper) for _, fn := range apply { if fn != nil { modified := fn(current) current = modified } } if finalBun, ok := current.(*GormSelectQuery); ok { return finalBun.db } return db // fallback }) return g } func (g *GormSelectQuery) Order(order string) common.SelectQuery { g.db = g.db.Order(order) return g } func (g *GormSelectQuery) Limit(n int) common.SelectQuery { g.db = g.db.Limit(n) return g } func (g *GormSelectQuery) Offset(n int) common.SelectQuery { g.db = g.db.Offset(n) return g } func (g *GormSelectQuery) Group(group string) common.SelectQuery { g.db = g.db.Group(group) return g } func (g *GormSelectQuery) Having(having string, args ...interface{}) common.SelectQuery { g.db = g.db.Having(having, args...) return g } func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) error { return g.db.WithContext(ctx).Find(dest).Error } func (g *GormSelectQuery) ScanModel(ctx context.Context) error { if g.db.Statement.Model == nil { return fmt.Errorf("ScanModel requires Model() to be set before scanning") } return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error } func (g *GormSelectQuery) Count(ctx context.Context) (int, error) { var count int64 err := g.db.WithContext(ctx).Count(&count).Error return int(count), err } func (g *GormSelectQuery) Exists(ctx context.Context) (bool, error) { var count int64 err := g.db.WithContext(ctx).Limit(1).Count(&count).Error return count > 0, err } // GormInsertQuery implements InsertQuery for GORM type GormInsertQuery struct { db *gorm.DB model interface{} values map[string]interface{} } func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery { g.model = model g.db = g.db.Model(model) return g } func (g *GormInsertQuery) Table(table string) common.InsertQuery { g.db = g.db.Table(table) return g } func (g *GormInsertQuery) Value(column string, value interface{}) common.InsertQuery { if g.values == nil { g.values = make(map[string]interface{}) } g.values[column] = value return g } func (g *GormInsertQuery) OnConflict(action string) common.InsertQuery { // GORM handles conflicts differently, this would need specific implementation return g } func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery { // GORM doesn't have explicit RETURNING, but updates the model return g } func (g *GormInsertQuery) Exec(ctx context.Context) (common.Result, error) { var result *gorm.DB switch { case g.model != nil: result = g.db.WithContext(ctx).Create(g.model) case g.values != nil: result = g.db.WithContext(ctx).Create(g.values) default: result = g.db.WithContext(ctx).Create(map[string]interface{}{}) } return &GormResult{result: result}, result.Error } // GormUpdateQuery implements UpdateQuery for GORM type GormUpdateQuery struct { db *gorm.DB model interface{} updates interface{} } func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery { g.model = model g.db = g.db.Model(model) return g } func (g *GormUpdateQuery) Table(table string) common.UpdateQuery { g.db = g.db.Table(table) return g } func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery { if g.updates == nil { g.updates = make(map[string]interface{}) } if updates, ok := g.updates.(map[string]interface{}); ok { updates[column] = value } return g } func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery { g.updates = values return g } func (g *GormUpdateQuery) Where(query string, args ...interface{}) common.UpdateQuery { g.db = g.db.Where(query, args...) return g } func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery { // GORM doesn't have explicit RETURNING return g } func (g *GormUpdateQuery) Exec(ctx context.Context) (common.Result, error) { result := g.db.WithContext(ctx).Updates(g.updates) return &GormResult{result: result}, result.Error } // GormDeleteQuery implements DeleteQuery for GORM type GormDeleteQuery struct { db *gorm.DB model interface{} } func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery { g.model = model g.db = g.db.Model(model) return g } func (g *GormDeleteQuery) Table(table string) common.DeleteQuery { g.db = g.db.Table(table) return g } func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.DeleteQuery { g.db = g.db.Where(query, args...) return g } func (g *GormDeleteQuery) Exec(ctx context.Context) (common.Result, error) { result := g.db.WithContext(ctx).Delete(g.model) return &GormResult{result: result}, result.Error } // GormResult implements Result for GORM type GormResult struct { result *gorm.DB } func (g *GormResult) RowsAffected() int64 { return g.result.RowsAffected } func (g *GormResult) LastInsertId() (int64, error) { // GORM doesn't directly provide last insert ID, would need specific implementation return 0, nil }