mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-13 17:10:36 +00:00
1364 lines
34 KiB
Go
1364 lines
34 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
|
)
|
|
|
|
// PgSQLAdapter adapts standard database/sql to work with our Database interface
|
|
// This provides a lightweight PostgreSQL adapter without ORM overhead
|
|
type PgSQLAdapter struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// NewPgSQLAdapter creates a new PostgreSQL adapter
|
|
func NewPgSQLAdapter(db *sql.DB) *PgSQLAdapter {
|
|
return &PgSQLAdapter{db: db}
|
|
}
|
|
|
|
// EnableQueryDebug enables query debugging for development
|
|
func (p *PgSQLAdapter) EnableQueryDebug() {
|
|
logger.Info("PgSQL query debug mode - logging enabled via logger")
|
|
}
|
|
|
|
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
|
return &PgSQLSelectQuery{
|
|
db: p.db,
|
|
columns: []string{"*"},
|
|
args: make([]interface{}, 0),
|
|
}
|
|
}
|
|
|
|
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
|
return &PgSQLInsertQuery{
|
|
db: p.db,
|
|
values: make(map[string]interface{}),
|
|
}
|
|
}
|
|
|
|
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
|
return &PgSQLUpdateQuery{
|
|
db: p.db,
|
|
sets: make(map[string]interface{}),
|
|
args: make([]interface{}, 0),
|
|
whereClauses: make([]string, 0),
|
|
}
|
|
}
|
|
|
|
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
|
|
return &PgSQLDeleteQuery{
|
|
db: p.db,
|
|
args: make([]interface{}, 0),
|
|
whereClauses: make([]string, 0),
|
|
}
|
|
}
|
|
|
|
func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = logger.HandlePanic("PgSQLAdapter.Exec", r)
|
|
}
|
|
}()
|
|
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
|
|
result, err := p.db.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
logger.Error("PgSQL Exec failed: %v", err)
|
|
return nil, err
|
|
}
|
|
return &PgSQLResult{result: result}, nil
|
|
}
|
|
|
|
func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = logger.HandlePanic("PgSQLAdapter.Query", r)
|
|
}
|
|
}()
|
|
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
|
|
rows, err := p.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
logger.Error("PgSQL Query failed: %v", err)
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
return scanRows(rows, dest)
|
|
}
|
|
|
|
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
|
tx, err := p.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &PgSQLTxAdapter{tx: tx}, nil
|
|
}
|
|
|
|
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
|
|
return fmt.Errorf("CommitTx should be called on transaction adapter")
|
|
}
|
|
|
|
func (p *PgSQLAdapter) RollbackTx(ctx context.Context) error {
|
|
return fmt.Errorf("RollbackTx should be called on transaction adapter")
|
|
}
|
|
|
|
func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = logger.HandlePanic("PgSQLAdapter.RunInTransaction", r)
|
|
}
|
|
}()
|
|
|
|
tx, err := p.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
adapter := &PgSQLTxAdapter{tx: tx}
|
|
|
|
defer func() {
|
|
if p := recover(); p != nil {
|
|
_ = tx.Rollback()
|
|
panic(p)
|
|
} else if err != nil {
|
|
_ = tx.Rollback()
|
|
} else {
|
|
err = tx.Commit()
|
|
}
|
|
}()
|
|
|
|
return fn(adapter)
|
|
}
|
|
|
|
func (p *PgSQLAdapter) GetUnderlyingDB() interface{} {
|
|
return p.db
|
|
}
|
|
|
|
// preloadConfig represents a relationship to be preloaded
|
|
type preloadConfig struct {
|
|
relation string
|
|
conditions []interface{}
|
|
applyFuncs []func(common.SelectQuery) common.SelectQuery
|
|
useJoin bool
|
|
}
|
|
|
|
// relationMetadata contains information about a relationship
|
|
type relationMetadata struct {
|
|
fieldName string
|
|
relationType reflection.RelationType
|
|
foreignKey string
|
|
targetTable string
|
|
targetKey string
|
|
}
|
|
|
|
// PgSQLSelectQuery implements SelectQuery for PostgreSQL
|
|
type PgSQLSelectQuery struct {
|
|
db *sql.DB
|
|
tx *sql.Tx
|
|
model interface{}
|
|
tableName string
|
|
tableAlias string
|
|
columns []string
|
|
columnExprs []string
|
|
whereClauses []string
|
|
orClauses []string
|
|
joins []string
|
|
orderBy []string
|
|
groupBy []string
|
|
havingClauses []string
|
|
limit int
|
|
offset int
|
|
args []interface{}
|
|
paramCounter int
|
|
preloads []preloadConfig
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
|
p.model = model
|
|
if provider, ok := model.(common.TableNameProvider); ok {
|
|
p.tableName = provider.TableName()
|
|
}
|
|
if provider, ok := model.(common.TableAliasProvider); ok {
|
|
p.tableAlias = provider.TableAlias()
|
|
}
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
|
|
p.tableName = table
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Column(columns ...string) common.SelectQuery {
|
|
if len(p.columns) == 1 && p.columns[0] == "*" {
|
|
p.columns = make([]string, 0)
|
|
}
|
|
p.columns = append(p.columns, columns...)
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
|
p.columnExprs = append(p.columnExprs, query)
|
|
p.args = append(p.args, args...)
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
|
// Replace ? placeholders with $1, $2, etc.
|
|
query = p.replacePlaceholders(query, len(args))
|
|
p.whereClauses = append(p.whereClauses, query)
|
|
p.args = append(p.args, args...)
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
|
query = p.replacePlaceholders(query, len(args))
|
|
p.orClauses = append(p.orClauses, query)
|
|
p.args = append(p.args, args...)
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Join(query string, args ...interface{}) common.SelectQuery {
|
|
query = p.replacePlaceholders(query, len(args))
|
|
p.joins = append(p.joins, "JOIN "+query)
|
|
p.args = append(p.args, args...)
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery {
|
|
query = p.replacePlaceholders(query, len(args))
|
|
p.joins = append(p.joins, "LEFT JOIN "+query)
|
|
p.args = append(p.args, args...)
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
|
|
p.preloads = append(p.preloads, preloadConfig{
|
|
relation: relation,
|
|
conditions: conditions,
|
|
useJoin: false, // Always use subquery for simple Preload
|
|
})
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
|
// Auto-detect relationship type and choose optimal loading strategy
|
|
var useJoin bool
|
|
if p.model != nil {
|
|
relType := reflection.GetRelationType(p.model, relation)
|
|
useJoin = relType.ShouldUseJoin()
|
|
logger.Debug("PreloadRelation '%s' detected as: %s (useJoin: %v)", relation, relType, useJoin)
|
|
}
|
|
|
|
p.preloads = append(p.preloads, preloadConfig{
|
|
relation: relation,
|
|
applyFuncs: apply,
|
|
useJoin: useJoin,
|
|
})
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
|
// Force JOIN loading
|
|
logger.Debug("JoinRelation '%s' - forcing JOIN strategy", relation)
|
|
p.preloads = append(p.preloads, preloadConfig{
|
|
relation: relation,
|
|
applyFuncs: apply,
|
|
useJoin: true, // Force JOIN
|
|
})
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Order(order string) common.SelectQuery {
|
|
p.orderBy = append(p.orderBy, order)
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Limit(n int) common.SelectQuery {
|
|
p.limit = n
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Offset(n int) common.SelectQuery {
|
|
p.offset = n
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Group(group string) common.SelectQuery {
|
|
p.groupBy = append(p.groupBy, group)
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Having(having string, args ...interface{}) common.SelectQuery {
|
|
having = p.replacePlaceholders(having, len(args))
|
|
p.havingClauses = append(p.havingClauses, having)
|
|
p.args = append(p.args, args...)
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) buildSQL() string {
|
|
var sb strings.Builder
|
|
|
|
// SELECT clause
|
|
sb.WriteString("SELECT ")
|
|
if len(p.columns) > 0 || len(p.columnExprs) > 0 {
|
|
allCols := make([]string, 0)
|
|
allCols = append(allCols, p.columns...)
|
|
allCols = append(allCols, p.columnExprs...)
|
|
sb.WriteString(strings.Join(allCols, ", "))
|
|
} else {
|
|
sb.WriteString("*")
|
|
}
|
|
|
|
// FROM clause
|
|
if p.tableName != "" {
|
|
sb.WriteString(" FROM ")
|
|
sb.WriteString(p.tableName)
|
|
if p.tableAlias != "" {
|
|
sb.WriteString(" AS ")
|
|
sb.WriteString(p.tableAlias)
|
|
}
|
|
}
|
|
|
|
// JOIN clauses
|
|
if len(p.joins) > 0 {
|
|
sb.WriteString(" ")
|
|
sb.WriteString(strings.Join(p.joins, " "))
|
|
}
|
|
|
|
// WHERE clause
|
|
if len(p.whereClauses) > 0 || len(p.orClauses) > 0 {
|
|
sb.WriteString(" WHERE ")
|
|
conditions := make([]string, 0)
|
|
|
|
if len(p.whereClauses) > 0 {
|
|
conditions = append(conditions, "("+strings.Join(p.whereClauses, " AND ")+")")
|
|
}
|
|
if len(p.orClauses) > 0 {
|
|
conditions = append(conditions, "("+strings.Join(p.orClauses, " OR ")+")")
|
|
}
|
|
|
|
sb.WriteString(strings.Join(conditions, " AND "))
|
|
}
|
|
|
|
// GROUP BY clause
|
|
if len(p.groupBy) > 0 {
|
|
sb.WriteString(" GROUP BY ")
|
|
sb.WriteString(strings.Join(p.groupBy, ", "))
|
|
}
|
|
|
|
// HAVING clause
|
|
if len(p.havingClauses) > 0 {
|
|
sb.WriteString(" HAVING ")
|
|
sb.WriteString(strings.Join(p.havingClauses, " AND "))
|
|
}
|
|
|
|
// ORDER BY clause
|
|
if len(p.orderBy) > 0 {
|
|
sb.WriteString(" ORDER BY ")
|
|
sb.WriteString(strings.Join(p.orderBy, ", "))
|
|
}
|
|
|
|
// LIMIT clause
|
|
if p.limit > 0 {
|
|
sb.WriteString(fmt.Sprintf(" LIMIT %d", p.limit))
|
|
}
|
|
|
|
// OFFSET clause
|
|
if p.offset > 0 {
|
|
sb.WriteString(fmt.Sprintf(" OFFSET %d", p.offset))
|
|
}
|
|
|
|
return sb.String()
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) replacePlaceholders(query string, argCount int) string {
|
|
result := query
|
|
for i := 0; i < argCount; i++ {
|
|
p.paramCounter++
|
|
placeholder := fmt.Sprintf("$%d", p.paramCounter)
|
|
result = strings.Replace(result, "?", placeholder, 1)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = logger.HandlePanic("PgSQLSelectQuery.Scan", r)
|
|
}
|
|
}()
|
|
|
|
// Apply preloads that use JOINs
|
|
p.applyJoinPreloads()
|
|
|
|
query := p.buildSQL()
|
|
logger.Debug("PgSQL SELECT: %s [args: %v]", query, p.args)
|
|
|
|
var rows *sql.Rows
|
|
if p.tx != nil {
|
|
rows, err = p.tx.QueryContext(ctx, query, p.args...)
|
|
} else {
|
|
rows, err = p.db.QueryContext(ctx, query, p.args...)
|
|
}
|
|
|
|
if err != nil {
|
|
logger.Error("PgSQL SELECT failed: %v", err)
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
err = scanRows(rows, dest)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Apply preloads that use separate queries
|
|
return p.applySubqueryPreloads(ctx, dest)
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
|
|
if p.model == nil {
|
|
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
|
}
|
|
return p.Scan(ctx, p.model)
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = logger.HandlePanic("PgSQLSelectQuery.Count", r)
|
|
count = 0
|
|
}
|
|
}()
|
|
|
|
// Build a COUNT query
|
|
var sb strings.Builder
|
|
sb.WriteString("SELECT COUNT(*) FROM ")
|
|
sb.WriteString(p.tableName)
|
|
|
|
if len(p.joins) > 0 {
|
|
sb.WriteString(" ")
|
|
sb.WriteString(strings.Join(p.joins, " "))
|
|
}
|
|
|
|
if len(p.whereClauses) > 0 || len(p.orClauses) > 0 {
|
|
sb.WriteString(" WHERE ")
|
|
conditions := make([]string, 0)
|
|
|
|
if len(p.whereClauses) > 0 {
|
|
conditions = append(conditions, "("+strings.Join(p.whereClauses, " AND ")+")")
|
|
}
|
|
if len(p.orClauses) > 0 {
|
|
conditions = append(conditions, "("+strings.Join(p.orClauses, " OR ")+")")
|
|
}
|
|
|
|
sb.WriteString(strings.Join(conditions, " AND "))
|
|
}
|
|
|
|
query := sb.String()
|
|
logger.Debug("PgSQL COUNT: %s [args: %v]", query, p.args)
|
|
|
|
var row *sql.Row
|
|
if p.tx != nil {
|
|
row = p.tx.QueryRowContext(ctx, query, p.args...)
|
|
} else {
|
|
row = p.db.QueryRowContext(ctx, query, p.args...)
|
|
}
|
|
|
|
err = row.Scan(&count)
|
|
if err != nil {
|
|
logger.Error("PgSQL COUNT failed: %v", err)
|
|
}
|
|
return count, err
|
|
}
|
|
|
|
func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = logger.HandlePanic("PgSQLSelectQuery.Exists", r)
|
|
exists = false
|
|
}
|
|
}()
|
|
|
|
count, err := p.Count(ctx)
|
|
return count > 0, err
|
|
}
|
|
|
|
// PgSQLInsertQuery implements InsertQuery for PostgreSQL
|
|
type PgSQLInsertQuery struct {
|
|
db *sql.DB
|
|
tx *sql.Tx
|
|
tableName string
|
|
values map[string]interface{}
|
|
returning []string
|
|
}
|
|
|
|
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
|
if provider, ok := model.(common.TableNameProvider); ok {
|
|
p.tableName = provider.TableName()
|
|
}
|
|
// Extract values from model using reflection
|
|
// This is a simplified implementation
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
|
|
p.tableName = table
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLInsertQuery) Value(column string, value interface{}) common.InsertQuery {
|
|
p.values[column] = value
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLInsertQuery) OnConflict(action string) common.InsertQuery {
|
|
logger.Warn("OnConflict not yet implemented in PgSQL adapter")
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLInsertQuery) Returning(columns ...string) common.InsertQuery {
|
|
p.returning = columns
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = logger.HandlePanic("PgSQLInsertQuery.Exec", r)
|
|
}
|
|
}()
|
|
|
|
if len(p.values) == 0 {
|
|
return nil, fmt.Errorf("no values to insert")
|
|
}
|
|
|
|
columns := make([]string, 0, len(p.values))
|
|
placeholders := make([]string, 0, len(p.values))
|
|
args := make([]interface{}, 0, len(p.values))
|
|
|
|
i := 1
|
|
for col, val := range p.values {
|
|
columns = append(columns, col)
|
|
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
|
|
args = append(args, val)
|
|
i++
|
|
}
|
|
|
|
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
|
p.tableName,
|
|
strings.Join(columns, ", "),
|
|
strings.Join(placeholders, ", "))
|
|
|
|
if len(p.returning) > 0 {
|
|
query += " RETURNING " + strings.Join(p.returning, ", ")
|
|
}
|
|
|
|
logger.Debug("PgSQL INSERT: %s [args: %v]", query, args)
|
|
|
|
var result sql.Result
|
|
if p.tx != nil {
|
|
result, err = p.tx.ExecContext(ctx, query, args...)
|
|
} else {
|
|
result, err = p.db.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
if err != nil {
|
|
logger.Error("PgSQL INSERT failed: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
return &PgSQLResult{result: result}, nil
|
|
}
|
|
|
|
// PgSQLUpdateQuery implements UpdateQuery for PostgreSQL
|
|
type PgSQLUpdateQuery struct {
|
|
db *sql.DB
|
|
tx *sql.Tx
|
|
tableName string
|
|
model interface{}
|
|
sets map[string]interface{}
|
|
whereClauses []string
|
|
args []interface{}
|
|
paramCounter int
|
|
returning []string
|
|
}
|
|
|
|
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
|
p.model = model
|
|
if provider, ok := model.(common.TableNameProvider); ok {
|
|
p.tableName = provider.TableName()
|
|
}
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
|
|
p.tableName = table
|
|
if p.model == nil {
|
|
model, err := modelregistry.GetModelByName(table)
|
|
if err == nil {
|
|
p.model = model
|
|
}
|
|
}
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
|
|
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
|
|
return p
|
|
}
|
|
p.sets[column] = value
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
|
pkName := ""
|
|
if p.model != nil {
|
|
pkName = reflection.GetPrimaryKeyName(p.model)
|
|
}
|
|
|
|
for column, value := range values {
|
|
if pkName != "" && column == pkName {
|
|
continue
|
|
}
|
|
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
|
|
continue
|
|
}
|
|
p.sets[column] = value
|
|
}
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLUpdateQuery) Where(query string, args ...interface{}) common.UpdateQuery {
|
|
query = p.replacePlaceholders(query, len(args))
|
|
p.whereClauses = append(p.whereClauses, query)
|
|
p.args = append(p.args, args...)
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
|
p.returning = columns
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLUpdateQuery) replacePlaceholders(query string, argCount int) string {
|
|
result := query
|
|
for i := 0; i < argCount; i++ {
|
|
p.paramCounter++
|
|
placeholder := fmt.Sprintf("$%d", p.paramCounter)
|
|
result = strings.Replace(result, "?", placeholder, 1)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = logger.HandlePanic("PgSQLUpdateQuery.Exec", r)
|
|
}
|
|
}()
|
|
|
|
if len(p.sets) == 0 {
|
|
return nil, fmt.Errorf("no values to update")
|
|
}
|
|
|
|
setClauses := make([]string, 0, len(p.sets))
|
|
setArgs := make([]interface{}, 0, len(p.sets))
|
|
|
|
// SET parameters start at $1
|
|
i := 1
|
|
for col, val := range p.sets {
|
|
setClauses = append(setClauses, fmt.Sprintf("%s = $%d", col, i))
|
|
setArgs = append(setArgs, val)
|
|
i++
|
|
}
|
|
|
|
query := fmt.Sprintf("UPDATE %s SET %s",
|
|
p.tableName,
|
|
strings.Join(setClauses, ", "))
|
|
|
|
// Update WHERE clause parameter numbers to continue after SET parameters
|
|
if len(p.whereClauses) > 0 {
|
|
updatedWhereClauses := make([]string, 0, len(p.whereClauses))
|
|
for _, whereClause := range p.whereClauses {
|
|
// Find and replace parameter placeholders
|
|
updatedClause := whereClause
|
|
paramNum := i
|
|
// Count how many parameters are in this WHERE clause
|
|
placeholderCount := strings.Count(whereClause, "$")
|
|
for j := 0; j < placeholderCount; j++ {
|
|
oldParam := fmt.Sprintf("$%d", j+1)
|
|
newParam := fmt.Sprintf("$%d", paramNum)
|
|
updatedClause = strings.Replace(updatedClause, oldParam, newParam, 1)
|
|
paramNum++
|
|
}
|
|
updatedWhereClauses = append(updatedWhereClauses, updatedClause)
|
|
i = paramNum
|
|
}
|
|
p.whereClauses = updatedWhereClauses
|
|
}
|
|
|
|
// All arguments: SET values first, then WHERE values
|
|
// Create a new slice to avoid modifying setArgs
|
|
allArgs := make([]interface{}, len(setArgs)+len(p.args))
|
|
copy(allArgs, setArgs)
|
|
copy(allArgs[len(setArgs):], p.args)
|
|
|
|
if len(p.whereClauses) > 0 {
|
|
query += " WHERE " + strings.Join(p.whereClauses, " AND ")
|
|
}
|
|
|
|
if len(p.returning) > 0 {
|
|
query += " RETURNING " + strings.Join(p.returning, ", ")
|
|
}
|
|
|
|
logger.Debug("PgSQL UPDATE: %s [args: %v]", query, allArgs)
|
|
|
|
var result sql.Result
|
|
if p.tx != nil {
|
|
result, err = p.tx.ExecContext(ctx, query, allArgs...)
|
|
} else {
|
|
result, err = p.db.ExecContext(ctx, query, allArgs...)
|
|
}
|
|
|
|
if err != nil {
|
|
logger.Error("PgSQL UPDATE failed: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
return &PgSQLResult{result: result}, nil
|
|
}
|
|
|
|
// PgSQLDeleteQuery implements DeleteQuery for PostgreSQL
|
|
type PgSQLDeleteQuery struct {
|
|
db *sql.DB
|
|
tx *sql.Tx
|
|
tableName string
|
|
whereClauses []string
|
|
args []interface{}
|
|
paramCounter int
|
|
}
|
|
|
|
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
|
if provider, ok := model.(common.TableNameProvider); ok {
|
|
p.tableName = provider.TableName()
|
|
}
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
|
|
p.tableName = table
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLDeleteQuery) Where(query string, args ...interface{}) common.DeleteQuery {
|
|
query = p.replacePlaceholders(query, len(args))
|
|
p.whereClauses = append(p.whereClauses, query)
|
|
p.args = append(p.args, args...)
|
|
return p
|
|
}
|
|
|
|
func (p *PgSQLDeleteQuery) replacePlaceholders(query string, argCount int) string {
|
|
result := query
|
|
for i := 0; i < argCount; i++ {
|
|
p.paramCounter++
|
|
placeholder := fmt.Sprintf("$%d", p.paramCounter)
|
|
result = strings.Replace(result, "?", placeholder, 1)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = logger.HandlePanic("PgSQLDeleteQuery.Exec", r)
|
|
}
|
|
}()
|
|
|
|
query := fmt.Sprintf("DELETE FROM %s", p.tableName)
|
|
|
|
if len(p.whereClauses) > 0 {
|
|
query += " WHERE " + strings.Join(p.whereClauses, " AND ")
|
|
}
|
|
|
|
logger.Debug("PgSQL DELETE: %s [args: %v]", query, p.args)
|
|
|
|
var result sql.Result
|
|
if p.tx != nil {
|
|
result, err = p.tx.ExecContext(ctx, query, p.args...)
|
|
} else {
|
|
result, err = p.db.ExecContext(ctx, query, p.args...)
|
|
}
|
|
|
|
if err != nil {
|
|
logger.Error("PgSQL DELETE failed: %v", err)
|
|
return nil, err
|
|
}
|
|
|
|
return &PgSQLResult{result: result}, nil
|
|
}
|
|
|
|
// PgSQLResult implements Result for PostgreSQL
|
|
type PgSQLResult struct {
|
|
result sql.Result
|
|
}
|
|
|
|
func (p *PgSQLResult) RowsAffected() int64 {
|
|
if p.result == nil {
|
|
return 0
|
|
}
|
|
rows, _ := p.result.RowsAffected()
|
|
return rows
|
|
}
|
|
|
|
func (p *PgSQLResult) LastInsertId() (int64, error) {
|
|
if p.result == nil {
|
|
return 0, nil
|
|
}
|
|
return p.result.LastInsertId()
|
|
}
|
|
|
|
// PgSQLTxAdapter wraps a PostgreSQL transaction
|
|
type PgSQLTxAdapter struct {
|
|
tx *sql.Tx
|
|
}
|
|
|
|
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
|
|
return &PgSQLSelectQuery{
|
|
tx: p.tx,
|
|
columns: []string{"*"},
|
|
args: make([]interface{}, 0),
|
|
}
|
|
}
|
|
|
|
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
|
|
return &PgSQLInsertQuery{
|
|
tx: p.tx,
|
|
values: make(map[string]interface{}),
|
|
}
|
|
}
|
|
|
|
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
|
|
return &PgSQLUpdateQuery{
|
|
tx: p.tx,
|
|
sets: make(map[string]interface{}),
|
|
args: make([]interface{}, 0),
|
|
whereClauses: make([]string, 0),
|
|
}
|
|
}
|
|
|
|
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
|
|
return &PgSQLDeleteQuery{
|
|
tx: p.tx,
|
|
args: make([]interface{}, 0),
|
|
whereClauses: make([]string, 0),
|
|
}
|
|
}
|
|
|
|
func (p *PgSQLTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
|
logger.Debug("PgSQL Tx Exec: %s [args: %v]", query, args)
|
|
result, err := p.tx.ExecContext(ctx, query, args...)
|
|
if err != nil {
|
|
logger.Error("PgSQL Tx Exec failed: %v", err)
|
|
return nil, err
|
|
}
|
|
return &PgSQLResult{result: result}, nil
|
|
}
|
|
|
|
func (p *PgSQLTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
|
logger.Debug("PgSQL Tx Query: %s [args: %v]", query, args)
|
|
rows, err := p.tx.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
logger.Error("PgSQL Tx Query failed: %v", err)
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
return scanRows(rows, dest)
|
|
}
|
|
|
|
func (p *PgSQLTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
|
return nil, fmt.Errorf("nested transactions not supported")
|
|
}
|
|
|
|
func (p *PgSQLTxAdapter) CommitTx(ctx context.Context) error {
|
|
return p.tx.Commit()
|
|
}
|
|
|
|
func (p *PgSQLTxAdapter) RollbackTx(ctx context.Context) error {
|
|
return p.tx.Rollback()
|
|
}
|
|
|
|
func (p *PgSQLTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
|
return fn(p)
|
|
}
|
|
|
|
func (p *PgSQLTxAdapter) GetUnderlyingDB() interface{} {
|
|
return p.tx
|
|
}
|
|
|
|
// applyJoinPreloads adds JOINs for relationships that should use JOIN strategy
|
|
func (p *PgSQLSelectQuery) applyJoinPreloads() {
|
|
for _, preload := range p.preloads {
|
|
if !preload.useJoin {
|
|
continue
|
|
}
|
|
|
|
// Build JOIN based on relationship metadata
|
|
meta := p.getRelationMetadata(preload.relation)
|
|
if meta == nil {
|
|
logger.Warn("Cannot determine relationship metadata for '%s'", preload.relation)
|
|
continue
|
|
}
|
|
|
|
// Build the JOIN clause
|
|
relationAlias := strings.ToLower(preload.relation)
|
|
joinClause := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s",
|
|
meta.targetTable,
|
|
relationAlias,
|
|
p.tableAlias,
|
|
meta.foreignKey,
|
|
relationAlias,
|
|
meta.targetKey,
|
|
)
|
|
|
|
logger.Debug("Adding LEFT JOIN for relation '%s': %s", preload.relation, joinClause)
|
|
p.joins = append(p.joins, "LEFT JOIN "+joinClause)
|
|
|
|
// Apply any custom conditions through applyFuncs
|
|
// Note: These would need to be integrated into the WHERE clause
|
|
// For simplicity, we're logging a warning if custom conditions are present
|
|
if len(preload.applyFuncs) > 0 {
|
|
logger.Warn("Custom conditions in JoinRelation not yet fully implemented")
|
|
}
|
|
}
|
|
}
|
|
|
|
// applySubqueryPreloads executes separate queries for has-many and many-to-many relationships
|
|
func (p *PgSQLSelectQuery) applySubqueryPreloads(ctx context.Context, dest interface{}) error {
|
|
// Get all preloads that don't use JOIN
|
|
subqueryPreloads := make([]preloadConfig, 0)
|
|
for _, preload := range p.preloads {
|
|
if !preload.useJoin {
|
|
subqueryPreloads = append(subqueryPreloads, preload)
|
|
}
|
|
}
|
|
|
|
if len(subqueryPreloads) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Use reflection to process the destination
|
|
destValue := reflect.ValueOf(dest)
|
|
if destValue.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("dest must be a pointer")
|
|
}
|
|
|
|
destValue = destValue.Elem()
|
|
|
|
// Handle slice of structs
|
|
if destValue.Kind() == reflect.Slice {
|
|
for i := 0; i < destValue.Len(); i++ {
|
|
elem := destValue.Index(i)
|
|
if err := p.loadPreloadsForRecord(ctx, elem, subqueryPreloads); err != nil {
|
|
logger.Warn("Failed to load preloads for record %d: %v", i, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Handle single struct
|
|
if destValue.Kind() == reflect.Struct {
|
|
return p.loadPreloadsForRecord(ctx, destValue, subqueryPreloads)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// loadPreloadsForRecord loads all preload relationships for a single record
|
|
func (p *PgSQLSelectQuery) loadPreloadsForRecord(ctx context.Context, record reflect.Value, preloads []preloadConfig) error {
|
|
if record.Kind() == reflect.Ptr {
|
|
if record.IsNil() {
|
|
return nil
|
|
}
|
|
record = record.Elem()
|
|
}
|
|
|
|
for _, preload := range preloads {
|
|
field := record.FieldByName(preload.relation)
|
|
if !field.IsValid() || !field.CanSet() {
|
|
logger.Warn("Field '%s' not found or cannot be set", preload.relation)
|
|
continue
|
|
}
|
|
|
|
meta := p.getRelationMetadataFromField(record.Type(), preload.relation)
|
|
if meta == nil {
|
|
logger.Warn("Cannot determine relationship metadata for '%s'", preload.relation)
|
|
continue
|
|
}
|
|
|
|
// Get the foreign key value from the parent record
|
|
fkField := record.FieldByName(meta.foreignKey)
|
|
if !fkField.IsValid() {
|
|
logger.Warn("Foreign key field '%s' not found", meta.foreignKey)
|
|
continue
|
|
}
|
|
|
|
fkValue := fkField.Interface()
|
|
|
|
// Build and execute the preload query
|
|
err := p.executePreloadQuery(ctx, field, meta, fkValue, preload)
|
|
if err != nil {
|
|
logger.Warn("Failed to execute preload query for '%s': %v", preload.relation, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// executePreloadQuery executes a query to load a relationship
|
|
func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflect.Value, meta *relationMetadata, fkValue interface{}, preload preloadConfig) error {
|
|
// Create a new select query for the related table
|
|
var db common.Database
|
|
if p.tx != nil {
|
|
db = &PgSQLTxAdapter{tx: p.tx}
|
|
} else {
|
|
db = &PgSQLAdapter{db: p.db}
|
|
}
|
|
|
|
query := db.NewSelect().
|
|
Table(meta.targetTable).
|
|
Where(fmt.Sprintf("%s = ?", meta.targetKey), fkValue)
|
|
|
|
// Apply custom functions
|
|
for _, applyFunc := range preload.applyFuncs {
|
|
if applyFunc != nil {
|
|
query = applyFunc(query)
|
|
}
|
|
}
|
|
|
|
// Determine if this is a slice (has-many) or single struct (belongs-to/has-one)
|
|
if field.Kind() == reflect.Slice {
|
|
// Create a new slice to hold results
|
|
sliceType := field.Type()
|
|
results := reflect.New(sliceType).Elem()
|
|
|
|
// Execute query
|
|
err := query.Scan(ctx, results.Addr().Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Set the field
|
|
field.Set(results)
|
|
} else {
|
|
// Single struct - create a pointer if needed
|
|
var target reflect.Value
|
|
if field.Kind() == reflect.Ptr {
|
|
target = reflect.New(field.Type().Elem())
|
|
} else {
|
|
target = reflect.New(field.Type())
|
|
}
|
|
|
|
// Execute query with LIMIT 1
|
|
err := query.Limit(1).Scan(ctx, target.Interface())
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return err
|
|
}
|
|
|
|
// Set the field
|
|
if field.Kind() == reflect.Ptr {
|
|
field.Set(target)
|
|
} else {
|
|
field.Set(target.Elem())
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// getRelationMetadata extracts relationship metadata from the model
|
|
func (p *PgSQLSelectQuery) getRelationMetadata(fieldName string) *relationMetadata {
|
|
if p.model == nil {
|
|
return nil
|
|
}
|
|
|
|
modelType := reflect.TypeOf(p.model)
|
|
if modelType.Kind() == reflect.Ptr {
|
|
modelType = modelType.Elem()
|
|
}
|
|
|
|
return p.getRelationMetadataFromField(modelType, fieldName)
|
|
}
|
|
|
|
// getRelationMetadataFromField extracts relationship metadata from a type
|
|
func (p *PgSQLSelectQuery) getRelationMetadataFromField(modelType reflect.Type, fieldName string) *relationMetadata {
|
|
if modelType.Kind() != reflect.Struct {
|
|
return nil
|
|
}
|
|
|
|
field, found := modelType.FieldByName(fieldName)
|
|
if !found {
|
|
return nil
|
|
}
|
|
|
|
meta := &relationMetadata{
|
|
fieldName: fieldName,
|
|
relationType: reflection.GetRelationType(reflect.New(modelType).Interface(), fieldName),
|
|
}
|
|
|
|
// Parse struct tags to get foreign key and target table
|
|
bunTag := field.Tag.Get("bun")
|
|
if bunTag != "" {
|
|
// Parse bun tags: rel:has-many,join:user_id=id
|
|
parts := strings.Split(bunTag, ",")
|
|
for _, part := range parts {
|
|
part = strings.TrimSpace(part)
|
|
if strings.HasPrefix(part, "join:") {
|
|
// Parse join condition: join:user_id=id
|
|
joinSpec := strings.TrimPrefix(part, "join:")
|
|
if strings.Contains(joinSpec, "=") {
|
|
keys := strings.Split(joinSpec, "=")
|
|
if len(keys) == 2 {
|
|
meta.foreignKey = strings.TrimSpace(keys[0])
|
|
meta.targetKey = strings.TrimSpace(keys[1])
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Try to determine target table from field type
|
|
fieldType := field.Type
|
|
if fieldType.Kind() == reflect.Slice {
|
|
fieldType = fieldType.Elem()
|
|
}
|
|
if fieldType.Kind() == reflect.Ptr {
|
|
fieldType = fieldType.Elem()
|
|
}
|
|
|
|
if fieldType.Kind() == reflect.Struct {
|
|
// Try to get table name from the related model
|
|
relatedModel := reflect.New(fieldType).Interface()
|
|
if provider, ok := relatedModel.(common.TableNameProvider); ok {
|
|
meta.targetTable = provider.TableName()
|
|
}
|
|
}
|
|
|
|
// Set defaults if not found
|
|
if meta.foreignKey == "" {
|
|
meta.foreignKey = "id"
|
|
}
|
|
if meta.targetKey == "" {
|
|
meta.targetKey = "id"
|
|
}
|
|
|
|
return meta
|
|
}
|
|
|
|
// scanRows scans database rows into the destination using reflection
|
|
func scanRows(rows *sql.Rows, dest interface{}) error {
|
|
// Get column names
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Get destination type
|
|
destValue := reflect.ValueOf(dest)
|
|
if destValue.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("dest must be a pointer")
|
|
}
|
|
|
|
destValue = destValue.Elem()
|
|
|
|
// Handle map slice: []map[string]interface{}
|
|
if destValue.Type() == reflect.TypeOf([]map[string]interface{}{}) {
|
|
return scanRowsToMapSlice(rows, columns, destValue)
|
|
}
|
|
|
|
// Handle struct slice: []MyStruct or []*MyStruct
|
|
if destValue.Kind() == reflect.Slice {
|
|
return scanRowsToStructSlice(rows, columns, destValue)
|
|
}
|
|
|
|
// Handle single struct: MyStruct or *MyStruct
|
|
if destValue.Kind() == reflect.Struct {
|
|
return scanRowsToSingleStruct(rows, columns, destValue)
|
|
}
|
|
|
|
return fmt.Errorf("unsupported destination type: %T", dest)
|
|
}
|
|
|
|
// scanRowsToMapSlice scans rows into []map[string]interface{}
|
|
func scanRowsToMapSlice(rows *sql.Rows, columns []string, destValue reflect.Value) error {
|
|
results := make([]map[string]interface{}, 0)
|
|
|
|
for rows.Next() {
|
|
// Create holders for values
|
|
values := make([]interface{}, len(columns))
|
|
valuePtrs := make([]interface{}, len(columns))
|
|
for i := range values {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
|
|
err := rows.Scan(valuePtrs...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
row := make(map[string]interface{})
|
|
for i, col := range columns {
|
|
row[col] = values[i]
|
|
}
|
|
results = append(results, row)
|
|
}
|
|
|
|
destValue.Set(reflect.ValueOf(results))
|
|
return rows.Err()
|
|
}
|
|
|
|
// scanRowsToStructSlice scans rows into a slice of structs
|
|
func scanRowsToStructSlice(rows *sql.Rows, columns []string, destValue reflect.Value) error {
|
|
elemType := destValue.Type().Elem()
|
|
isPtr := elemType.Kind() == reflect.Ptr
|
|
|
|
if isPtr {
|
|
elemType = elemType.Elem()
|
|
}
|
|
|
|
if elemType.Kind() != reflect.Struct {
|
|
return fmt.Errorf("slice element must be a struct, got %v", elemType.Kind())
|
|
}
|
|
|
|
// Build column-to-field mapping
|
|
fieldMap := buildFieldMap(elemType, columns)
|
|
|
|
for rows.Next() {
|
|
// Create a new instance of the struct
|
|
elemValue := reflect.New(elemType).Elem()
|
|
|
|
// Create scan targets
|
|
scanTargets := make([]interface{}, len(columns))
|
|
for i, col := range columns {
|
|
if fieldInfo, ok := fieldMap[col]; ok {
|
|
field := elemValue.FieldByIndex(fieldInfo.Index)
|
|
if field.CanSet() {
|
|
scanTargets[i] = field.Addr().Interface()
|
|
continue
|
|
}
|
|
}
|
|
// Use a dummy variable for unmapped columns
|
|
var dummy interface{}
|
|
scanTargets[i] = &dummy
|
|
}
|
|
|
|
err := rows.Scan(scanTargets...)
|
|
if err != nil {
|
|
return fmt.Errorf("scan failed: %w", err)
|
|
}
|
|
|
|
// Append to slice
|
|
if isPtr {
|
|
destValue.Set(reflect.Append(destValue, elemValue.Addr()))
|
|
} else {
|
|
destValue.Set(reflect.Append(destValue, elemValue))
|
|
}
|
|
}
|
|
|
|
return rows.Err()
|
|
}
|
|
|
|
// scanRowsToSingleStruct scans a single row into a struct
|
|
func scanRowsToSingleStruct(rows *sql.Rows, columns []string, destValue reflect.Value) error {
|
|
if !rows.Next() {
|
|
return sql.ErrNoRows
|
|
}
|
|
|
|
// Build column-to-field mapping
|
|
fieldMap := buildFieldMap(destValue.Type(), columns)
|
|
|
|
// Create scan targets
|
|
scanTargets := make([]interface{}, len(columns))
|
|
for i, col := range columns {
|
|
if fieldInfo, ok := fieldMap[col]; ok {
|
|
field := destValue.FieldByIndex(fieldInfo.Index)
|
|
if field.CanSet() {
|
|
scanTargets[i] = field.Addr().Interface()
|
|
continue
|
|
}
|
|
}
|
|
// Use a dummy variable for unmapped columns
|
|
var dummy interface{}
|
|
scanTargets[i] = &dummy
|
|
}
|
|
|
|
err := rows.Scan(scanTargets...)
|
|
if err != nil {
|
|
return fmt.Errorf("scan failed: %w", err)
|
|
}
|
|
|
|
return rows.Err()
|
|
}
|
|
|
|
// fieldInfo holds information about a struct field
|
|
type fieldInfo struct {
|
|
Index []int
|
|
Name string
|
|
}
|
|
|
|
// buildFieldMap creates a mapping from column names to struct fields
|
|
func buildFieldMap(structType reflect.Type, _ []string) map[string]fieldInfo {
|
|
fieldMap := make(map[string]fieldInfo)
|
|
|
|
// Iterate through struct fields
|
|
for i := 0; i < structType.NumField(); i++ {
|
|
field := structType.Field(i)
|
|
|
|
// Skip unexported fields
|
|
if !field.IsExported() {
|
|
continue
|
|
}
|
|
|
|
// Get column name from struct tag or field name
|
|
colName := field.Name
|
|
|
|
// Check for bun tag
|
|
if bunTag := field.Tag.Get("bun"); bunTag != "" {
|
|
parts := strings.Split(bunTag, ",")
|
|
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
|
|
colName = parts[0]
|
|
}
|
|
}
|
|
|
|
// Check for db tag (common convention)
|
|
if dbTag := field.Tag.Get("db"); dbTag != "" && dbTag != "-" {
|
|
colName = dbTag
|
|
}
|
|
|
|
// Convert to lowercase for case-insensitive matching
|
|
colNameLower := strings.ToLower(colName)
|
|
|
|
fieldMap[colNameLower] = fieldInfo{
|
|
Index: field.Index,
|
|
Name: field.Name,
|
|
}
|
|
|
|
// Also map by exact field name
|
|
fieldMap[strings.ToLower(field.Name)] = fieldInfo{
|
|
Index: field.Index,
|
|
Name: field.Name,
|
|
}
|
|
}
|
|
|
|
return fieldMap
|
|
}
|