ResolveSpec/pkg/common/adapters/database/pgsql.go
Hein 4cc943b9d3
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Added row PgSQLAdapter
2025-12-10 15:28:09 +02:00

1356 lines
34 KiB
Go

package database
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// PgSQLAdapter adapts standard database/sql to work with our Database interface
// This provides a lightweight PostgreSQL adapter without ORM overhead
type PgSQLAdapter struct {
db *sql.DB
}
// NewPgSQLAdapter creates a new PostgreSQL adapter
func NewPgSQLAdapter(db *sql.DB) *PgSQLAdapter {
return &PgSQLAdapter{db: db}
}
// EnableQueryDebug enables query debugging for development
func (p *PgSQLAdapter) EnableQueryDebug() {
logger.Info("PgSQL query debug mode - logging enabled via logger")
}
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{
db: p.db,
columns: []string{"*"},
args: make([]interface{}, 0),
}
}
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{
db: p.db,
values: make(map[string]interface{}),
}
}
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{
db: p.db,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
}
}
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{
db: p.db,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
}
}
func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLAdapter.Exec", r)
}
}()
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
result, err := p.db.ExecContext(ctx, query, args...)
if err != nil {
logger.Error("PgSQL Exec failed: %v", err)
return nil, err
}
return &PgSQLResult{result: result}, nil
}
func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLAdapter.Query", r)
}
}()
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
rows, err := p.db.QueryContext(ctx, query, args...)
if err != nil {
logger.Error("PgSQL Query failed: %v", err)
return err
}
defer rows.Close()
return scanRows(rows, dest)
}
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
tx, err := p.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
return &PgSQLTxAdapter{tx: tx}, nil
}
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
return fmt.Errorf("CommitTx should be called on transaction adapter")
}
func (p *PgSQLAdapter) RollbackTx(ctx context.Context) error {
return fmt.Errorf("RollbackTx should be called on transaction adapter")
}
func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLAdapter.RunInTransaction", r)
}
}()
tx, err := p.db.BeginTx(ctx, nil)
if err != nil {
return err
}
adapter := &PgSQLTxAdapter{tx: tx}
defer func() {
if p := recover(); p != nil {
_ = tx.Rollback()
panic(p)
} else if err != nil {
_ = tx.Rollback()
} else {
err = tx.Commit()
}
}()
return fn(adapter)
}
// preloadConfig represents a relationship to be preloaded
type preloadConfig struct {
relation string
conditions []interface{}
applyFuncs []func(common.SelectQuery) common.SelectQuery
useJoin bool
}
// relationMetadata contains information about a relationship
type relationMetadata struct {
fieldName string
relationType reflection.RelationType
foreignKey string
targetTable string
targetKey string
}
// PgSQLSelectQuery implements SelectQuery for PostgreSQL
type PgSQLSelectQuery struct {
db *sql.DB
tx *sql.Tx
model interface{}
tableName string
tableAlias string
columns []string
columnExprs []string
whereClauses []string
orClauses []string
joins []string
orderBy []string
groupBy []string
havingClauses []string
limit int
offset int
args []interface{}
paramCounter int
preloads []preloadConfig
}
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
p.model = model
if provider, ok := model.(common.TableNameProvider); ok {
p.tableName = provider.TableName()
}
if provider, ok := model.(common.TableAliasProvider); ok {
p.tableAlias = provider.TableAlias()
}
return p
}
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
p.tableName = table
return p
}
func (p *PgSQLSelectQuery) Column(columns ...string) common.SelectQuery {
if len(p.columns) == 1 && p.columns[0] == "*" {
p.columns = make([]string, 0)
}
p.columns = append(p.columns, columns...)
return p
}
func (p *PgSQLSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
p.columnExprs = append(p.columnExprs, query)
p.args = append(p.args, args...)
return p
}
func (p *PgSQLSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
// Replace ? placeholders with $1, $2, etc.
query = p.replacePlaceholders(query, len(args))
p.whereClauses = append(p.whereClauses, query)
p.args = append(p.args, args...)
return p
}
func (p *PgSQLSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
query = p.replacePlaceholders(query, len(args))
p.orClauses = append(p.orClauses, query)
p.args = append(p.args, args...)
return p
}
func (p *PgSQLSelectQuery) Join(query string, args ...interface{}) common.SelectQuery {
query = p.replacePlaceholders(query, len(args))
p.joins = append(p.joins, "JOIN "+query)
p.args = append(p.args, args...)
return p
}
func (p *PgSQLSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery {
query = p.replacePlaceholders(query, len(args))
p.joins = append(p.joins, "LEFT JOIN "+query)
p.args = append(p.args, args...)
return p
}
func (p *PgSQLSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
p.preloads = append(p.preloads, preloadConfig{
relation: relation,
conditions: conditions,
useJoin: false, // Always use subquery for simple Preload
})
return p
}
func (p *PgSQLSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Auto-detect relationship type and choose optimal loading strategy
var useJoin bool
if p.model != nil {
relType := reflection.GetRelationType(p.model, relation)
useJoin = relType.ShouldUseJoin()
logger.Debug("PreloadRelation '%s' detected as: %s (useJoin: %v)", relation, relType, useJoin)
}
p.preloads = append(p.preloads, preloadConfig{
relation: relation,
applyFuncs: apply,
useJoin: useJoin,
})
return p
}
func (p *PgSQLSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Force JOIN loading
logger.Debug("JoinRelation '%s' - forcing JOIN strategy", relation)
p.preloads = append(p.preloads, preloadConfig{
relation: relation,
applyFuncs: apply,
useJoin: true, // Force JOIN
})
return p
}
func (p *PgSQLSelectQuery) Order(order string) common.SelectQuery {
p.orderBy = append(p.orderBy, order)
return p
}
func (p *PgSQLSelectQuery) Limit(n int) common.SelectQuery {
p.limit = n
return p
}
func (p *PgSQLSelectQuery) Offset(n int) common.SelectQuery {
p.offset = n
return p
}
func (p *PgSQLSelectQuery) Group(group string) common.SelectQuery {
p.groupBy = append(p.groupBy, group)
return p
}
func (p *PgSQLSelectQuery) Having(having string, args ...interface{}) common.SelectQuery {
having = p.replacePlaceholders(having, len(args))
p.havingClauses = append(p.havingClauses, having)
p.args = append(p.args, args...)
return p
}
func (p *PgSQLSelectQuery) buildSQL() string {
var sb strings.Builder
// SELECT clause
sb.WriteString("SELECT ")
if len(p.columns) > 0 || len(p.columnExprs) > 0 {
allCols := make([]string, 0)
allCols = append(allCols, p.columns...)
allCols = append(allCols, p.columnExprs...)
sb.WriteString(strings.Join(allCols, ", "))
} else {
sb.WriteString("*")
}
// FROM clause
if p.tableName != "" {
sb.WriteString(" FROM ")
sb.WriteString(p.tableName)
if p.tableAlias != "" {
sb.WriteString(" AS ")
sb.WriteString(p.tableAlias)
}
}
// JOIN clauses
if len(p.joins) > 0 {
sb.WriteString(" ")
sb.WriteString(strings.Join(p.joins, " "))
}
// WHERE clause
if len(p.whereClauses) > 0 || len(p.orClauses) > 0 {
sb.WriteString(" WHERE ")
conditions := make([]string, 0)
if len(p.whereClauses) > 0 {
conditions = append(conditions, "("+strings.Join(p.whereClauses, " AND ")+")")
}
if len(p.orClauses) > 0 {
conditions = append(conditions, "("+strings.Join(p.orClauses, " OR ")+")")
}
sb.WriteString(strings.Join(conditions, " AND "))
}
// GROUP BY clause
if len(p.groupBy) > 0 {
sb.WriteString(" GROUP BY ")
sb.WriteString(strings.Join(p.groupBy, ", "))
}
// HAVING clause
if len(p.havingClauses) > 0 {
sb.WriteString(" HAVING ")
sb.WriteString(strings.Join(p.havingClauses, " AND "))
}
// ORDER BY clause
if len(p.orderBy) > 0 {
sb.WriteString(" ORDER BY ")
sb.WriteString(strings.Join(p.orderBy, ", "))
}
// LIMIT clause
if p.limit > 0 {
sb.WriteString(fmt.Sprintf(" LIMIT %d", p.limit))
}
// OFFSET clause
if p.offset > 0 {
sb.WriteString(fmt.Sprintf(" OFFSET %d", p.offset))
}
return sb.String()
}
func (p *PgSQLSelectQuery) replacePlaceholders(query string, argCount int) string {
result := query
for i := 0; i < argCount; i++ {
p.paramCounter++
placeholder := fmt.Sprintf("$%d", p.paramCounter)
result = strings.Replace(result, "?", placeholder, 1)
}
return result
}
func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLSelectQuery.Scan", r)
}
}()
// Apply preloads that use JOINs
p.applyJoinPreloads()
query := p.buildSQL()
logger.Debug("PgSQL SELECT: %s [args: %v]", query, p.args)
var rows *sql.Rows
if p.tx != nil {
rows, err = p.tx.QueryContext(ctx, query, p.args...)
} else {
rows, err = p.db.QueryContext(ctx, query, p.args...)
}
if err != nil {
logger.Error("PgSQL SELECT failed: %v", err)
return err
}
defer rows.Close()
err = scanRows(rows, dest)
if err != nil {
return err
}
// Apply preloads that use separate queries
return p.applySubqueryPreloads(ctx, dest)
}
func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
if p.model == nil {
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
}
return p.Scan(ctx, p.model)
}
func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLSelectQuery.Count", r)
count = 0
}
}()
// Build a COUNT query
var sb strings.Builder
sb.WriteString("SELECT COUNT(*) FROM ")
sb.WriteString(p.tableName)
if len(p.joins) > 0 {
sb.WriteString(" ")
sb.WriteString(strings.Join(p.joins, " "))
}
if len(p.whereClauses) > 0 || len(p.orClauses) > 0 {
sb.WriteString(" WHERE ")
conditions := make([]string, 0)
if len(p.whereClauses) > 0 {
conditions = append(conditions, "("+strings.Join(p.whereClauses, " AND ")+")")
}
if len(p.orClauses) > 0 {
conditions = append(conditions, "("+strings.Join(p.orClauses, " OR ")+")")
}
sb.WriteString(strings.Join(conditions, " AND "))
}
query := sb.String()
logger.Debug("PgSQL COUNT: %s [args: %v]", query, p.args)
var row *sql.Row
if p.tx != nil {
row = p.tx.QueryRowContext(ctx, query, p.args...)
} else {
row = p.db.QueryRowContext(ctx, query, p.args...)
}
err = row.Scan(&count)
if err != nil {
logger.Error("PgSQL COUNT failed: %v", err)
}
return count, err
}
func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLSelectQuery.Exists", r)
exists = false
}
}()
count, err := p.Count(ctx)
return count > 0, err
}
// PgSQLInsertQuery implements InsertQuery for PostgreSQL
type PgSQLInsertQuery struct {
db *sql.DB
tx *sql.Tx
tableName string
values map[string]interface{}
returning []string
}
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
if provider, ok := model.(common.TableNameProvider); ok {
p.tableName = provider.TableName()
}
// Extract values from model using reflection
// This is a simplified implementation
return p
}
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
p.tableName = table
return p
}
func (p *PgSQLInsertQuery) Value(column string, value interface{}) common.InsertQuery {
p.values[column] = value
return p
}
func (p *PgSQLInsertQuery) OnConflict(action string) common.InsertQuery {
logger.Warn("OnConflict not yet implemented in PgSQL adapter")
return p
}
func (p *PgSQLInsertQuery) Returning(columns ...string) common.InsertQuery {
p.returning = columns
return p
}
func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLInsertQuery.Exec", r)
}
}()
if len(p.values) == 0 {
return nil, fmt.Errorf("no values to insert")
}
columns := make([]string, 0, len(p.values))
placeholders := make([]string, 0, len(p.values))
args := make([]interface{}, 0, len(p.values))
i := 1
for col, val := range p.values {
columns = append(columns, col)
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
args = append(args, val)
i++
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
p.tableName,
strings.Join(columns, ", "),
strings.Join(placeholders, ", "))
if len(p.returning) > 0 {
query += " RETURNING " + strings.Join(p.returning, ", ")
}
logger.Debug("PgSQL INSERT: %s [args: %v]", query, args)
var result sql.Result
if p.tx != nil {
result, err = p.tx.ExecContext(ctx, query, args...)
} else {
result, err = p.db.ExecContext(ctx, query, args...)
}
if err != nil {
logger.Error("PgSQL INSERT failed: %v", err)
return nil, err
}
return &PgSQLResult{result: result}, nil
}
// PgSQLUpdateQuery implements UpdateQuery for PostgreSQL
type PgSQLUpdateQuery struct {
db *sql.DB
tx *sql.Tx
tableName string
model interface{}
sets map[string]interface{}
whereClauses []string
args []interface{}
paramCounter int
returning []string
}
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
p.model = model
if provider, ok := model.(common.TableNameProvider); ok {
p.tableName = provider.TableName()
}
return p
}
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
p.tableName = table
if p.model == nil {
model, err := modelregistry.GetModelByName(table)
if err == nil {
p.model = model
}
}
return p
}
func (p *PgSQLUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
return p
}
p.sets[column] = value
return p
}
func (p *PgSQLUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
pkName := ""
if p.model != nil {
pkName = reflection.GetPrimaryKeyName(p.model)
}
for column, value := range values {
if pkName != "" && column == pkName {
continue
}
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
continue
}
p.sets[column] = value
}
return p
}
func (p *PgSQLUpdateQuery) Where(query string, args ...interface{}) common.UpdateQuery {
query = p.replacePlaceholders(query, len(args))
p.whereClauses = append(p.whereClauses, query)
p.args = append(p.args, args...)
return p
}
func (p *PgSQLUpdateQuery) Returning(columns ...string) common.UpdateQuery {
p.returning = columns
return p
}
func (p *PgSQLUpdateQuery) replacePlaceholders(query string, argCount int) string {
result := query
for i := 0; i < argCount; i++ {
p.paramCounter++
placeholder := fmt.Sprintf("$%d", p.paramCounter)
result = strings.Replace(result, "?", placeholder, 1)
}
return result
}
func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLUpdateQuery.Exec", r)
}
}()
if len(p.sets) == 0 {
return nil, fmt.Errorf("no values to update")
}
setClauses := make([]string, 0, len(p.sets))
setArgs := make([]interface{}, 0, len(p.sets))
// SET parameters start at $1
i := 1
for col, val := range p.sets {
setClauses = append(setClauses, fmt.Sprintf("%s = $%d", col, i))
setArgs = append(setArgs, val)
i++
}
query := fmt.Sprintf("UPDATE %s SET %s",
p.tableName,
strings.Join(setClauses, ", "))
// Update WHERE clause parameter numbers to continue after SET parameters
if len(p.whereClauses) > 0 {
updatedWhereClauses := make([]string, 0, len(p.whereClauses))
for _, whereClause := range p.whereClauses {
// Find and replace parameter placeholders
updatedClause := whereClause
paramNum := i
// Count how many parameters are in this WHERE clause
placeholderCount := strings.Count(whereClause, "$")
for j := 0; j < placeholderCount; j++ {
oldParam := fmt.Sprintf("$%d", j+1)
newParam := fmt.Sprintf("$%d", paramNum)
updatedClause = strings.Replace(updatedClause, oldParam, newParam, 1)
paramNum++
}
updatedWhereClauses = append(updatedWhereClauses, updatedClause)
i = paramNum
}
p.whereClauses = updatedWhereClauses
}
// All arguments: SET values first, then WHERE values
// Create a new slice to avoid modifying setArgs
allArgs := make([]interface{}, len(setArgs)+len(p.args))
copy(allArgs, setArgs)
copy(allArgs[len(setArgs):], p.args)
if len(p.whereClauses) > 0 {
query += " WHERE " + strings.Join(p.whereClauses, " AND ")
}
if len(p.returning) > 0 {
query += " RETURNING " + strings.Join(p.returning, ", ")
}
logger.Debug("PgSQL UPDATE: %s [args: %v]", query, allArgs)
var result sql.Result
if p.tx != nil {
result, err = p.tx.ExecContext(ctx, query, allArgs...)
} else {
result, err = p.db.ExecContext(ctx, query, allArgs...)
}
if err != nil {
logger.Error("PgSQL UPDATE failed: %v", err)
return nil, err
}
return &PgSQLResult{result: result}, nil
}
// PgSQLDeleteQuery implements DeleteQuery for PostgreSQL
type PgSQLDeleteQuery struct {
db *sql.DB
tx *sql.Tx
tableName string
whereClauses []string
args []interface{}
paramCounter int
}
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
if provider, ok := model.(common.TableNameProvider); ok {
p.tableName = provider.TableName()
}
return p
}
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
p.tableName = table
return p
}
func (p *PgSQLDeleteQuery) Where(query string, args ...interface{}) common.DeleteQuery {
query = p.replacePlaceholders(query, len(args))
p.whereClauses = append(p.whereClauses, query)
p.args = append(p.args, args...)
return p
}
func (p *PgSQLDeleteQuery) replacePlaceholders(query string, argCount int) string {
result := query
for i := 0; i < argCount; i++ {
p.paramCounter++
placeholder := fmt.Sprintf("$%d", p.paramCounter)
result = strings.Replace(result, "?", placeholder, 1)
}
return result
}
func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("PgSQLDeleteQuery.Exec", r)
}
}()
query := fmt.Sprintf("DELETE FROM %s", p.tableName)
if len(p.whereClauses) > 0 {
query += " WHERE " + strings.Join(p.whereClauses, " AND ")
}
logger.Debug("PgSQL DELETE: %s [args: %v]", query, p.args)
var result sql.Result
if p.tx != nil {
result, err = p.tx.ExecContext(ctx, query, p.args...)
} else {
result, err = p.db.ExecContext(ctx, query, p.args...)
}
if err != nil {
logger.Error("PgSQL DELETE failed: %v", err)
return nil, err
}
return &PgSQLResult{result: result}, nil
}
// PgSQLResult implements Result for PostgreSQL
type PgSQLResult struct {
result sql.Result
}
func (p *PgSQLResult) RowsAffected() int64 {
if p.result == nil {
return 0
}
rows, _ := p.result.RowsAffected()
return rows
}
func (p *PgSQLResult) LastInsertId() (int64, error) {
if p.result == nil {
return 0, nil
}
return p.result.LastInsertId()
}
// PgSQLTxAdapter wraps a PostgreSQL transaction
type PgSQLTxAdapter struct {
tx *sql.Tx
}
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
return &PgSQLSelectQuery{
tx: p.tx,
columns: []string{"*"},
args: make([]interface{}, 0),
}
}
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
return &PgSQLInsertQuery{
tx: p.tx,
values: make(map[string]interface{}),
}
}
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
return &PgSQLUpdateQuery{
tx: p.tx,
sets: make(map[string]interface{}),
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
}
}
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
return &PgSQLDeleteQuery{
tx: p.tx,
args: make([]interface{}, 0),
whereClauses: make([]string, 0),
}
}
func (p *PgSQLTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
logger.Debug("PgSQL Tx Exec: %s [args: %v]", query, args)
result, err := p.tx.ExecContext(ctx, query, args...)
if err != nil {
logger.Error("PgSQL Tx Exec failed: %v", err)
return nil, err
}
return &PgSQLResult{result: result}, nil
}
func (p *PgSQLTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
logger.Debug("PgSQL Tx Query: %s [args: %v]", query, args)
rows, err := p.tx.QueryContext(ctx, query, args...)
if err != nil {
logger.Error("PgSQL Tx Query failed: %v", err)
return err
}
defer rows.Close()
return scanRows(rows, dest)
}
func (p *PgSQLTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {
return nil, fmt.Errorf("nested transactions not supported")
}
func (p *PgSQLTxAdapter) CommitTx(ctx context.Context) error {
return p.tx.Commit()
}
func (p *PgSQLTxAdapter) RollbackTx(ctx context.Context) error {
return p.tx.Rollback()
}
func (p *PgSQLTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
return fn(p)
}
// applyJoinPreloads adds JOINs for relationships that should use JOIN strategy
func (p *PgSQLSelectQuery) applyJoinPreloads() {
for _, preload := range p.preloads {
if !preload.useJoin {
continue
}
// Build JOIN based on relationship metadata
meta := p.getRelationMetadata(preload.relation)
if meta == nil {
logger.Warn("Cannot determine relationship metadata for '%s'", preload.relation)
continue
}
// Build the JOIN clause
relationAlias := strings.ToLower(preload.relation)
joinClause := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s",
meta.targetTable,
relationAlias,
p.tableAlias,
meta.foreignKey,
relationAlias,
meta.targetKey,
)
logger.Debug("Adding LEFT JOIN for relation '%s': %s", preload.relation, joinClause)
p.joins = append(p.joins, "LEFT JOIN "+joinClause)
// Apply any custom conditions through applyFuncs
// Note: These would need to be integrated into the WHERE clause
// For simplicity, we're logging a warning if custom conditions are present
if len(preload.applyFuncs) > 0 {
logger.Warn("Custom conditions in JoinRelation not yet fully implemented")
}
}
}
// applySubqueryPreloads executes separate queries for has-many and many-to-many relationships
func (p *PgSQLSelectQuery) applySubqueryPreloads(ctx context.Context, dest interface{}) error {
// Get all preloads that don't use JOIN
subqueryPreloads := make([]preloadConfig, 0)
for _, preload := range p.preloads {
if !preload.useJoin {
subqueryPreloads = append(subqueryPreloads, preload)
}
}
if len(subqueryPreloads) == 0 {
return nil
}
// Use reflection to process the destination
destValue := reflect.ValueOf(dest)
if destValue.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer")
}
destValue = destValue.Elem()
// Handle slice of structs
if destValue.Kind() == reflect.Slice {
for i := 0; i < destValue.Len(); i++ {
elem := destValue.Index(i)
if err := p.loadPreloadsForRecord(ctx, elem, subqueryPreloads); err != nil {
logger.Warn("Failed to load preloads for record %d: %v", i, err)
}
}
return nil
}
// Handle single struct
if destValue.Kind() == reflect.Struct {
return p.loadPreloadsForRecord(ctx, destValue, subqueryPreloads)
}
return nil
}
// loadPreloadsForRecord loads all preload relationships for a single record
func (p *PgSQLSelectQuery) loadPreloadsForRecord(ctx context.Context, record reflect.Value, preloads []preloadConfig) error {
if record.Kind() == reflect.Ptr {
if record.IsNil() {
return nil
}
record = record.Elem()
}
for _, preload := range preloads {
field := record.FieldByName(preload.relation)
if !field.IsValid() || !field.CanSet() {
logger.Warn("Field '%s' not found or cannot be set", preload.relation)
continue
}
meta := p.getRelationMetadataFromField(record.Type(), preload.relation)
if meta == nil {
logger.Warn("Cannot determine relationship metadata for '%s'", preload.relation)
continue
}
// Get the foreign key value from the parent record
fkField := record.FieldByName(meta.foreignKey)
if !fkField.IsValid() {
logger.Warn("Foreign key field '%s' not found", meta.foreignKey)
continue
}
fkValue := fkField.Interface()
// Build and execute the preload query
err := p.executePreloadQuery(ctx, field, meta, fkValue, preload)
if err != nil {
logger.Warn("Failed to execute preload query for '%s': %v", preload.relation, err)
}
}
return nil
}
// executePreloadQuery executes a query to load a relationship
func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflect.Value, meta *relationMetadata, fkValue interface{}, preload preloadConfig) error {
// Create a new select query for the related table
var db common.Database
if p.tx != nil {
db = &PgSQLTxAdapter{tx: p.tx}
} else {
db = &PgSQLAdapter{db: p.db}
}
query := db.NewSelect().
Table(meta.targetTable).
Where(fmt.Sprintf("%s = ?", meta.targetKey), fkValue)
// Apply custom functions
for _, applyFunc := range preload.applyFuncs {
if applyFunc != nil {
query = applyFunc(query)
}
}
// Determine if this is a slice (has-many) or single struct (belongs-to/has-one)
if field.Kind() == reflect.Slice {
// Create a new slice to hold results
sliceType := field.Type()
results := reflect.New(sliceType).Elem()
// Execute query
err := query.Scan(ctx, results.Addr().Interface())
if err != nil {
return err
}
// Set the field
field.Set(results)
} else {
// Single struct - create a pointer if needed
var target reflect.Value
if field.Kind() == reflect.Ptr {
target = reflect.New(field.Type().Elem())
} else {
target = reflect.New(field.Type())
}
// Execute query with LIMIT 1
err := query.Limit(1).Scan(ctx, target.Interface())
if err != nil && err != sql.ErrNoRows {
return err
}
// Set the field
if field.Kind() == reflect.Ptr {
field.Set(target)
} else {
field.Set(target.Elem())
}
}
return nil
}
// getRelationMetadata extracts relationship metadata from the model
func (p *PgSQLSelectQuery) getRelationMetadata(fieldName string) *relationMetadata {
if p.model == nil {
return nil
}
modelType := reflect.TypeOf(p.model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
return p.getRelationMetadataFromField(modelType, fieldName)
}
// getRelationMetadataFromField extracts relationship metadata from a type
func (p *PgSQLSelectQuery) getRelationMetadataFromField(modelType reflect.Type, fieldName string) *relationMetadata {
if modelType.Kind() != reflect.Struct {
return nil
}
field, found := modelType.FieldByName(fieldName)
if !found {
return nil
}
meta := &relationMetadata{
fieldName: fieldName,
relationType: reflection.GetRelationType(reflect.New(modelType).Interface(), fieldName),
}
// Parse struct tags to get foreign key and target table
bunTag := field.Tag.Get("bun")
if bunTag != "" {
// Parse bun tags: rel:has-many,join:user_id=id
parts := strings.Split(bunTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "join:") {
// Parse join condition: join:user_id=id
joinSpec := strings.TrimPrefix(part, "join:")
if strings.Contains(joinSpec, "=") {
keys := strings.Split(joinSpec, "=")
if len(keys) == 2 {
meta.foreignKey = strings.TrimSpace(keys[0])
meta.targetKey = strings.TrimSpace(keys[1])
}
}
}
}
}
// Try to determine target table from field type
fieldType := field.Type
if fieldType.Kind() == reflect.Slice {
fieldType = fieldType.Elem()
}
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
if fieldType.Kind() == reflect.Struct {
// Try to get table name from the related model
relatedModel := reflect.New(fieldType).Interface()
if provider, ok := relatedModel.(common.TableNameProvider); ok {
meta.targetTable = provider.TableName()
}
}
// Set defaults if not found
if meta.foreignKey == "" {
meta.foreignKey = "id"
}
if meta.targetKey == "" {
meta.targetKey = "id"
}
return meta
}
// scanRows scans database rows into the destination using reflection
func scanRows(rows *sql.Rows, dest interface{}) error {
// Get column names
columns, err := rows.Columns()
if err != nil {
return err
}
// Get destination type
destValue := reflect.ValueOf(dest)
if destValue.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer")
}
destValue = destValue.Elem()
// Handle map slice: []map[string]interface{}
if destValue.Type() == reflect.TypeOf([]map[string]interface{}{}) {
return scanRowsToMapSlice(rows, columns, destValue)
}
// Handle struct slice: []MyStruct or []*MyStruct
if destValue.Kind() == reflect.Slice {
return scanRowsToStructSlice(rows, columns, destValue)
}
// Handle single struct: MyStruct or *MyStruct
if destValue.Kind() == reflect.Struct {
return scanRowsToSingleStruct(rows, columns, destValue)
}
return fmt.Errorf("unsupported destination type: %T", dest)
}
// scanRowsToMapSlice scans rows into []map[string]interface{}
func scanRowsToMapSlice(rows *sql.Rows, columns []string, destValue reflect.Value) error {
results := make([]map[string]interface{}, 0)
for rows.Next() {
// Create holders for values
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
err := rows.Scan(valuePtrs...)
if err != nil {
return err
}
row := make(map[string]interface{})
for i, col := range columns {
row[col] = values[i]
}
results = append(results, row)
}
destValue.Set(reflect.ValueOf(results))
return rows.Err()
}
// scanRowsToStructSlice scans rows into a slice of structs
func scanRowsToStructSlice(rows *sql.Rows, columns []string, destValue reflect.Value) error {
elemType := destValue.Type().Elem()
isPtr := elemType.Kind() == reflect.Ptr
if isPtr {
elemType = elemType.Elem()
}
if elemType.Kind() != reflect.Struct {
return fmt.Errorf("slice element must be a struct, got %v", elemType.Kind())
}
// Build column-to-field mapping
fieldMap := buildFieldMap(elemType, columns)
for rows.Next() {
// Create a new instance of the struct
elemValue := reflect.New(elemType).Elem()
// Create scan targets
scanTargets := make([]interface{}, len(columns))
for i, col := range columns {
if fieldInfo, ok := fieldMap[col]; ok {
field := elemValue.FieldByIndex(fieldInfo.Index)
if field.CanSet() {
scanTargets[i] = field.Addr().Interface()
continue
}
}
// Use a dummy variable for unmapped columns
var dummy interface{}
scanTargets[i] = &dummy
}
err := rows.Scan(scanTargets...)
if err != nil {
return fmt.Errorf("scan failed: %w", err)
}
// Append to slice
if isPtr {
destValue.Set(reflect.Append(destValue, elemValue.Addr()))
} else {
destValue.Set(reflect.Append(destValue, elemValue))
}
}
return rows.Err()
}
// scanRowsToSingleStruct scans a single row into a struct
func scanRowsToSingleStruct(rows *sql.Rows, columns []string, destValue reflect.Value) error {
if !rows.Next() {
return sql.ErrNoRows
}
// Build column-to-field mapping
fieldMap := buildFieldMap(destValue.Type(), columns)
// Create scan targets
scanTargets := make([]interface{}, len(columns))
for i, col := range columns {
if fieldInfo, ok := fieldMap[col]; ok {
field := destValue.FieldByIndex(fieldInfo.Index)
if field.CanSet() {
scanTargets[i] = field.Addr().Interface()
continue
}
}
// Use a dummy variable for unmapped columns
var dummy interface{}
scanTargets[i] = &dummy
}
err := rows.Scan(scanTargets...)
if err != nil {
return fmt.Errorf("scan failed: %w", err)
}
return rows.Err()
}
// fieldInfo holds information about a struct field
type fieldInfo struct {
Index []int
Name string
}
// buildFieldMap creates a mapping from column names to struct fields
func buildFieldMap(structType reflect.Type, _ []string) map[string]fieldInfo {
fieldMap := make(map[string]fieldInfo)
// Iterate through struct fields
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
// Get column name from struct tag or field name
colName := field.Name
// Check for bun tag
if bunTag := field.Tag.Get("bun"); bunTag != "" {
parts := strings.Split(bunTag, ",")
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
colName = parts[0]
}
}
// Check for db tag (common convention)
if dbTag := field.Tag.Get("db"); dbTag != "" && dbTag != "-" {
colName = dbTag
}
// Convert to lowercase for case-insensitive matching
colNameLower := strings.ToLower(colName)
fieldMap[colNameLower] = fieldInfo{
Index: field.Index,
Name: field.Name,
}
// Also map by exact field name
fieldMap[strings.ToLower(field.Name)] = fieldInfo{
Index: field.Index,
Name: field.Name,
}
}
return fieldMap
}