mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-06-09 23:33:45 +00:00
fix(database): add Scan method to insert query interfaces
* Implement Scan method for BunInsertQuery, GormInsertQuery, and PgSQLInsertQuery * Update mock implementations to support Scan method * Introduce GetForeignKeyColumn utility for foreign key resolution * Add tests for GetForeignKeyColumn functionality
This commit is contained in:
@@ -1451,6 +1451,18 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) prepareValues() {
|
||||
if len(b.values) > 0 {
|
||||
if !b.hasModel {
|
||||
b.query = b.query.Model(&b.values)
|
||||
} else {
|
||||
for k, v := range b.values {
|
||||
b.query = b.query.Value(k, "?", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
@@ -1458,23 +1470,25 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
if len(b.values) > 0 {
|
||||
if !b.hasModel {
|
||||
// If no model was set, use the values map as the model
|
||||
// Bun can insert map[string]interface{} directly
|
||||
b.query = b.query.Model(&b.values)
|
||||
} else {
|
||||
// If model was set, use Value() to add individual values
|
||||
for k, v := range b.values {
|
||||
b.query = b.query.Value(k, "?", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
b.prepareValues()
|
||||
result, err := b.query.Exec(ctx)
|
||||
recordQueryMetrics(b.metricsEnabled, "INSERT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunInsertQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
b.prepareValues()
|
||||
err = b.query.Scan(ctx, dest)
|
||||
recordQueryMetrics(b.metricsEnabled, "INSERT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// BunUpdateQuery implements UpdateQuery for Bun
|
||||
type BunUpdateQuery struct {
|
||||
query *bun.UpdateQuery
|
||||
|
||||
@@ -3,11 +3,13 @@ package database
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -676,15 +678,16 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
|
||||
// GormInsertQuery implements InsertQuery for GORM
|
||||
type GormInsertQuery struct {
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
values map[string]interface{}
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
db *gorm.DB
|
||||
reconnect func(...*gorm.DB) error
|
||||
model interface{}
|
||||
values map[string]interface{}
|
||||
schema string
|
||||
tableName string
|
||||
entity string
|
||||
driverName string
|
||||
metricsEnabled bool
|
||||
returningColumns []string
|
||||
}
|
||||
|
||||
func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
@@ -718,7 +721,7 @@ func (g *GormInsertQuery) OnConflict(action string) common.InsertQuery {
|
||||
}
|
||||
|
||||
func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
// GORM doesn't have explicit RETURNING, but updates the model
|
||||
g.returningColumns = columns
|
||||
return g
|
||||
}
|
||||
|
||||
@@ -749,6 +752,76 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
func (g *GormInsertQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormInsertQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
startedAt := time.Now()
|
||||
|
||||
var returningCols []clause.Column
|
||||
for _, col := range g.returningColumns {
|
||||
returningCols = append(returningCols, clause.Column{Name: col})
|
||||
}
|
||||
|
||||
db := g.db.WithContext(ctx)
|
||||
if len(returningCols) > 0 {
|
||||
db = db.Clauses(clause.Returning{Columns: returningCols})
|
||||
}
|
||||
|
||||
var result *gorm.DB
|
||||
switch {
|
||||
case g.model != nil:
|
||||
result = db.Create(g.model)
|
||||
case g.values != nil:
|
||||
result = db.Create(g.values)
|
||||
default:
|
||||
result = db.Create(map[string]interface{}{})
|
||||
}
|
||||
|
||||
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||
result = db.Create(g.model)
|
||||
}
|
||||
}
|
||||
|
||||
recordQueryMetrics(g.metricsEnabled, "INSERT", g.schema, g.entity, g.tableName, startedAt, result.Error)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// Extract the returning column value from the model or values map
|
||||
if len(g.returningColumns) == 1 {
|
||||
col := g.returningColumns[0]
|
||||
if g.model != nil {
|
||||
val := reflect.ValueOf(g.model)
|
||||
if val.Kind() == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
if val.Kind() == reflect.Struct {
|
||||
for i := 0; i < val.NumField(); i++ {
|
||||
f := val.Type().Field(i)
|
||||
dbTag := strings.Split(f.Tag.Get("bun"), ",")[0]
|
||||
jsonTag := strings.Split(f.Tag.Get("json"), ",")[0]
|
||||
if strings.EqualFold(f.Name, col) || dbTag == col || jsonTag == col {
|
||||
reflect.ValueOf(dest).Elem().Set(val.Field(i))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if g.values != nil {
|
||||
if v, ok := g.values[col]; ok {
|
||||
reflect.ValueOf(dest).Elem().Set(reflect.ValueOf(v))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GormUpdateQuery implements UpdateQuery for GORM
|
||||
type GormUpdateQuery struct {
|
||||
db *gorm.DB
|
||||
|
||||
@@ -708,6 +708,51 @@ func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err err
|
||||
return &PgSQLResult{result: result}, nil
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||
startedAt := time.Now()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("PgSQLInsertQuery.Scan", r)
|
||||
}
|
||||
recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||
}()
|
||||
|
||||
if len(p.values) == 0 {
|
||||
return fmt.Errorf("no values to insert")
|
||||
}
|
||||
|
||||
columns := make([]string, 0, len(p.values))
|
||||
placeholders := make([]string, 0, len(p.values))
|
||||
args := make([]interface{}, 0, len(p.values))
|
||||
i := 1
|
||||
for _, col := range p.valueOrder {
|
||||
columns = append(columns, col)
|
||||
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
|
||||
args = append(args, p.values[col])
|
||||
i++
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||
p.tableName,
|
||||
strings.Join(columns, ", "),
|
||||
strings.Join(placeholders, ", "))
|
||||
|
||||
if len(p.returning) > 0 {
|
||||
query += " RETURNING " + strings.Join(p.returning, ", ")
|
||||
}
|
||||
|
||||
logger.Debug("PgSQL INSERT (Scan): %s [args: %v]", query, args)
|
||||
|
||||
var row *sql.Row
|
||||
if p.tx != nil {
|
||||
row = p.tx.QueryRowContext(ctx, query, args...)
|
||||
} else {
|
||||
row = p.db.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
return row.Scan(dest)
|
||||
}
|
||||
|
||||
// PgSQLUpdateQuery implements UpdateQuery for PostgreSQL
|
||||
type PgSQLUpdateQuery struct {
|
||||
db *sql.DB
|
||||
|
||||
Reference in New Issue
Block a user