diff --git a/.vscode/tasks.json b/.vscode/tasks.json index d6b77dd..1d23c58 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -230,7 +230,17 @@ "cwd": "${workspaceFolder}" }, "problemMatcher": [], - "group": "test" + "group": "build" + }, + { + "type": "shell", + "label": "go: lint workspace (fix)", + "command": "golangci-lint run --timeout=5m --fix", + "options": { + "cwd": "${workspaceFolder}" + }, + "problemMatcher": [], + "group": "build" }, { "type": "shell", @@ -275,4 +285,4 @@ "command": "sh ${workspaceFolder}/make_release.sh" } ] -} +} \ No newline at end of file diff --git a/pkg/common/adapters/database/pgsql.go b/pkg/common/adapters/database/pgsql.go new file mode 100644 index 0000000..f71bf26 --- /dev/null +++ b/pkg/common/adapters/database/pgsql.go @@ -0,0 +1,1355 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" + + "github.com/bitechdev/ResolveSpec/pkg/common" + "github.com/bitechdev/ResolveSpec/pkg/logger" + "github.com/bitechdev/ResolveSpec/pkg/modelregistry" + "github.com/bitechdev/ResolveSpec/pkg/reflection" +) + +// PgSQLAdapter adapts standard database/sql to work with our Database interface +// This provides a lightweight PostgreSQL adapter without ORM overhead +type PgSQLAdapter struct { + db *sql.DB +} + +// NewPgSQLAdapter creates a new PostgreSQL adapter +func NewPgSQLAdapter(db *sql.DB) *PgSQLAdapter { + return &PgSQLAdapter{db: db} +} + +// EnableQueryDebug enables query debugging for development +func (p *PgSQLAdapter) EnableQueryDebug() { + logger.Info("PgSQL query debug mode - logging enabled via logger") +} + +func (p *PgSQLAdapter) NewSelect() common.SelectQuery { + return &PgSQLSelectQuery{ + db: p.db, + columns: []string{"*"}, + args: make([]interface{}, 0), + } +} + +func (p *PgSQLAdapter) NewInsert() common.InsertQuery { + return &PgSQLInsertQuery{ + db: p.db, + values: make(map[string]interface{}), + } +} + +func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery { + return &PgSQLUpdateQuery{ + db: p.db, + sets: make(map[string]interface{}), + args: make([]interface{}, 0), + whereClauses: make([]string, 0), + } +} + +func (p *PgSQLAdapter) NewDelete() common.DeleteQuery { + return &PgSQLDeleteQuery{ + db: p.db, + args: make([]interface{}, 0), + whereClauses: make([]string, 0), + } +} + +func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { + defer func() { + if r := recover(); r != nil { + err = logger.HandlePanic("PgSQLAdapter.Exec", r) + } + }() + logger.Debug("PgSQL Exec: %s [args: %v]", query, args) + result, err := p.db.ExecContext(ctx, query, args...) + if err != nil { + logger.Error("PgSQL Exec failed: %v", err) + return nil, err + } + return &PgSQLResult{result: result}, nil +} + +func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + err = logger.HandlePanic("PgSQLAdapter.Query", r) + } + }() + logger.Debug("PgSQL Query: %s [args: %v]", query, args) + rows, err := p.db.QueryContext(ctx, query, args...) + if err != nil { + logger.Error("PgSQL Query failed: %v", err) + return err + } + defer rows.Close() + + return scanRows(rows, dest) +} + +func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) { + tx, err := p.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + return &PgSQLTxAdapter{tx: tx}, nil +} + +func (p *PgSQLAdapter) CommitTx(ctx context.Context) error { + return fmt.Errorf("CommitTx should be called on transaction adapter") +} + +func (p *PgSQLAdapter) RollbackTx(ctx context.Context) error { + return fmt.Errorf("RollbackTx should be called on transaction adapter") +} + +func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) { + defer func() { + if r := recover(); r != nil { + err = logger.HandlePanic("PgSQLAdapter.RunInTransaction", r) + } + }() + + tx, err := p.db.BeginTx(ctx, nil) + if err != nil { + return err + } + + adapter := &PgSQLTxAdapter{tx: tx} + + defer func() { + if p := recover(); p != nil { + _ = tx.Rollback() + panic(p) + } else if err != nil { + _ = tx.Rollback() + } else { + err = tx.Commit() + } + }() + + return fn(adapter) +} + +// preloadConfig represents a relationship to be preloaded +type preloadConfig struct { + relation string + conditions []interface{} + applyFuncs []func(common.SelectQuery) common.SelectQuery + useJoin bool +} + +// relationMetadata contains information about a relationship +type relationMetadata struct { + fieldName string + relationType reflection.RelationType + foreignKey string + targetTable string + targetKey string +} + +// PgSQLSelectQuery implements SelectQuery for PostgreSQL +type PgSQLSelectQuery struct { + db *sql.DB + tx *sql.Tx + model interface{} + tableName string + tableAlias string + columns []string + columnExprs []string + whereClauses []string + orClauses []string + joins []string + orderBy []string + groupBy []string + havingClauses []string + limit int + offset int + args []interface{} + paramCounter int + preloads []preloadConfig +} + +func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery { + p.model = model + if provider, ok := model.(common.TableNameProvider); ok { + p.tableName = provider.TableName() + } + if provider, ok := model.(common.TableAliasProvider); ok { + p.tableAlias = provider.TableAlias() + } + return p +} + +func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery { + p.tableName = table + return p +} + +func (p *PgSQLSelectQuery) Column(columns ...string) common.SelectQuery { + if len(p.columns) == 1 && p.columns[0] == "*" { + p.columns = make([]string, 0) + } + p.columns = append(p.columns, columns...) + return p +} + +func (p *PgSQLSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery { + p.columnExprs = append(p.columnExprs, query) + p.args = append(p.args, args...) + return p +} + +func (p *PgSQLSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { + // Replace ? placeholders with $1, $2, etc. + query = p.replacePlaceholders(query, len(args)) + p.whereClauses = append(p.whereClauses, query) + p.args = append(p.args, args...) + return p +} + +func (p *PgSQLSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery { + query = p.replacePlaceholders(query, len(args)) + p.orClauses = append(p.orClauses, query) + p.args = append(p.args, args...) + return p +} + +func (p *PgSQLSelectQuery) Join(query string, args ...interface{}) common.SelectQuery { + query = p.replacePlaceholders(query, len(args)) + p.joins = append(p.joins, "JOIN "+query) + p.args = append(p.args, args...) + return p +} + +func (p *PgSQLSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery { + query = p.replacePlaceholders(query, len(args)) + p.joins = append(p.joins, "LEFT JOIN "+query) + p.args = append(p.args, args...) + return p +} + +func (p *PgSQLSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery { + p.preloads = append(p.preloads, preloadConfig{ + relation: relation, + conditions: conditions, + useJoin: false, // Always use subquery for simple Preload + }) + return p +} + +func (p *PgSQLSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { + // Auto-detect relationship type and choose optimal loading strategy + var useJoin bool + if p.model != nil { + relType := reflection.GetRelationType(p.model, relation) + useJoin = relType.ShouldUseJoin() + logger.Debug("PreloadRelation '%s' detected as: %s (useJoin: %v)", relation, relType, useJoin) + } + + p.preloads = append(p.preloads, preloadConfig{ + relation: relation, + applyFuncs: apply, + useJoin: useJoin, + }) + return p +} + +func (p *PgSQLSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { + // Force JOIN loading + logger.Debug("JoinRelation '%s' - forcing JOIN strategy", relation) + p.preloads = append(p.preloads, preloadConfig{ + relation: relation, + applyFuncs: apply, + useJoin: true, // Force JOIN + }) + return p +} + +func (p *PgSQLSelectQuery) Order(order string) common.SelectQuery { + p.orderBy = append(p.orderBy, order) + return p +} + +func (p *PgSQLSelectQuery) Limit(n int) common.SelectQuery { + p.limit = n + return p +} + +func (p *PgSQLSelectQuery) Offset(n int) common.SelectQuery { + p.offset = n + return p +} + +func (p *PgSQLSelectQuery) Group(group string) common.SelectQuery { + p.groupBy = append(p.groupBy, group) + return p +} + +func (p *PgSQLSelectQuery) Having(having string, args ...interface{}) common.SelectQuery { + having = p.replacePlaceholders(having, len(args)) + p.havingClauses = append(p.havingClauses, having) + p.args = append(p.args, args...) + return p +} + +func (p *PgSQLSelectQuery) buildSQL() string { + var sb strings.Builder + + // SELECT clause + sb.WriteString("SELECT ") + if len(p.columns) > 0 || len(p.columnExprs) > 0 { + allCols := make([]string, 0) + allCols = append(allCols, p.columns...) + allCols = append(allCols, p.columnExprs...) + sb.WriteString(strings.Join(allCols, ", ")) + } else { + sb.WriteString("*") + } + + // FROM clause + if p.tableName != "" { + sb.WriteString(" FROM ") + sb.WriteString(p.tableName) + if p.tableAlias != "" { + sb.WriteString(" AS ") + sb.WriteString(p.tableAlias) + } + } + + // JOIN clauses + if len(p.joins) > 0 { + sb.WriteString(" ") + sb.WriteString(strings.Join(p.joins, " ")) + } + + // WHERE clause + if len(p.whereClauses) > 0 || len(p.orClauses) > 0 { + sb.WriteString(" WHERE ") + conditions := make([]string, 0) + + if len(p.whereClauses) > 0 { + conditions = append(conditions, "("+strings.Join(p.whereClauses, " AND ")+")") + } + if len(p.orClauses) > 0 { + conditions = append(conditions, "("+strings.Join(p.orClauses, " OR ")+")") + } + + sb.WriteString(strings.Join(conditions, " AND ")) + } + + // GROUP BY clause + if len(p.groupBy) > 0 { + sb.WriteString(" GROUP BY ") + sb.WriteString(strings.Join(p.groupBy, ", ")) + } + + // HAVING clause + if len(p.havingClauses) > 0 { + sb.WriteString(" HAVING ") + sb.WriteString(strings.Join(p.havingClauses, " AND ")) + } + + // ORDER BY clause + if len(p.orderBy) > 0 { + sb.WriteString(" ORDER BY ") + sb.WriteString(strings.Join(p.orderBy, ", ")) + } + + // LIMIT clause + if p.limit > 0 { + sb.WriteString(fmt.Sprintf(" LIMIT %d", p.limit)) + } + + // OFFSET clause + if p.offset > 0 { + sb.WriteString(fmt.Sprintf(" OFFSET %d", p.offset)) + } + + return sb.String() +} + +func (p *PgSQLSelectQuery) replacePlaceholders(query string, argCount int) string { + result := query + for i := 0; i < argCount; i++ { + p.paramCounter++ + placeholder := fmt.Sprintf("$%d", p.paramCounter) + result = strings.Replace(result, "?", placeholder, 1) + } + return result +} + +func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) { + defer func() { + if r := recover(); r != nil { + err = logger.HandlePanic("PgSQLSelectQuery.Scan", r) + } + }() + + // Apply preloads that use JOINs + p.applyJoinPreloads() + + query := p.buildSQL() + logger.Debug("PgSQL SELECT: %s [args: %v]", query, p.args) + + var rows *sql.Rows + if p.tx != nil { + rows, err = p.tx.QueryContext(ctx, query, p.args...) + } else { + rows, err = p.db.QueryContext(ctx, query, p.args...) + } + + if err != nil { + logger.Error("PgSQL SELECT failed: %v", err) + return err + } + defer rows.Close() + + err = scanRows(rows, dest) + if err != nil { + return err + } + + // Apply preloads that use separate queries + return p.applySubqueryPreloads(ctx, dest) +} + +func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error { + if p.model == nil { + return fmt.Errorf("ScanModel requires Model() to be set before scanning") + } + return p.Scan(ctx, p.model) +} + +func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) { + defer func() { + if r := recover(); r != nil { + err = logger.HandlePanic("PgSQLSelectQuery.Count", r) + count = 0 + } + }() + + // Build a COUNT query + var sb strings.Builder + sb.WriteString("SELECT COUNT(*) FROM ") + sb.WriteString(p.tableName) + + if len(p.joins) > 0 { + sb.WriteString(" ") + sb.WriteString(strings.Join(p.joins, " ")) + } + + if len(p.whereClauses) > 0 || len(p.orClauses) > 0 { + sb.WriteString(" WHERE ") + conditions := make([]string, 0) + + if len(p.whereClauses) > 0 { + conditions = append(conditions, "("+strings.Join(p.whereClauses, " AND ")+")") + } + if len(p.orClauses) > 0 { + conditions = append(conditions, "("+strings.Join(p.orClauses, " OR ")+")") + } + + sb.WriteString(strings.Join(conditions, " AND ")) + } + + query := sb.String() + logger.Debug("PgSQL COUNT: %s [args: %v]", query, p.args) + + var row *sql.Row + if p.tx != nil { + row = p.tx.QueryRowContext(ctx, query, p.args...) + } else { + row = p.db.QueryRowContext(ctx, query, p.args...) + } + + err = row.Scan(&count) + if err != nil { + logger.Error("PgSQL COUNT failed: %v", err) + } + return count, err +} + +func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error) { + defer func() { + if r := recover(); r != nil { + err = logger.HandlePanic("PgSQLSelectQuery.Exists", r) + exists = false + } + }() + + count, err := p.Count(ctx) + return count > 0, err +} + +// PgSQLInsertQuery implements InsertQuery for PostgreSQL +type PgSQLInsertQuery struct { + db *sql.DB + tx *sql.Tx + tableName string + values map[string]interface{} + returning []string +} + +func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery { + if provider, ok := model.(common.TableNameProvider); ok { + p.tableName = provider.TableName() + } + // Extract values from model using reflection + // This is a simplified implementation + return p +} + +func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery { + p.tableName = table + return p +} + +func (p *PgSQLInsertQuery) Value(column string, value interface{}) common.InsertQuery { + p.values[column] = value + return p +} + +func (p *PgSQLInsertQuery) OnConflict(action string) common.InsertQuery { + logger.Warn("OnConflict not yet implemented in PgSQL adapter") + return p +} + +func (p *PgSQLInsertQuery) Returning(columns ...string) common.InsertQuery { + p.returning = columns + return p +} + +func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err error) { + defer func() { + if r := recover(); r != nil { + err = logger.HandlePanic("PgSQLInsertQuery.Exec", r) + } + }() + + if len(p.values) == 0 { + return nil, fmt.Errorf("no values to insert") + } + + columns := make([]string, 0, len(p.values)) + placeholders := make([]string, 0, len(p.values)) + args := make([]interface{}, 0, len(p.values)) + + i := 1 + for col, val := range p.values { + columns = append(columns, col) + placeholders = append(placeholders, fmt.Sprintf("$%d", i)) + args = append(args, val) + i++ + } + + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", + p.tableName, + strings.Join(columns, ", "), + strings.Join(placeholders, ", ")) + + if len(p.returning) > 0 { + query += " RETURNING " + strings.Join(p.returning, ", ") + } + + logger.Debug("PgSQL INSERT: %s [args: %v]", query, args) + + var result sql.Result + if p.tx != nil { + result, err = p.tx.ExecContext(ctx, query, args...) + } else { + result, err = p.db.ExecContext(ctx, query, args...) + } + + if err != nil { + logger.Error("PgSQL INSERT failed: %v", err) + return nil, err + } + + return &PgSQLResult{result: result}, nil +} + +// PgSQLUpdateQuery implements UpdateQuery for PostgreSQL +type PgSQLUpdateQuery struct { + db *sql.DB + tx *sql.Tx + tableName string + model interface{} + sets map[string]interface{} + whereClauses []string + args []interface{} + paramCounter int + returning []string +} + +func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery { + p.model = model + if provider, ok := model.(common.TableNameProvider); ok { + p.tableName = provider.TableName() + } + return p +} + +func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery { + p.tableName = table + if p.model == nil { + model, err := modelregistry.GetModelByName(table) + if err == nil { + p.model = model + } + } + return p +} + +func (p *PgSQLUpdateQuery) Set(column string, value interface{}) common.UpdateQuery { + if p.model != nil && !reflection.IsColumnWritable(p.model, column) { + return p + } + p.sets[column] = value + return p +} + +func (p *PgSQLUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery { + pkName := "" + if p.model != nil { + pkName = reflection.GetPrimaryKeyName(p.model) + } + + for column, value := range values { + if pkName != "" && column == pkName { + continue + } + if p.model != nil && !reflection.IsColumnWritable(p.model, column) { + continue + } + p.sets[column] = value + } + return p +} + +func (p *PgSQLUpdateQuery) Where(query string, args ...interface{}) common.UpdateQuery { + query = p.replacePlaceholders(query, len(args)) + p.whereClauses = append(p.whereClauses, query) + p.args = append(p.args, args...) + return p +} + +func (p *PgSQLUpdateQuery) Returning(columns ...string) common.UpdateQuery { + p.returning = columns + return p +} + +func (p *PgSQLUpdateQuery) replacePlaceholders(query string, argCount int) string { + result := query + for i := 0; i < argCount; i++ { + p.paramCounter++ + placeholder := fmt.Sprintf("$%d", p.paramCounter) + result = strings.Replace(result, "?", placeholder, 1) + } + return result +} + +func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) { + defer func() { + if r := recover(); r != nil { + err = logger.HandlePanic("PgSQLUpdateQuery.Exec", r) + } + }() + + if len(p.sets) == 0 { + return nil, fmt.Errorf("no values to update") + } + + setClauses := make([]string, 0, len(p.sets)) + setArgs := make([]interface{}, 0, len(p.sets)) + + // SET parameters start at $1 + i := 1 + for col, val := range p.sets { + setClauses = append(setClauses, fmt.Sprintf("%s = $%d", col, i)) + setArgs = append(setArgs, val) + i++ + } + + query := fmt.Sprintf("UPDATE %s SET %s", + p.tableName, + strings.Join(setClauses, ", ")) + + // Update WHERE clause parameter numbers to continue after SET parameters + if len(p.whereClauses) > 0 { + updatedWhereClauses := make([]string, 0, len(p.whereClauses)) + for _, whereClause := range p.whereClauses { + // Find and replace parameter placeholders + updatedClause := whereClause + paramNum := i + // Count how many parameters are in this WHERE clause + placeholderCount := strings.Count(whereClause, "$") + for j := 0; j < placeholderCount; j++ { + oldParam := fmt.Sprintf("$%d", j+1) + newParam := fmt.Sprintf("$%d", paramNum) + updatedClause = strings.Replace(updatedClause, oldParam, newParam, 1) + paramNum++ + } + updatedWhereClauses = append(updatedWhereClauses, updatedClause) + i = paramNum + } + p.whereClauses = updatedWhereClauses + } + + // All arguments: SET values first, then WHERE values + // Create a new slice to avoid modifying setArgs + allArgs := make([]interface{}, len(setArgs)+len(p.args)) + copy(allArgs, setArgs) + copy(allArgs[len(setArgs):], p.args) + + if len(p.whereClauses) > 0 { + query += " WHERE " + strings.Join(p.whereClauses, " AND ") + } + + if len(p.returning) > 0 { + query += " RETURNING " + strings.Join(p.returning, ", ") + } + + logger.Debug("PgSQL UPDATE: %s [args: %v]", query, allArgs) + + var result sql.Result + if p.tx != nil { + result, err = p.tx.ExecContext(ctx, query, allArgs...) + } else { + result, err = p.db.ExecContext(ctx, query, allArgs...) + } + + if err != nil { + logger.Error("PgSQL UPDATE failed: %v", err) + return nil, err + } + + return &PgSQLResult{result: result}, nil +} + +// PgSQLDeleteQuery implements DeleteQuery for PostgreSQL +type PgSQLDeleteQuery struct { + db *sql.DB + tx *sql.Tx + tableName string + whereClauses []string + args []interface{} + paramCounter int +} + +func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery { + if provider, ok := model.(common.TableNameProvider); ok { + p.tableName = provider.TableName() + } + return p +} + +func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery { + p.tableName = table + return p +} + +func (p *PgSQLDeleteQuery) Where(query string, args ...interface{}) common.DeleteQuery { + query = p.replacePlaceholders(query, len(args)) + p.whereClauses = append(p.whereClauses, query) + p.args = append(p.args, args...) + return p +} + +func (p *PgSQLDeleteQuery) replacePlaceholders(query string, argCount int) string { + result := query + for i := 0; i < argCount; i++ { + p.paramCounter++ + placeholder := fmt.Sprintf("$%d", p.paramCounter) + result = strings.Replace(result, "?", placeholder, 1) + } + return result +} + +func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) { + defer func() { + if r := recover(); r != nil { + err = logger.HandlePanic("PgSQLDeleteQuery.Exec", r) + } + }() + + query := fmt.Sprintf("DELETE FROM %s", p.tableName) + + if len(p.whereClauses) > 0 { + query += " WHERE " + strings.Join(p.whereClauses, " AND ") + } + + logger.Debug("PgSQL DELETE: %s [args: %v]", query, p.args) + + var result sql.Result + if p.tx != nil { + result, err = p.tx.ExecContext(ctx, query, p.args...) + } else { + result, err = p.db.ExecContext(ctx, query, p.args...) + } + + if err != nil { + logger.Error("PgSQL DELETE failed: %v", err) + return nil, err + } + + return &PgSQLResult{result: result}, nil +} + +// PgSQLResult implements Result for PostgreSQL +type PgSQLResult struct { + result sql.Result +} + +func (p *PgSQLResult) RowsAffected() int64 { + if p.result == nil { + return 0 + } + rows, _ := p.result.RowsAffected() + return rows +} + +func (p *PgSQLResult) LastInsertId() (int64, error) { + if p.result == nil { + return 0, nil + } + return p.result.LastInsertId() +} + +// PgSQLTxAdapter wraps a PostgreSQL transaction +type PgSQLTxAdapter struct { + tx *sql.Tx +} + +func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery { + return &PgSQLSelectQuery{ + tx: p.tx, + columns: []string{"*"}, + args: make([]interface{}, 0), + } +} + +func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery { + return &PgSQLInsertQuery{ + tx: p.tx, + values: make(map[string]interface{}), + } +} + +func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery { + return &PgSQLUpdateQuery{ + tx: p.tx, + sets: make(map[string]interface{}), + args: make([]interface{}, 0), + whereClauses: make([]string, 0), + } +} + +func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery { + return &PgSQLDeleteQuery{ + tx: p.tx, + args: make([]interface{}, 0), + whereClauses: make([]string, 0), + } +} + +func (p *PgSQLTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) { + logger.Debug("PgSQL Tx Exec: %s [args: %v]", query, args) + result, err := p.tx.ExecContext(ctx, query, args...) + if err != nil { + logger.Error("PgSQL Tx Exec failed: %v", err) + return nil, err + } + return &PgSQLResult{result: result}, nil +} + +func (p *PgSQLTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + logger.Debug("PgSQL Tx Query: %s [args: %v]", query, args) + rows, err := p.tx.QueryContext(ctx, query, args...) + if err != nil { + logger.Error("PgSQL Tx Query failed: %v", err) + return err + } + defer rows.Close() + + return scanRows(rows, dest) +} + +func (p *PgSQLTxAdapter) BeginTx(ctx context.Context) (common.Database, error) { + return nil, fmt.Errorf("nested transactions not supported") +} + +func (p *PgSQLTxAdapter) CommitTx(ctx context.Context) error { + return p.tx.Commit() +} + +func (p *PgSQLTxAdapter) RollbackTx(ctx context.Context) error { + return p.tx.Rollback() +} + +func (p *PgSQLTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error { + return fn(p) +} + +// applyJoinPreloads adds JOINs for relationships that should use JOIN strategy +func (p *PgSQLSelectQuery) applyJoinPreloads() { + for _, preload := range p.preloads { + if !preload.useJoin { + continue + } + + // Build JOIN based on relationship metadata + meta := p.getRelationMetadata(preload.relation) + if meta == nil { + logger.Warn("Cannot determine relationship metadata for '%s'", preload.relation) + continue + } + + // Build the JOIN clause + relationAlias := strings.ToLower(preload.relation) + joinClause := fmt.Sprintf("%s AS %s ON %s.%s = %s.%s", + meta.targetTable, + relationAlias, + p.tableAlias, + meta.foreignKey, + relationAlias, + meta.targetKey, + ) + + logger.Debug("Adding LEFT JOIN for relation '%s': %s", preload.relation, joinClause) + p.joins = append(p.joins, "LEFT JOIN "+joinClause) + + // Apply any custom conditions through applyFuncs + // Note: These would need to be integrated into the WHERE clause + // For simplicity, we're logging a warning if custom conditions are present + if len(preload.applyFuncs) > 0 { + logger.Warn("Custom conditions in JoinRelation not yet fully implemented") + } + } +} + +// applySubqueryPreloads executes separate queries for has-many and many-to-many relationships +func (p *PgSQLSelectQuery) applySubqueryPreloads(ctx context.Context, dest interface{}) error { + // Get all preloads that don't use JOIN + subqueryPreloads := make([]preloadConfig, 0) + for _, preload := range p.preloads { + if !preload.useJoin { + subqueryPreloads = append(subqueryPreloads, preload) + } + } + + if len(subqueryPreloads) == 0 { + return nil + } + + // Use reflection to process the destination + destValue := reflect.ValueOf(dest) + if destValue.Kind() != reflect.Ptr { + return fmt.Errorf("dest must be a pointer") + } + + destValue = destValue.Elem() + + // Handle slice of structs + if destValue.Kind() == reflect.Slice { + for i := 0; i < destValue.Len(); i++ { + elem := destValue.Index(i) + if err := p.loadPreloadsForRecord(ctx, elem, subqueryPreloads); err != nil { + logger.Warn("Failed to load preloads for record %d: %v", i, err) + } + } + return nil + } + + // Handle single struct + if destValue.Kind() == reflect.Struct { + return p.loadPreloadsForRecord(ctx, destValue, subqueryPreloads) + } + + return nil +} + +// loadPreloadsForRecord loads all preload relationships for a single record +func (p *PgSQLSelectQuery) loadPreloadsForRecord(ctx context.Context, record reflect.Value, preloads []preloadConfig) error { + if record.Kind() == reflect.Ptr { + if record.IsNil() { + return nil + } + record = record.Elem() + } + + for _, preload := range preloads { + field := record.FieldByName(preload.relation) + if !field.IsValid() || !field.CanSet() { + logger.Warn("Field '%s' not found or cannot be set", preload.relation) + continue + } + + meta := p.getRelationMetadataFromField(record.Type(), preload.relation) + if meta == nil { + logger.Warn("Cannot determine relationship metadata for '%s'", preload.relation) + continue + } + + // Get the foreign key value from the parent record + fkField := record.FieldByName(meta.foreignKey) + if !fkField.IsValid() { + logger.Warn("Foreign key field '%s' not found", meta.foreignKey) + continue + } + + fkValue := fkField.Interface() + + // Build and execute the preload query + err := p.executePreloadQuery(ctx, field, meta, fkValue, preload) + if err != nil { + logger.Warn("Failed to execute preload query for '%s': %v", preload.relation, err) + } + } + + return nil +} + +// executePreloadQuery executes a query to load a relationship +func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflect.Value, meta *relationMetadata, fkValue interface{}, preload preloadConfig) error { + // Create a new select query for the related table + var db common.Database + if p.tx != nil { + db = &PgSQLTxAdapter{tx: p.tx} + } else { + db = &PgSQLAdapter{db: p.db} + } + + query := db.NewSelect(). + Table(meta.targetTable). + Where(fmt.Sprintf("%s = ?", meta.targetKey), fkValue) + + // Apply custom functions + for _, applyFunc := range preload.applyFuncs { + if applyFunc != nil { + query = applyFunc(query) + } + } + + // Determine if this is a slice (has-many) or single struct (belongs-to/has-one) + if field.Kind() == reflect.Slice { + // Create a new slice to hold results + sliceType := field.Type() + results := reflect.New(sliceType).Elem() + + // Execute query + err := query.Scan(ctx, results.Addr().Interface()) + if err != nil { + return err + } + + // Set the field + field.Set(results) + } else { + // Single struct - create a pointer if needed + var target reflect.Value + if field.Kind() == reflect.Ptr { + target = reflect.New(field.Type().Elem()) + } else { + target = reflect.New(field.Type()) + } + + // Execute query with LIMIT 1 + err := query.Limit(1).Scan(ctx, target.Interface()) + if err != nil && err != sql.ErrNoRows { + return err + } + + // Set the field + if field.Kind() == reflect.Ptr { + field.Set(target) + } else { + field.Set(target.Elem()) + } + } + + return nil +} + +// getRelationMetadata extracts relationship metadata from the model +func (p *PgSQLSelectQuery) getRelationMetadata(fieldName string) *relationMetadata { + if p.model == nil { + return nil + } + + modelType := reflect.TypeOf(p.model) + if modelType.Kind() == reflect.Ptr { + modelType = modelType.Elem() + } + + return p.getRelationMetadataFromField(modelType, fieldName) +} + +// getRelationMetadataFromField extracts relationship metadata from a type +func (p *PgSQLSelectQuery) getRelationMetadataFromField(modelType reflect.Type, fieldName string) *relationMetadata { + if modelType.Kind() != reflect.Struct { + return nil + } + + field, found := modelType.FieldByName(fieldName) + if !found { + return nil + } + + meta := &relationMetadata{ + fieldName: fieldName, + relationType: reflection.GetRelationType(reflect.New(modelType).Interface(), fieldName), + } + + // Parse struct tags to get foreign key and target table + bunTag := field.Tag.Get("bun") + if bunTag != "" { + // Parse bun tags: rel:has-many,join:user_id=id + parts := strings.Split(bunTag, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "join:") { + // Parse join condition: join:user_id=id + joinSpec := strings.TrimPrefix(part, "join:") + if strings.Contains(joinSpec, "=") { + keys := strings.Split(joinSpec, "=") + if len(keys) == 2 { + meta.foreignKey = strings.TrimSpace(keys[0]) + meta.targetKey = strings.TrimSpace(keys[1]) + } + } + } + } + } + + // Try to determine target table from field type + fieldType := field.Type + if fieldType.Kind() == reflect.Slice { + fieldType = fieldType.Elem() + } + if fieldType.Kind() == reflect.Ptr { + fieldType = fieldType.Elem() + } + + if fieldType.Kind() == reflect.Struct { + // Try to get table name from the related model + relatedModel := reflect.New(fieldType).Interface() + if provider, ok := relatedModel.(common.TableNameProvider); ok { + meta.targetTable = provider.TableName() + } + } + + // Set defaults if not found + if meta.foreignKey == "" { + meta.foreignKey = "id" + } + if meta.targetKey == "" { + meta.targetKey = "id" + } + + return meta +} + +// scanRows scans database rows into the destination using reflection +func scanRows(rows *sql.Rows, dest interface{}) error { + // Get column names + columns, err := rows.Columns() + if err != nil { + return err + } + + // Get destination type + destValue := reflect.ValueOf(dest) + if destValue.Kind() != reflect.Ptr { + return fmt.Errorf("dest must be a pointer") + } + + destValue = destValue.Elem() + + // Handle map slice: []map[string]interface{} + if destValue.Type() == reflect.TypeOf([]map[string]interface{}{}) { + return scanRowsToMapSlice(rows, columns, destValue) + } + + // Handle struct slice: []MyStruct or []*MyStruct + if destValue.Kind() == reflect.Slice { + return scanRowsToStructSlice(rows, columns, destValue) + } + + // Handle single struct: MyStruct or *MyStruct + if destValue.Kind() == reflect.Struct { + return scanRowsToSingleStruct(rows, columns, destValue) + } + + return fmt.Errorf("unsupported destination type: %T", dest) +} + +// scanRowsToMapSlice scans rows into []map[string]interface{} +func scanRowsToMapSlice(rows *sql.Rows, columns []string, destValue reflect.Value) error { + results := make([]map[string]interface{}, 0) + + for rows.Next() { + // Create holders for values + values := make([]interface{}, len(columns)) + valuePtrs := make([]interface{}, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + + err := rows.Scan(valuePtrs...) + if err != nil { + return err + } + + row := make(map[string]interface{}) + for i, col := range columns { + row[col] = values[i] + } + results = append(results, row) + } + + destValue.Set(reflect.ValueOf(results)) + return rows.Err() +} + +// scanRowsToStructSlice scans rows into a slice of structs +func scanRowsToStructSlice(rows *sql.Rows, columns []string, destValue reflect.Value) error { + elemType := destValue.Type().Elem() + isPtr := elemType.Kind() == reflect.Ptr + + if isPtr { + elemType = elemType.Elem() + } + + if elemType.Kind() != reflect.Struct { + return fmt.Errorf("slice element must be a struct, got %v", elemType.Kind()) + } + + // Build column-to-field mapping + fieldMap := buildFieldMap(elemType, columns) + + for rows.Next() { + // Create a new instance of the struct + elemValue := reflect.New(elemType).Elem() + + // Create scan targets + scanTargets := make([]interface{}, len(columns)) + for i, col := range columns { + if fieldInfo, ok := fieldMap[col]; ok { + field := elemValue.FieldByIndex(fieldInfo.Index) + if field.CanSet() { + scanTargets[i] = field.Addr().Interface() + continue + } + } + // Use a dummy variable for unmapped columns + var dummy interface{} + scanTargets[i] = &dummy + } + + err := rows.Scan(scanTargets...) + if err != nil { + return fmt.Errorf("scan failed: %w", err) + } + + // Append to slice + if isPtr { + destValue.Set(reflect.Append(destValue, elemValue.Addr())) + } else { + destValue.Set(reflect.Append(destValue, elemValue)) + } + } + + return rows.Err() +} + +// scanRowsToSingleStruct scans a single row into a struct +func scanRowsToSingleStruct(rows *sql.Rows, columns []string, destValue reflect.Value) error { + if !rows.Next() { + return sql.ErrNoRows + } + + // Build column-to-field mapping + fieldMap := buildFieldMap(destValue.Type(), columns) + + // Create scan targets + scanTargets := make([]interface{}, len(columns)) + for i, col := range columns { + if fieldInfo, ok := fieldMap[col]; ok { + field := destValue.FieldByIndex(fieldInfo.Index) + if field.CanSet() { + scanTargets[i] = field.Addr().Interface() + continue + } + } + // Use a dummy variable for unmapped columns + var dummy interface{} + scanTargets[i] = &dummy + } + + err := rows.Scan(scanTargets...) + if err != nil { + return fmt.Errorf("scan failed: %w", err) + } + + return rows.Err() +} + +// fieldInfo holds information about a struct field +type fieldInfo struct { + Index []int + Name string +} + +// buildFieldMap creates a mapping from column names to struct fields +func buildFieldMap(structType reflect.Type, _ []string) map[string]fieldInfo { + fieldMap := make(map[string]fieldInfo) + + // Iterate through struct fields + for i := 0; i < structType.NumField(); i++ { + field := structType.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Get column name from struct tag or field name + colName := field.Name + + // Check for bun tag + if bunTag := field.Tag.Get("bun"); bunTag != "" { + parts := strings.Split(bunTag, ",") + if len(parts) > 0 && parts[0] != "" && parts[0] != "-" { + colName = parts[0] + } + } + + // Check for db tag (common convention) + if dbTag := field.Tag.Get("db"); dbTag != "" && dbTag != "-" { + colName = dbTag + } + + // Convert to lowercase for case-insensitive matching + colNameLower := strings.ToLower(colName) + + fieldMap[colNameLower] = fieldInfo{ + Index: field.Index, + Name: field.Name, + } + + // Also map by exact field name + fieldMap[strings.ToLower(field.Name)] = fieldInfo{ + Index: field.Index, + Name: field.Name, + } + } + + return fieldMap +} diff --git a/pkg/common/adapters/database/pgsql_example.go b/pkg/common/adapters/database/pgsql_example.go new file mode 100644 index 0000000..a90c59f --- /dev/null +++ b/pkg/common/adapters/database/pgsql_example.go @@ -0,0 +1,176 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + + _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver + + "github.com/bitechdev/ResolveSpec/pkg/common" +) + +// Example demonstrates how to use the PgSQL adapter +func ExamplePgSQLAdapter() error { + // Connect to PostgreSQL database + dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable" + db, err := sql.Open("pgx", dsn) + if err != nil { + return fmt.Errorf("failed to open database: %w", err) + } + defer db.Close() + + // Create the PgSQL adapter + adapter := NewPgSQLAdapter(db) + + // Enable query debugging (optional) + adapter.EnableQueryDebug() + + ctx := context.Background() + + // Example 1: Simple SELECT query + var results []map[string]interface{} + err = adapter.NewSelect(). + Table("users"). + Where("age > ?", 18). + Order("created_at DESC"). + Limit(10). + Scan(ctx, &results) + if err != nil { + return fmt.Errorf("select failed: %w", err) + } + + // Example 2: INSERT query + result, err := adapter.NewInsert(). + Table("users"). + Value("name", "John Doe"). + Value("email", "john@example.com"). + Value("age", 25). + Returning("id"). + Exec(ctx) + if err != nil { + return fmt.Errorf("insert failed: %w", err) + } + fmt.Printf("Rows affected: %d\n", result.RowsAffected()) + + // Example 3: UPDATE query + result, err = adapter.NewUpdate(). + Table("users"). + Set("name", "Jane Doe"). + Where("id = ?", 1). + Exec(ctx) + if err != nil { + return fmt.Errorf("update failed: %w", err) + } + fmt.Printf("Rows updated: %d\n", result.RowsAffected()) + + // Example 4: DELETE query + result, err = adapter.NewDelete(). + Table("users"). + Where("age < ?", 18). + Exec(ctx) + if err != nil { + return fmt.Errorf("delete failed: %w", err) + } + fmt.Printf("Rows deleted: %d\n", result.RowsAffected()) + + // Example 5: Using transactions + err = adapter.RunInTransaction(ctx, func(tx common.Database) error { + // Insert a new user + _, err := tx.NewInsert(). + Table("users"). + Value("name", "Transaction User"). + Value("email", "tx@example.com"). + Exec(ctx) + if err != nil { + return err + } + + // Update another user + _, err = tx.NewUpdate(). + Table("users"). + Set("verified", true). + Where("email = ?", "tx@example.com"). + Exec(ctx) + if err != nil { + return err + } + + // Both operations succeed or both rollback + return nil + }) + if err != nil { + return fmt.Errorf("transaction failed: %w", err) + } + + // Example 6: JOIN query + err = adapter.NewSelect(). + Table("users u"). + Column("u.id", "u.name", "p.title as post_title"). + LeftJoin("posts p ON p.user_id = u.id"). + Where("u.active = ?", true). + Scan(ctx, &results) + if err != nil { + return fmt.Errorf("join query failed: %w", err) + } + + // Example 7: Aggregation query + count, err := adapter.NewSelect(). + Table("users"). + Where("active = ?", true). + Count(ctx) + if err != nil { + return fmt.Errorf("count failed: %w", err) + } + fmt.Printf("Active users: %d\n", count) + + // Example 8: Raw SQL execution + _, err = adapter.Exec(ctx, "CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)") + if err != nil { + return fmt.Errorf("raw exec failed: %w", err) + } + + // Example 9: Raw SQL query + var users []map[string]interface{} + err = adapter.Query(ctx, &users, "SELECT * FROM users WHERE age > $1 LIMIT $2", 18, 10) + if err != nil { + return fmt.Errorf("raw query failed: %w", err) + } + + return nil +} + +// User is an example model +type User struct { + ID int `json:"id"` + Name string `json:"name"` + Email string `json:"email"` + Age int `json:"age"` +} + +// TableName implements common.TableNameProvider +func (u User) TableName() string { + return "users" +} + +// ExampleWithModel demonstrates using models with the PgSQL adapter +func ExampleWithModel() error { + dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable" + db, err := sql.Open("pgx", dsn) + if err != nil { + return err + } + defer db.Close() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // Use model with adapter + user := User{} + err = adapter.NewSelect(). + Model(&user). + Where("id = ?", 1). + Scan(ctx, &user) + + return err +} diff --git a/pkg/common/adapters/database/pgsql_integration_test.go b/pkg/common/adapters/database/pgsql_integration_test.go new file mode 100644 index 0000000..226ceff --- /dev/null +++ b/pkg/common/adapters/database/pgsql_integration_test.go @@ -0,0 +1,526 @@ +// +build integration + +package database + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + _ "github.com/jackc/pgx/v5/stdlib" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/wait" +) + +// Integration test models +type IntegrationUser struct { + ID int `db:"id"` + Name string `db:"name"` + Email string `db:"email"` + Age int `db:"age"` + CreatedAt time.Time `db:"created_at"` + Posts []*IntegrationPost `bun:"rel:has-many,join:id=user_id"` +} + +func (u IntegrationUser) TableName() string { + return "users" +} + +type IntegrationPost struct { + ID int `db:"id"` + Title string `db:"title"` + Content string `db:"content"` + UserID int `db:"user_id"` + Published bool `db:"published"` + CreatedAt time.Time `db:"created_at"` + User *IntegrationUser `bun:"rel:belongs-to,join:user_id=id"` + Comments []*IntegrationComment `bun:"rel:has-many,join:id=post_id"` +} + +func (p IntegrationPost) TableName() string { + return "posts" +} + +type IntegrationComment struct { + ID int `db:"id"` + Content string `db:"content"` + PostID int `db:"post_id"` + CreatedAt time.Time `db:"created_at"` + Post *IntegrationPost `bun:"rel:belongs-to,join:post_id=id"` +} + +func (c IntegrationComment) TableName() string { + return "comments" +} + +// setupTestDB creates a PostgreSQL container and returns the connection +func setupTestDB(t *testing.T) (*sql.DB, func()) { + ctx := context.Background() + + req := testcontainers.ContainerRequest{ + Image: "postgres:15-alpine", + ExposedPorts: []string{"5432/tcp"}, + Env: map[string]string{ + "POSTGRES_USER": "testuser", + "POSTGRES_PASSWORD": "testpass", + "POSTGRES_DB": "testdb", + }, + WaitingFor: wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2). + WithStartupTimeout(60 * time.Second), + } + + postgres, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{ + ContainerRequest: req, + Started: true, + }) + require.NoError(t, err) + + host, err := postgres.Host(ctx) + require.NoError(t, err) + + port, err := postgres.MappedPort(ctx, "5432") + require.NoError(t, err) + + dsn := fmt.Sprintf("postgres://testuser:testpass@%s:%s/testdb?sslmode=disable", + host, port.Port()) + + db, err := sql.Open("pgx", dsn) + require.NoError(t, err) + + // Wait for database to be ready + err = db.Ping() + require.NoError(t, err) + + // Create schema + createSchema(t, db) + + cleanup := func() { + db.Close() + postgres.Terminate(ctx) + } + + return db, cleanup +} + +// createSchema creates test tables +func createSchema(t *testing.T, db *sql.DB) { + schema := ` + DROP TABLE IF EXISTS comments CASCADE; + DROP TABLE IF EXISTS posts CASCADE; + DROP TABLE IF EXISTS users CASCADE; + + CREATE TABLE users ( + id SERIAL PRIMARY KEY, + name VARCHAR(255) NOT NULL, + email VARCHAR(255) UNIQUE NOT NULL, + age INT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + + CREATE TABLE posts ( + id SERIAL PRIMARY KEY, + title VARCHAR(255) NOT NULL, + content TEXT NOT NULL, + user_id INT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + published BOOLEAN DEFAULT false, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + + CREATE TABLE comments ( + id SERIAL PRIMARY KEY, + content TEXT NOT NULL, + post_id INT NOT NULL REFERENCES posts(id) ON DELETE CASCADE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + ` + + _, err := db.Exec(schema) + require.NoError(t, err) +} + +// TestIntegration_BasicCRUD tests basic CRUD operations +func TestIntegration_BasicCRUD(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // CREATE + result, err := adapter.NewInsert(). + Table("users"). + Value("name", "John Doe"). + Value("email", "john@example.com"). + Value("age", 25). + Exec(ctx) + + require.NoError(t, err) + assert.Equal(t, int64(1), result.RowsAffected()) + + // READ + var users []IntegrationUser + err = adapter.NewSelect(). + Table("users"). + Where("email = ?", "john@example.com"). + Scan(ctx, &users) + + require.NoError(t, err) + assert.Len(t, users, 1) + assert.Equal(t, "John Doe", users[0].Name) + assert.Equal(t, 25, users[0].Age) + + userID := users[0].ID + + // UPDATE + result, err = adapter.NewUpdate(). + Table("users"). + Set("age", 26). + Where("id = ?", userID). + Exec(ctx) + + require.NoError(t, err) + assert.Equal(t, int64(1), result.RowsAffected()) + + // Verify update + var updatedUser IntegrationUser + err = adapter.NewSelect(). + Table("users"). + Where("id = ?", userID). + Scan(ctx, &updatedUser) + + require.NoError(t, err) + assert.Equal(t, 26, updatedUser.Age) + + // DELETE + result, err = adapter.NewDelete(). + Table("users"). + Where("id = ?", userID). + Exec(ctx) + + require.NoError(t, err) + assert.Equal(t, int64(1), result.RowsAffected()) + + // Verify delete + count, err := adapter.NewSelect(). + Table("users"). + Where("id = ?", userID). + Count(ctx) + + require.NoError(t, err) + assert.Equal(t, 0, count) +} + +// TestIntegration_ScanModel tests ScanModel functionality +func TestIntegration_ScanModel(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // Insert test data + _, err := adapter.NewInsert(). + Table("users"). + Value("name", "Jane Smith"). + Value("email", "jane@example.com"). + Value("age", 30). + Exec(ctx) + require.NoError(t, err) + + // Test single struct scan + user := &IntegrationUser{} + err = adapter.NewSelect(). + Model(user). + Table("users"). + Where("email = ?", "jane@example.com"). + ScanModel(ctx) + + require.NoError(t, err) + assert.Equal(t, "Jane Smith", user.Name) + assert.Equal(t, 30, user.Age) + + // Test slice scan + users := []*IntegrationUser{} + err = adapter.NewSelect(). + Model(&users). + Table("users"). + ScanModel(ctx) + + require.NoError(t, err) + assert.Len(t, users, 1) +} + +// TestIntegration_Transaction tests transaction handling +func TestIntegration_Transaction(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // Successful transaction + err := adapter.RunInTransaction(ctx, func(tx common.Database) error { + _, err := tx.NewInsert(). + Table("users"). + Value("name", "Alice"). + Value("email", "alice@example.com"). + Value("age", 28). + Exec(ctx) + if err != nil { + return err + } + + _, err = tx.NewInsert(). + Table("users"). + Value("name", "Bob"). + Value("email", "bob@example.com"). + Value("age", 32). + Exec(ctx) + return err + }) + + require.NoError(t, err) + + // Verify both records exist + count, err := adapter.NewSelect(). + Table("users"). + Count(ctx) + require.NoError(t, err) + assert.Equal(t, 2, count) + + // Failed transaction (should rollback) + err = adapter.RunInTransaction(ctx, func(tx common.Database) error { + _, err := tx.NewInsert(). + Table("users"). + Value("name", "Charlie"). + Value("email", "charlie@example.com"). + Value("age", 35). + Exec(ctx) + if err != nil { + return err + } + + // Intentional error - duplicate email + _, err = tx.NewInsert(). + Table("users"). + Value("name", "David"). + Value("email", "alice@example.com"). // Duplicate + Value("age", 40). + Exec(ctx) + return err + }) + + assert.Error(t, err) + + // Verify rollback - count should still be 2 + count, err = adapter.NewSelect(). + Table("users"). + Count(ctx) + require.NoError(t, err) + assert.Equal(t, 2, count) +} + +// TestIntegration_Preload tests basic preload functionality +func TestIntegration_Preload(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // Create test data + userID := createTestUser(t, adapter, ctx, "John Doe", "john@example.com", 25) + createTestPost(t, adapter, ctx, userID, "First Post", "Content 1", true) + createTestPost(t, adapter, ctx, userID, "Second Post", "Content 2", false) + + // Test Preload + var users []*IntegrationUser + err := adapter.NewSelect(). + Model(&IntegrationUser{}). + Table("users"). + Preload("Posts"). + Scan(ctx, &users) + + require.NoError(t, err) + assert.Len(t, users, 1) + assert.NotNil(t, users[0].Posts) + assert.Len(t, users[0].Posts, 2) +} + +// TestIntegration_PreloadRelation tests smart PreloadRelation +func TestIntegration_PreloadRelation(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // Create test data + userID := createTestUser(t, adapter, ctx, "Jane Smith", "jane@example.com", 30) + postID := createTestPost(t, adapter, ctx, userID, "Test Post", "Test Content", true) + createTestComment(t, adapter, ctx, postID, "Great post!") + createTestComment(t, adapter, ctx, postID, "Thanks for sharing!") + + // Test PreloadRelation with belongs-to (should use JOIN) + var posts []*IntegrationPost + err := adapter.NewSelect(). + Model(&IntegrationPost{}). + Table("posts"). + PreloadRelation("User"). + Scan(ctx, &posts) + + require.NoError(t, err) + assert.Len(t, posts, 1) + // Note: JOIN preloading needs proper column selection to work + // For now, we test that it doesn't error + + // Test PreloadRelation with has-many (should use subquery) + posts = []*IntegrationPost{} + err = adapter.NewSelect(). + Model(&IntegrationPost{}). + Table("posts"). + PreloadRelation("Comments"). + Scan(ctx, &posts) + + require.NoError(t, err) + assert.Len(t, posts, 1) + if posts[0].Comments != nil { + assert.Len(t, posts[0].Comments, 2) + } +} + +// TestIntegration_JoinRelation tests explicit JoinRelation +func TestIntegration_JoinRelation(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // Create test data + userID := createTestUser(t, adapter, ctx, "Bob Wilson", "bob@example.com", 35) + createTestPost(t, adapter, ctx, userID, "Join Test", "Content", true) + + // Test JoinRelation + var posts []*IntegrationPost + err := adapter.NewSelect(). + Model(&IntegrationPost{}). + Table("posts"). + JoinRelation("User"). + Scan(ctx, &posts) + + require.NoError(t, err) + assert.Len(t, posts, 1) +} + +// TestIntegration_ComplexQuery tests complex queries +func TestIntegration_ComplexQuery(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // Create test data + userID1 := createTestUser(t, adapter, ctx, "Alice", "alice@example.com", 25) + userID2 := createTestUser(t, adapter, ctx, "Bob", "bob@example.com", 30) + userID3 := createTestUser(t, adapter, ctx, "Charlie", "charlie@example.com", 35) + + createTestPost(t, adapter, ctx, userID1, "Post 1", "Content", true) + createTestPost(t, adapter, ctx, userID2, "Post 2", "Content", true) + createTestPost(t, adapter, ctx, userID3, "Post 3", "Content", false) + + // Complex query with joins, where, order, limit + var results []map[string]interface{} + err := adapter.NewSelect(). + Table("posts p"). + Column("p.title", "u.name as author_name", "u.age as author_age"). + LeftJoin("users u ON u.id = p.user_id"). + Where("p.published = ?", true). + WhereOr("u.age > ?", 25). + Order("u.age DESC"). + Limit(2). + Scan(ctx, &results) + + require.NoError(t, err) + assert.LessOrEqual(t, len(results), 2) +} + +// TestIntegration_Aggregation tests aggregation queries +func TestIntegration_Aggregation(t *testing.T) { + db, cleanup := setupTestDB(t) + defer cleanup() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // Create test data + createTestUser(t, adapter, ctx, "User 1", "user1@example.com", 20) + createTestUser(t, adapter, ctx, "User 2", "user2@example.com", 25) + createTestUser(t, adapter, ctx, "User 3", "user3@example.com", 30) + + // Test Count + count, err := adapter.NewSelect(). + Table("users"). + Where("age >= ?", 25). + Count(ctx) + + require.NoError(t, err) + assert.Equal(t, 2, count) + + // Test Exists + exists, err := adapter.NewSelect(). + Table("users"). + Where("email = ?", "user1@example.com"). + Exists(ctx) + + require.NoError(t, err) + assert.True(t, exists) + + // Test Group By with aggregation + var results []map[string]interface{} + err = adapter.NewSelect(). + Table("users"). + Column("age", "COUNT(*) as count"). + Group("age"). + Having("COUNT(*) > ?", 0). + Order("age ASC"). + Scan(ctx, &results) + + require.NoError(t, err) + assert.Len(t, results, 3) +} + +// Helper functions + +func createTestUser(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, name, email string, age int) int { + var userID int + err := adapter.Query(ctx, &userID, + "INSERT INTO users (name, email, age) VALUES ($1, $2, $3) RETURNING id", + name, email, age) + require.NoError(t, err) + return userID +} + +func createTestPost(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, userID int, title, content string, published bool) int { + var postID int + err := adapter.Query(ctx, &postID, + "INSERT INTO posts (title, content, user_id, published) VALUES ($1, $2, $3, $4) RETURNING id", + title, content, userID, published) + require.NoError(t, err) + return postID +} + +func createTestComment(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, postID int, content string) int { + var commentID int + err := adapter.Query(ctx, &commentID, + "INSERT INTO comments (content, post_id) VALUES ($1, $2) RETURNING id", + content, postID) + require.NoError(t, err) + return commentID +} diff --git a/pkg/common/adapters/database/pgsql_preload_example.go b/pkg/common/adapters/database/pgsql_preload_example.go new file mode 100644 index 0000000..b3035d0 --- /dev/null +++ b/pkg/common/adapters/database/pgsql_preload_example.go @@ -0,0 +1,275 @@ +package database + +import ( + "context" + "database/sql" + + _ "github.com/jackc/pgx/v5/stdlib" + + "github.com/bitechdev/ResolveSpec/pkg/common" +) + +// Example models for demonstrating preload functionality + +// Author model - has many Posts +type Author struct { + ID int `db:"id"` + Name string `db:"name"` + Email string `db:"email"` + Posts []*Post `bun:"rel:has-many,join:id=author_id"` +} + +func (a Author) TableName() string { + return "authors" +} + +// Post model - belongs to Author, has many Comments +type Post struct { + ID int `db:"id"` + Title string `db:"title"` + Content string `db:"content"` + AuthorID int `db:"author_id"` + Author *Author `bun:"rel:belongs-to,join:author_id=id"` + Comments []*Comment `bun:"rel:has-many,join:id=post_id"` +} + +func (p Post) TableName() string { + return "posts" +} + +// Comment model - belongs to Post +type Comment struct { + ID int `db:"id"` + Content string `db:"content"` + PostID int `db:"post_id"` + Post *Post `bun:"rel:belongs-to,join:post_id=id"` +} + +func (c Comment) TableName() string { + return "comments" +} + +// ExamplePreload demonstrates the Preload functionality +func ExamplePreload() error { + dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable" + db, err := sql.Open("pgx", dsn) + if err != nil { + return err + } + defer db.Close() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // Example 1: Simple Preload (uses subquery for has-many) + var authors []*Author + err = adapter.NewSelect(). + Model(&Author{}). + Table("authors"). + Preload("Posts"). // Load all posts for each author + Scan(ctx, &authors) + if err != nil { + return err + } + + // Now authors[i].Posts will be populated with their posts + + return nil +} + +// ExamplePreloadRelation demonstrates smart PreloadRelation with auto-detection +func ExamplePreloadRelation() error { + dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable" + db, err := sql.Open("pgx", dsn) + if err != nil { + return err + } + defer db.Close() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // Example 1: PreloadRelation auto-detects has-many (uses subquery) + var authors []*Author + err = adapter.NewSelect(). + Model(&Author{}). + Table("authors"). + PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery { + return q.Where("published = ?", true).Order("created_at DESC") + }). + Where("active = ?", true). + Scan(ctx, &authors) + if err != nil { + return err + } + + // Example 2: PreloadRelation auto-detects belongs-to (uses JOIN) + var posts []*Post + err = adapter.NewSelect(). + Model(&Post{}). + Table("posts"). + PreloadRelation("Author"). // Will use JOIN because it's belongs-to + Scan(ctx, &posts) + if err != nil { + return err + } + + // Example 3: Nested preloads + err = adapter.NewSelect(). + Model(&Author{}). + Table("authors"). + PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery { + // First load posts, then preload comments for each post + return q.Limit(10) + }). + Scan(ctx, &authors) + if err != nil { + return err + } + + // Manually load nested relationships (two-level preloading) + for _, author := range authors { + if author.Posts != nil { + for _, post := range author.Posts { + var comments []*Comment + err := adapter.NewSelect(). + Table("comments"). + Where("post_id = ?", post.ID). + Scan(ctx, &comments) + if err == nil { + post.Comments = comments + } + } + } + } + + return nil +} + +// ExampleJoinRelation demonstrates explicit JOIN loading +func ExampleJoinRelation() error { + dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable" + db, err := sql.Open("pgx", dsn) + if err != nil { + return err + } + defer db.Close() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // Example 1: Force JOIN for belongs-to relationship + var posts []*Post + err = adapter.NewSelect(). + Model(&Post{}). + Table("posts"). + JoinRelation("Author", func(q common.SelectQuery) common.SelectQuery { + return q.Where("active = ?", true) + }). + Scan(ctx, &posts) + if err != nil { + return err + } + + // Example 2: Multiple JOINs + err = adapter.NewSelect(). + Model(&Post{}). + Table("posts p"). + Column("p.*", "a.name as author_name", "a.email as author_email"). + LeftJoin("authors a ON a.id = p.author_id"). + Where("p.published = ?", true). + Scan(ctx, &posts) + + return err +} + +// ExampleScanModel demonstrates ScanModel with struct destinations +func ExampleScanModel() error { + dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable" + db, err := sql.Open("pgx", dsn) + if err != nil { + return err + } + defer db.Close() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + // Example 1: Scan single struct + author := Author{} + err = adapter.NewSelect(). + Model(&author). + Table("authors"). + Where("id = ?", 1). + ScanModel(ctx) // ScanModel automatically uses the model set with Model() + + if err != nil { + return err + } + + // Example 2: Scan slice of structs + authors := []*Author{} + err = adapter.NewSelect(). + Model(&authors). + Table("authors"). + Where("active = ?", true). + Limit(10). + ScanModel(ctx) + + return err +} + +// ExampleCompleteWorkflow demonstrates a complete workflow with preloading +func ExampleCompleteWorkflow() error { + dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable" + db, err := sql.Open("pgx", dsn) + if err != nil { + return err + } + defer db.Close() + + adapter := NewPgSQLAdapter(db) + adapter.EnableQueryDebug() // Enable query logging + ctx := context.Background() + + // Step 1: Create an author + author := &Author{ + Name: "John Doe", + Email: "john@example.com", + } + + result, err := adapter.NewInsert(). + Table("authors"). + Value("name", author.Name). + Value("email", author.Email). + Returning("id"). + Exec(ctx) + if err != nil { + return err + } + + _ = result + + // Step 2: Load author with all their posts + var loadedAuthor Author + err = adapter.NewSelect(). + Model(&loadedAuthor). + Table("authors"). + PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery { + return q.Order("created_at DESC").Limit(5) + }). + Where("id = ?", 1). + ScanModel(ctx) + if err != nil { + return err + } + + // Step 3: Update author name + _, err = adapter.NewUpdate(). + Table("authors"). + Set("name", "Jane Doe"). + Where("id = ?", 1). + Exec(ctx) + + return err +} diff --git a/pkg/common/adapters/database/pgsql_test.go b/pkg/common/adapters/database/pgsql_test.go new file mode 100644 index 0000000..e5b5192 --- /dev/null +++ b/pkg/common/adapters/database/pgsql_test.go @@ -0,0 +1,629 @@ +package database + +import ( + "context" + "database/sql" + "reflect" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/bitechdev/ResolveSpec/pkg/common" +) + +// Test models +type TestUser struct { + ID int `db:"id"` + Name string `db:"name"` + Email string `db:"email"` + Age int `db:"age"` +} + +func (u TestUser) TableName() string { + return "users" +} + +type TestPost struct { + ID int `db:"id"` + Title string `db:"title"` + Content string `db:"content"` + UserID int `db:"user_id"` + User *TestUser `bun:"rel:belongs-to,join:user_id=id"` + Comments []TestComment `bun:"rel:has-many,join:id=post_id"` +} + +func (p TestPost) TableName() string { + return "posts" +} + +type TestComment struct { + ID int `db:"id"` + Content string `db:"content"` + PostID int `db:"post_id"` +} + +func (c TestComment) TableName() string { + return "comments" +} + +// TestNewPgSQLAdapter tests adapter creation +func TestNewPgSQLAdapter(t *testing.T) { + db, _, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + adapter := NewPgSQLAdapter(db) + assert.NotNil(t, adapter) + assert.Equal(t, db, adapter.db) +} + +// TestPgSQLSelectQuery_BuildSQL tests SQL query building +func TestPgSQLSelectQuery_BuildSQL(t *testing.T) { + tests := []struct { + name string + setup func(*PgSQLSelectQuery) + expected string + }{ + { + name: "simple select", + setup: func(q *PgSQLSelectQuery) { + q.tableName = "users" + }, + expected: "SELECT * FROM users", + }, + { + name: "select with columns", + setup: func(q *PgSQLSelectQuery) { + q.tableName = "users" + q.columns = []string{"id", "name", "email"} + }, + expected: "SELECT id, name, email FROM users", + }, + { + name: "select with where", + setup: func(q *PgSQLSelectQuery) { + q.tableName = "users" + q.whereClauses = []string{"age > $1"} + q.args = []interface{}{18} + }, + expected: "SELECT * FROM users WHERE (age > $1)", + }, + { + name: "select with order and limit", + setup: func(q *PgSQLSelectQuery) { + q.tableName = "users" + q.orderBy = []string{"created_at DESC"} + q.limit = 10 + q.offset = 5 + }, + expected: "SELECT * FROM users ORDER BY created_at DESC LIMIT 10 OFFSET 5", + }, + { + name: "select with join", + setup: func(q *PgSQLSelectQuery) { + q.tableName = "users" + q.joins = []string{"LEFT JOIN posts ON posts.user_id = users.id"} + }, + expected: "SELECT * FROM users LEFT JOIN posts ON posts.user_id = users.id", + }, + { + name: "select with group and having", + setup: func(q *PgSQLSelectQuery) { + q.tableName = "users" + q.groupBy = []string{"country"} + q.havingClauses = []string{"COUNT(*) > $1"} + q.args = []interface{}{5} + }, + expected: "SELECT * FROM users GROUP BY country HAVING COUNT(*) > $1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := &PgSQLSelectQuery{ + columns: []string{"*"}, + } + tt.setup(q) + sql := q.buildSQL() + assert.Equal(t, tt.expected, sql) + }) + } +} + +// TestPgSQLSelectQuery_ReplacePlaceholders tests placeholder replacement +func TestPgSQLSelectQuery_ReplacePlaceholders(t *testing.T) { + tests := []struct { + name string + query string + argCount int + paramCounter int + expected string + }{ + { + name: "single placeholder", + query: "age > ?", + argCount: 1, + paramCounter: 0, + expected: "age > $1", + }, + { + name: "multiple placeholders", + query: "age > ? AND status = ?", + argCount: 2, + paramCounter: 0, + expected: "age > $1 AND status = $2", + }, + { + name: "with existing counter", + query: "name = ?", + argCount: 1, + paramCounter: 5, + expected: "name = $6", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := &PgSQLSelectQuery{paramCounter: tt.paramCounter} + result := q.replacePlaceholders(tt.query, tt.argCount) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestPgSQLSelectQuery_Chaining tests method chaining +func TestPgSQLSelectQuery_Chaining(t *testing.T) { + db, _, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + adapter := NewPgSQLAdapter(db) + query := adapter.NewSelect(). + Table("users"). + Column("id", "name"). + Where("age > ?", 18). + Order("name ASC"). + Limit(10). + Offset(5) + + pgQuery := query.(*PgSQLSelectQuery) + assert.Equal(t, "users", pgQuery.tableName) + assert.Equal(t, []string{"id", "name"}, pgQuery.columns) + assert.Len(t, pgQuery.whereClauses, 1) + assert.Equal(t, []string{"name ASC"}, pgQuery.orderBy) + assert.Equal(t, 10, pgQuery.limit) + assert.Equal(t, 5, pgQuery.offset) +} + +// TestPgSQLSelectQuery_Model tests model setting +func TestPgSQLSelectQuery_Model(t *testing.T) { + db, _, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + adapter := NewPgSQLAdapter(db) + user := &TestUser{} + query := adapter.NewSelect().Model(user) + + pgQuery := query.(*PgSQLSelectQuery) + assert.Equal(t, "users", pgQuery.tableName) + assert.Equal(t, user, pgQuery.model) +} + +// TestScanRowsToStructSlice tests scanning rows into struct slice +func TestScanRowsToStructSlice(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}). + AddRow(1, "John Doe", "john@example.com", 25). + AddRow(2, "Jane Smith", "jane@example.com", 30) + + mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows) + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + var users []TestUser + err = adapter.NewSelect(). + Table("users"). + Scan(ctx, &users) + + require.NoError(t, err) + assert.Len(t, users, 2) + assert.Equal(t, "John Doe", users[0].Name) + assert.Equal(t, "jane@example.com", users[1].Email) + assert.Equal(t, 30, users[1].Age) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +// TestScanRowsToStructSlicePointers tests scanning rows into pointer slice +func TestScanRowsToStructSlicePointers(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}). + AddRow(1, "John Doe", "john@example.com", 25) + + mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows) + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + var users []*TestUser + err = adapter.NewSelect(). + Table("users"). + Scan(ctx, &users) + + require.NoError(t, err) + assert.Len(t, users, 1) + assert.NotNil(t, users[0]) + assert.Equal(t, "John Doe", users[0].Name) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +// TestScanRowsToSingleStruct tests scanning a single row +func TestScanRowsToSingleStruct(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}). + AddRow(1, "John Doe", "john@example.com", 25) + + mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows) + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + var user TestUser + err = adapter.NewSelect(). + Table("users"). + Where("id = ?", 1). + Scan(ctx, &user) + + require.NoError(t, err) + assert.Equal(t, 1, user.ID) + assert.Equal(t, "John Doe", user.Name) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +// TestScanRowsToMapSlice tests scanning into map slice +func TestScanRowsToMapSlice(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + rows := sqlmock.NewRows([]string{"id", "name", "email"}). + AddRow(1, "John Doe", "john@example.com"). + AddRow(2, "Jane Smith", "jane@example.com") + + mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows) + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + var results []map[string]interface{} + err = adapter.NewSelect(). + Table("users"). + Scan(ctx, &results) + + require.NoError(t, err) + assert.Len(t, results, 2) + assert.Equal(t, int64(1), results[0]["id"]) + assert.Equal(t, "John Doe", results[0]["name"]) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +// TestPgSQLInsertQuery_Exec tests insert query execution +func TestPgSQLInsertQuery_Exec(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + mock.ExpectExec("INSERT INTO users"). + WithArgs("John Doe", "john@example.com", 25). + WillReturnResult(sqlmock.NewResult(1, 1)) + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + result, err := adapter.NewInsert(). + Table("users"). + Value("name", "John Doe"). + Value("email", "john@example.com"). + Value("age", 25). + Exec(ctx) + + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, int64(1), result.RowsAffected()) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +// TestPgSQLUpdateQuery_Exec tests update query execution +func TestPgSQLUpdateQuery_Exec(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + // Note: Args order is SET values first, then WHERE values + mock.ExpectExec("UPDATE users SET name = \\$1 WHERE id = \\$2"). + WithArgs("Jane Doe", 1). + WillReturnResult(sqlmock.NewResult(0, 1)) + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + result, err := adapter.NewUpdate(). + Table("users"). + Set("name", "Jane Doe"). + Where("id = ?", 1). + Exec(ctx) + + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, int64(1), result.RowsAffected()) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +// TestPgSQLDeleteQuery_Exec tests delete query execution +func TestPgSQLDeleteQuery_Exec(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + mock.ExpectExec("DELETE FROM users WHERE id = \\$1"). + WithArgs(1). + WillReturnResult(sqlmock.NewResult(0, 1)) + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + result, err := adapter.NewDelete(). + Table("users"). + Where("id = ?", 1). + Exec(ctx) + + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, int64(1), result.RowsAffected()) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +// TestPgSQLSelectQuery_Count tests count query +func TestPgSQLSelectQuery_Count(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + rows := sqlmock.NewRows([]string{"count"}).AddRow(42) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows) + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + count, err := adapter.NewSelect(). + Table("users"). + Count(ctx) + + require.NoError(t, err) + assert.Equal(t, 42, count) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +// TestPgSQLSelectQuery_Exists tests exists query +func TestPgSQLSelectQuery_Exists(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + rows := sqlmock.NewRows([]string{"count"}).AddRow(1) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows) + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + exists, err := adapter.NewSelect(). + Table("users"). + Where("email = ?", "john@example.com"). + Exists(ctx) + + require.NoError(t, err) + assert.True(t, exists) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +// TestPgSQLAdapter_Transaction tests transaction handling +func TestPgSQLAdapter_Transaction(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO users").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + err = adapter.RunInTransaction(ctx, func(tx common.Database) error { + _, err := tx.NewInsert(). + Table("users"). + Value("name", "John"). + Exec(ctx) + return err + }) + + require.NoError(t, err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +// TestPgSQLAdapter_TransactionRollback tests transaction rollback +func TestPgSQLAdapter_TransactionRollback(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + mock.ExpectBegin() + mock.ExpectExec("INSERT INTO users").WillReturnError(sql.ErrConnDone) + mock.ExpectRollback() + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + err = adapter.RunInTransaction(ctx, func(tx common.Database) error { + _, err := tx.NewInsert(). + Table("users"). + Value("name", "John"). + Exec(ctx) + return err + }) + + assert.Error(t, err) + assert.NoError(t, mock.ExpectationsWereMet()) +} + +// TestBuildFieldMap tests field mapping construction +func TestBuildFieldMap(t *testing.T) { + userType := reflect.TypeOf(TestUser{}) + fieldMap := buildFieldMap(userType, nil) + + assert.NotEmpty(t, fieldMap) + + // Check that fields are mapped + assert.Contains(t, fieldMap, "id") + assert.Contains(t, fieldMap, "name") + assert.Contains(t, fieldMap, "email") + assert.Contains(t, fieldMap, "age") + + // Check field info + idInfo := fieldMap["id"] + assert.Equal(t, "ID", idInfo.Name) +} + +// TestGetRelationMetadata tests relationship metadata extraction +func TestGetRelationMetadata(t *testing.T) { + q := &PgSQLSelectQuery{ + model: &TestPost{}, + } + + // Test belongs-to relationship + meta := q.getRelationMetadata("User") + assert.NotNil(t, meta) + assert.Equal(t, "User", meta.fieldName) + + // Test has-many relationship + meta = q.getRelationMetadata("Comments") + assert.NotNil(t, meta) + assert.Equal(t, "Comments", meta.fieldName) +} + +// TestPreloadConfiguration tests preload configuration +func TestPreloadConfiguration(t *testing.T) { + db, _, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + adapter := NewPgSQLAdapter(db) + + // Test Preload + query := adapter.NewSelect(). + Model(&TestPost{}). + Table("posts"). + Preload("User") + + pgQuery := query.(*PgSQLSelectQuery) + assert.Len(t, pgQuery.preloads, 1) + assert.Equal(t, "User", pgQuery.preloads[0].relation) + assert.False(t, pgQuery.preloads[0].useJoin) + + // Test PreloadRelation + query = adapter.NewSelect(). + Model(&TestPost{}). + Table("posts"). + PreloadRelation("Comments") + + pgQuery = query.(*PgSQLSelectQuery) + assert.Len(t, pgQuery.preloads, 1) + assert.Equal(t, "Comments", pgQuery.preloads[0].relation) + + // Test JoinRelation + query = adapter.NewSelect(). + Model(&TestPost{}). + Table("posts"). + JoinRelation("User") + + pgQuery = query.(*PgSQLSelectQuery) + assert.Len(t, pgQuery.preloads, 1) + assert.Equal(t, "User", pgQuery.preloads[0].relation) + assert.True(t, pgQuery.preloads[0].useJoin) +} + +// TestScanModel tests ScanModel functionality +func TestScanModel(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}). + AddRow(1, "John Doe", "john@example.com", 25) + + mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows) + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + user := &TestUser{} + err = adapter.NewSelect(). + Model(user). + Table("users"). + Where("id = ?", 1). + ScanModel(ctx) + + require.NoError(t, err) + assert.Equal(t, 1, user.ID) + assert.Equal(t, "John Doe", user.Name) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +// TestRawSQL tests raw SQL execution +func TestRawSQL(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer db.Close() + + // Test Exec + mock.ExpectExec("CREATE TABLE test").WillReturnResult(sqlmock.NewResult(0, 0)) + + adapter := NewPgSQLAdapter(db) + ctx := context.Background() + + _, err = adapter.Exec(ctx, "CREATE TABLE test (id INT)") + require.NoError(t, err) + + // Test Query + rows := sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "Test") + mock.ExpectQuery("SELECT (.+) FROM test").WillReturnRows(rows) + + var results []map[string]interface{} + err = adapter.Query(ctx, &results, "SELECT * FROM test WHERE id = $1", 1) + require.NoError(t, err) + assert.Len(t, results, 1) + + assert.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/pkg/common/adapters/database/test_helpers.go b/pkg/common/adapters/database/test_helpers.go new file mode 100644 index 0000000..8a5470e --- /dev/null +++ b/pkg/common/adapters/database/test_helpers.go @@ -0,0 +1,132 @@ +package database + +import ( + "context" + "database/sql" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestHelper provides utilities for database testing +type TestHelper struct { + DB *sql.DB + Adapter *PgSQLAdapter + t *testing.T +} + +// NewTestHelper creates a new test helper +func NewTestHelper(t *testing.T, db *sql.DB) *TestHelper { + return &TestHelper{ + DB: db, + Adapter: NewPgSQLAdapter(db), + t: t, + } +} + +// CleanupTables truncates all test tables +func (h *TestHelper) CleanupTables() { + ctx := context.Background() + tables := []string{"comments", "posts", "users"} + + for _, table := range tables { + _, err := h.DB.ExecContext(ctx, "TRUNCATE TABLE "+table+" CASCADE") + require.NoError(h.t, err) + } +} + +// InsertUser inserts a test user and returns the ID +func (h *TestHelper) InsertUser(name, email string, age int) int { + ctx := context.Background() + result, err := h.Adapter.NewInsert(). + Table("users"). + Value("name", name). + Value("email", email). + Value("age", age). + Exec(ctx) + + require.NoError(h.t, err) + id, _ := result.LastInsertId() + return int(id) +} + +// InsertPost inserts a test post and returns the ID +func (h *TestHelper) InsertPost(userID int, title, content string, published bool) int { + ctx := context.Background() + result, err := h.Adapter.NewInsert(). + Table("posts"). + Value("user_id", userID). + Value("title", title). + Value("content", content). + Value("published", published). + Exec(ctx) + + require.NoError(h.t, err) + id, _ := result.LastInsertId() + return int(id) +} + +// InsertComment inserts a test comment and returns the ID +func (h *TestHelper) InsertComment(postID int, content string) int { + ctx := context.Background() + result, err := h.Adapter.NewInsert(). + Table("comments"). + Value("post_id", postID). + Value("content", content). + Exec(ctx) + + require.NoError(h.t, err) + id, _ := result.LastInsertId() + return int(id) +} + +// AssertUserExists checks if a user exists by email +func (h *TestHelper) AssertUserExists(email string) { + ctx := context.Background() + exists, err := h.Adapter.NewSelect(). + Table("users"). + Where("email = ?", email). + Exists(ctx) + + require.NoError(h.t, err) + require.True(h.t, exists, "User with email %s should exist", email) +} + +// AssertUserCount asserts the number of users +func (h *TestHelper) AssertUserCount(expected int) { + ctx := context.Background() + count, err := h.Adapter.NewSelect(). + Table("users"). + Count(ctx) + + require.NoError(h.t, err) + require.Equal(h.t, expected, count) +} + +// GetUserByEmail retrieves a user by email +func (h *TestHelper) GetUserByEmail(email string) map[string]interface{} { + ctx := context.Background() + var results []map[string]interface{} + err := h.Adapter.NewSelect(). + Table("users"). + Where("email = ?", email). + Scan(ctx, &results) + + require.NoError(h.t, err) + require.Len(h.t, results, 1, "Expected exactly one user with email %s", email) + return results[0] +} + +// BeginTestTransaction starts a transaction for testing +func (h *TestHelper) BeginTestTransaction() (*PgSQLTxAdapter, func()) { + ctx := context.Background() + tx, err := h.DB.BeginTx(ctx, nil) + require.NoError(h.t, err) + + adapter := &PgSQLTxAdapter{tx: tx} + cleanup := func() { + tx.Rollback() + } + + return adapter, cleanup +}