Compare commits

..

6 Commits

Author SHA1 Message Date
Hein
8a9423df6d Fixed DatabaseAuthenticator JSON value. Added make tag 2025-12-11 13:59:41 +02:00
Hein
4cc943b9d3 Added row PgSQLAdapter
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
2025-12-10 15:28:09 +02:00
Hein
68dee78a34 Fixed filterExtendedOptions 2025-12-10 12:25:23 +02:00
Hein
efb9e5d9d5 Removed the buggy filter expand columns 2025-12-10 12:15:18 +02:00
Hein
490ae37c6d Fixed bugs in extractTableAndColumn 2025-12-10 11:48:03 +02:00
Hein
99307e31e6 More debugging on bun for scan issues 2025-12-10 11:16:25 +02:00
13 changed files with 3333 additions and 21 deletions

82
.github/workflows/make_tag.yml vendored Normal file
View File

@@ -0,0 +1,82 @@
# This workflow will build a golang project
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
name: Create Go Release (Tag Versioning)
on:
workflow_dispatch:
inputs:
semver:
description: "New Version"
required: true
default: "patch"
type: choice
options:
- patch
- minor
- major
jobs:
tag_and_commit:
name: "Tag and Commit ${{ github.event.inputs.semver }}"
runs-on: linux
permissions:
contents: write # 'write' access to repository contents
pull-requests: write # 'write' access to pull requests
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Set up Git
run: |
git config --global user.name "Hein"
git config --global user.email "hein.puth@gmail.com"
- name: Fetch latest tag
id: latest_tag
run: |
git fetch --tags
latest_tag=$(git describe --tags `git rev-list --tags --max-count=1`)
echo "::set-output name=tag::$latest_tag"
- name: Determine new tag version
id: new_tag
run: |
current_tag=${{ steps.latest_tag.outputs.tag }}
version=$(echo $current_tag | cut -c 2-) # remove the leading 'v'
IFS='.' read -r -a version_parts <<< "$version"
major=${version_parts[0]}
minor=${version_parts[1]}
patch=${version_parts[2]}
case "${{ github.event.inputs.semver }}" in
"patch")
((patch++))
;;
"minor")
((minor++))
patch=0
;;
"release")
((major++))
minor=0
patch=0
;;
*)
echo "Invalid semver input"
exit 1
;;
esac
new_tag="v$major.$minor.$patch"
echo "::set-output name=tag::$new_tag"
- name: Create tag
run: |
git tag -a ${{ steps.new_tag.outputs.tag }} -m "Tagging ${{ steps.new_tag.outputs.tag }} for release"
- name: Push changes
uses: ad-m/github-push-action@master
with:
github_token: ${{ secrets.BITECH_GITHUB_TOKEN }}
force: true
tags: true

14
.vscode/tasks.json vendored
View File

@@ -230,7 +230,17 @@
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"group": "test"
"group": "build"
},
{
"type": "shell",
"label": "go: lint workspace (fix)",
"command": "golangci-lint run --timeout=5m --fix",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"group": "build"
},
{
"type": "shell",
@@ -275,4 +285,4 @@
"command": "sh ${workspaceFolder}/make_release.sh"
}
]
}
}

View File

@@ -48,21 +48,42 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
}
// Log the type being scanned into
logger.Debug("Debug scan into type: %s (kind: %s)", v.Type().Name(), v.Kind())
typeName := v.Type().String()
logger.Debug("Debug scan into type: %s (kind: %s)", typeName, v.Kind())
// If it's a struct, log all field types
if v.Kind() == reflect.Struct {
for i := 0; i < v.NumField(); i++ {
field := v.Type().Field(i)
fieldValue := v.Field(i)
// Handle slice types - inspect the element type
var structType reflect.Type
if v.Kind() == reflect.Slice {
elemType := v.Type().Elem()
logger.Debug(" Slice element type: %s", elemType)
// If slice of pointers, get the underlying type
if elemType.Kind() == reflect.Ptr {
structType = elemType.Elem()
} else {
structType = elemType
}
} else if v.Kind() == reflect.Struct {
structType = v.Type()
}
// If we have a struct type, log all its fields
if structType != nil && structType.Kind() == reflect.Struct {
logger.Debug(" Struct %s has %d fields:", structType.Name(), structType.NumField())
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
// Log embedded fields specially
if field.Anonymous {
logger.Debug(" Embedded field [%d]: %s (type: %s, kind: %s)",
i, field.Name, field.Type, fieldValue.Kind())
logger.Debug(" [%d] EMBEDDED: %s (type: %s, kind: %s, bun:%q)",
i, field.Name, field.Type, field.Type.Kind(), field.Tag.Get("bun"))
} else {
logger.Debug(" Field [%d]: %s (type: %s, kind: %s, tag: %s)",
i, field.Name, field.Type, fieldValue.Kind(), field.Tag.Get("bun"))
bunTag := field.Tag.Get("bun")
if bunTag == "" {
bunTag = "(no tag)"
}
logger.Debug(" [%d] %s (type: %s, kind: %s, bun:%q)",
i, field.Name, field.Type, field.Type.Kind(), bunTag)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,176 @@
package database
import (
"context"
"database/sql"
"fmt"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Example demonstrates how to use the PgSQL adapter
func ExamplePgSQLAdapter() error {
// Connect to PostgreSQL database
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
}
defer db.Close()
// Create the PgSQL adapter
adapter := NewPgSQLAdapter(db)
// Enable query debugging (optional)
adapter.EnableQueryDebug()
ctx := context.Background()
// Example 1: Simple SELECT query
var results []map[string]interface{}
err = adapter.NewSelect().
Table("users").
Where("age > ?", 18).
Order("created_at DESC").
Limit(10).
Scan(ctx, &results)
if err != nil {
return fmt.Errorf("select failed: %w", err)
}
// Example 2: INSERT query
result, err := adapter.NewInsert().
Table("users").
Value("name", "John Doe").
Value("email", "john@example.com").
Value("age", 25).
Returning("id").
Exec(ctx)
if err != nil {
return fmt.Errorf("insert failed: %w", err)
}
fmt.Printf("Rows affected: %d\n", result.RowsAffected())
// Example 3: UPDATE query
result, err = adapter.NewUpdate().
Table("users").
Set("name", "Jane Doe").
Where("id = ?", 1).
Exec(ctx)
if err != nil {
return fmt.Errorf("update failed: %w", err)
}
fmt.Printf("Rows updated: %d\n", result.RowsAffected())
// Example 4: DELETE query
result, err = adapter.NewDelete().
Table("users").
Where("age < ?", 18).
Exec(ctx)
if err != nil {
return fmt.Errorf("delete failed: %w", err)
}
fmt.Printf("Rows deleted: %d\n", result.RowsAffected())
// Example 5: Using transactions
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
// Insert a new user
_, err := tx.NewInsert().
Table("users").
Value("name", "Transaction User").
Value("email", "tx@example.com").
Exec(ctx)
if err != nil {
return err
}
// Update another user
_, err = tx.NewUpdate().
Table("users").
Set("verified", true).
Where("email = ?", "tx@example.com").
Exec(ctx)
if err != nil {
return err
}
// Both operations succeed or both rollback
return nil
})
if err != nil {
return fmt.Errorf("transaction failed: %w", err)
}
// Example 6: JOIN query
err = adapter.NewSelect().
Table("users u").
Column("u.id", "u.name", "p.title as post_title").
LeftJoin("posts p ON p.user_id = u.id").
Where("u.active = ?", true).
Scan(ctx, &results)
if err != nil {
return fmt.Errorf("join query failed: %w", err)
}
// Example 7: Aggregation query
count, err := adapter.NewSelect().
Table("users").
Where("active = ?", true).
Count(ctx)
if err != nil {
return fmt.Errorf("count failed: %w", err)
}
fmt.Printf("Active users: %d\n", count)
// Example 8: Raw SQL execution
_, err = adapter.Exec(ctx, "CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
if err != nil {
return fmt.Errorf("raw exec failed: %w", err)
}
// Example 9: Raw SQL query
var users []map[string]interface{}
err = adapter.Query(ctx, &users, "SELECT * FROM users WHERE age > $1 LIMIT $2", 18, 10)
if err != nil {
return fmt.Errorf("raw query failed: %w", err)
}
return nil
}
// User is an example model
type User struct {
ID int `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
Age int `json:"age"`
}
// TableName implements common.TableNameProvider
func (u User) TableName() string {
return "users"
}
// ExampleWithModel demonstrates using models with the PgSQL adapter
func ExampleWithModel() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Use model with adapter
user := User{}
err = adapter.NewSelect().
Model(&user).
Where("id = ?", 1).
Scan(ctx, &user)
return err
}

View File

@@ -0,0 +1,526 @@
// +build integration
package database
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
// Integration test models
type IntegrationUser struct {
ID int `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
CreatedAt time.Time `db:"created_at"`
Posts []*IntegrationPost `bun:"rel:has-many,join:id=user_id"`
}
func (u IntegrationUser) TableName() string {
return "users"
}
type IntegrationPost struct {
ID int `db:"id"`
Title string `db:"title"`
Content string `db:"content"`
UserID int `db:"user_id"`
Published bool `db:"published"`
CreatedAt time.Time `db:"created_at"`
User *IntegrationUser `bun:"rel:belongs-to,join:user_id=id"`
Comments []*IntegrationComment `bun:"rel:has-many,join:id=post_id"`
}
func (p IntegrationPost) TableName() string {
return "posts"
}
type IntegrationComment struct {
ID int `db:"id"`
Content string `db:"content"`
PostID int `db:"post_id"`
CreatedAt time.Time `db:"created_at"`
Post *IntegrationPost `bun:"rel:belongs-to,join:post_id=id"`
}
func (c IntegrationComment) TableName() string {
return "comments"
}
// setupTestDB creates a PostgreSQL container and returns the connection
func setupTestDB(t *testing.T) (*sql.DB, func()) {
ctx := context.Background()
req := testcontainers.ContainerRequest{
Image: "postgres:15-alpine",
ExposedPorts: []string{"5432/tcp"},
Env: map[string]string{
"POSTGRES_USER": "testuser",
"POSTGRES_PASSWORD": "testpass",
"POSTGRES_DB": "testdb",
},
WaitingFor: wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(60 * time.Second),
}
postgres, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
require.NoError(t, err)
host, err := postgres.Host(ctx)
require.NoError(t, err)
port, err := postgres.MappedPort(ctx, "5432")
require.NoError(t, err)
dsn := fmt.Sprintf("postgres://testuser:testpass@%s:%s/testdb?sslmode=disable",
host, port.Port())
db, err := sql.Open("pgx", dsn)
require.NoError(t, err)
// Wait for database to be ready
err = db.Ping()
require.NoError(t, err)
// Create schema
createSchema(t, db)
cleanup := func() {
db.Close()
postgres.Terminate(ctx)
}
return db, cleanup
}
// createSchema creates test tables
func createSchema(t *testing.T, db *sql.DB) {
schema := `
DROP TABLE IF EXISTS comments CASCADE;
DROP TABLE IF EXISTS posts CASCADE;
DROP TABLE IF EXISTS users CASCADE;
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name VARCHAR(255) NOT NULL,
email VARCHAR(255) UNIQUE NOT NULL,
age INT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE posts (
id SERIAL PRIMARY KEY,
title VARCHAR(255) NOT NULL,
content TEXT NOT NULL,
user_id INT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
published BOOLEAN DEFAULT false,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE comments (
id SERIAL PRIMARY KEY,
content TEXT NOT NULL,
post_id INT NOT NULL REFERENCES posts(id) ON DELETE CASCADE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
`
_, err := db.Exec(schema)
require.NoError(t, err)
}
// TestIntegration_BasicCRUD tests basic CRUD operations
func TestIntegration_BasicCRUD(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// CREATE
result, err := adapter.NewInsert().
Table("users").
Value("name", "John Doe").
Value("email", "john@example.com").
Value("age", 25).
Exec(ctx)
require.NoError(t, err)
assert.Equal(t, int64(1), result.RowsAffected())
// READ
var users []IntegrationUser
err = adapter.NewSelect().
Table("users").
Where("email = ?", "john@example.com").
Scan(ctx, &users)
require.NoError(t, err)
assert.Len(t, users, 1)
assert.Equal(t, "John Doe", users[0].Name)
assert.Equal(t, 25, users[0].Age)
userID := users[0].ID
// UPDATE
result, err = adapter.NewUpdate().
Table("users").
Set("age", 26).
Where("id = ?", userID).
Exec(ctx)
require.NoError(t, err)
assert.Equal(t, int64(1), result.RowsAffected())
// Verify update
var updatedUser IntegrationUser
err = adapter.NewSelect().
Table("users").
Where("id = ?", userID).
Scan(ctx, &updatedUser)
require.NoError(t, err)
assert.Equal(t, 26, updatedUser.Age)
// DELETE
result, err = adapter.NewDelete().
Table("users").
Where("id = ?", userID).
Exec(ctx)
require.NoError(t, err)
assert.Equal(t, int64(1), result.RowsAffected())
// Verify delete
count, err := adapter.NewSelect().
Table("users").
Where("id = ?", userID).
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 0, count)
}
// TestIntegration_ScanModel tests ScanModel functionality
func TestIntegration_ScanModel(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Insert test data
_, err := adapter.NewInsert().
Table("users").
Value("name", "Jane Smith").
Value("email", "jane@example.com").
Value("age", 30).
Exec(ctx)
require.NoError(t, err)
// Test single struct scan
user := &IntegrationUser{}
err = adapter.NewSelect().
Model(user).
Table("users").
Where("email = ?", "jane@example.com").
ScanModel(ctx)
require.NoError(t, err)
assert.Equal(t, "Jane Smith", user.Name)
assert.Equal(t, 30, user.Age)
// Test slice scan
users := []*IntegrationUser{}
err = adapter.NewSelect().
Model(&users).
Table("users").
ScanModel(ctx)
require.NoError(t, err)
assert.Len(t, users, 1)
}
// TestIntegration_Transaction tests transaction handling
func TestIntegration_Transaction(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Successful transaction
err := adapter.RunInTransaction(ctx, func(tx common.Database) error {
_, err := tx.NewInsert().
Table("users").
Value("name", "Alice").
Value("email", "alice@example.com").
Value("age", 28).
Exec(ctx)
if err != nil {
return err
}
_, err = tx.NewInsert().
Table("users").
Value("name", "Bob").
Value("email", "bob@example.com").
Value("age", 32).
Exec(ctx)
return err
})
require.NoError(t, err)
// Verify both records exist
count, err := adapter.NewSelect().
Table("users").
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 2, count)
// Failed transaction (should rollback)
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
_, err := tx.NewInsert().
Table("users").
Value("name", "Charlie").
Value("email", "charlie@example.com").
Value("age", 35).
Exec(ctx)
if err != nil {
return err
}
// Intentional error - duplicate email
_, err = tx.NewInsert().
Table("users").
Value("name", "David").
Value("email", "alice@example.com"). // Duplicate
Value("age", 40).
Exec(ctx)
return err
})
assert.Error(t, err)
// Verify rollback - count should still be 2
count, err = adapter.NewSelect().
Table("users").
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 2, count)
}
// TestIntegration_Preload tests basic preload functionality
func TestIntegration_Preload(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
userID := createTestUser(t, adapter, ctx, "John Doe", "john@example.com", 25)
createTestPost(t, adapter, ctx, userID, "First Post", "Content 1", true)
createTestPost(t, adapter, ctx, userID, "Second Post", "Content 2", false)
// Test Preload
var users []*IntegrationUser
err := adapter.NewSelect().
Model(&IntegrationUser{}).
Table("users").
Preload("Posts").
Scan(ctx, &users)
require.NoError(t, err)
assert.Len(t, users, 1)
assert.NotNil(t, users[0].Posts)
assert.Len(t, users[0].Posts, 2)
}
// TestIntegration_PreloadRelation tests smart PreloadRelation
func TestIntegration_PreloadRelation(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
userID := createTestUser(t, adapter, ctx, "Jane Smith", "jane@example.com", 30)
postID := createTestPost(t, adapter, ctx, userID, "Test Post", "Test Content", true)
createTestComment(t, adapter, ctx, postID, "Great post!")
createTestComment(t, adapter, ctx, postID, "Thanks for sharing!")
// Test PreloadRelation with belongs-to (should use JOIN)
var posts []*IntegrationPost
err := adapter.NewSelect().
Model(&IntegrationPost{}).
Table("posts").
PreloadRelation("User").
Scan(ctx, &posts)
require.NoError(t, err)
assert.Len(t, posts, 1)
// Note: JOIN preloading needs proper column selection to work
// For now, we test that it doesn't error
// Test PreloadRelation with has-many (should use subquery)
posts = []*IntegrationPost{}
err = adapter.NewSelect().
Model(&IntegrationPost{}).
Table("posts").
PreloadRelation("Comments").
Scan(ctx, &posts)
require.NoError(t, err)
assert.Len(t, posts, 1)
if posts[0].Comments != nil {
assert.Len(t, posts[0].Comments, 2)
}
}
// TestIntegration_JoinRelation tests explicit JoinRelation
func TestIntegration_JoinRelation(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
userID := createTestUser(t, adapter, ctx, "Bob Wilson", "bob@example.com", 35)
createTestPost(t, adapter, ctx, userID, "Join Test", "Content", true)
// Test JoinRelation
var posts []*IntegrationPost
err := adapter.NewSelect().
Model(&IntegrationPost{}).
Table("posts").
JoinRelation("User").
Scan(ctx, &posts)
require.NoError(t, err)
assert.Len(t, posts, 1)
}
// TestIntegration_ComplexQuery tests complex queries
func TestIntegration_ComplexQuery(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
userID1 := createTestUser(t, adapter, ctx, "Alice", "alice@example.com", 25)
userID2 := createTestUser(t, adapter, ctx, "Bob", "bob@example.com", 30)
userID3 := createTestUser(t, adapter, ctx, "Charlie", "charlie@example.com", 35)
createTestPost(t, adapter, ctx, userID1, "Post 1", "Content", true)
createTestPost(t, adapter, ctx, userID2, "Post 2", "Content", true)
createTestPost(t, adapter, ctx, userID3, "Post 3", "Content", false)
// Complex query with joins, where, order, limit
var results []map[string]interface{}
err := adapter.NewSelect().
Table("posts p").
Column("p.title", "u.name as author_name", "u.age as author_age").
LeftJoin("users u ON u.id = p.user_id").
Where("p.published = ?", true).
WhereOr("u.age > ?", 25).
Order("u.age DESC").
Limit(2).
Scan(ctx, &results)
require.NoError(t, err)
assert.LessOrEqual(t, len(results), 2)
}
// TestIntegration_Aggregation tests aggregation queries
func TestIntegration_Aggregation(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
createTestUser(t, adapter, ctx, "User 1", "user1@example.com", 20)
createTestUser(t, adapter, ctx, "User 2", "user2@example.com", 25)
createTestUser(t, adapter, ctx, "User 3", "user3@example.com", 30)
// Test Count
count, err := adapter.NewSelect().
Table("users").
Where("age >= ?", 25).
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 2, count)
// Test Exists
exists, err := adapter.NewSelect().
Table("users").
Where("email = ?", "user1@example.com").
Exists(ctx)
require.NoError(t, err)
assert.True(t, exists)
// Test Group By with aggregation
var results []map[string]interface{}
err = adapter.NewSelect().
Table("users").
Column("age", "COUNT(*) as count").
Group("age").
Having("COUNT(*) > ?", 0).
Order("age ASC").
Scan(ctx, &results)
require.NoError(t, err)
assert.Len(t, results, 3)
}
// Helper functions
func createTestUser(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, name, email string, age int) int {
var userID int
err := adapter.Query(ctx, &userID,
"INSERT INTO users (name, email, age) VALUES ($1, $2, $3) RETURNING id",
name, email, age)
require.NoError(t, err)
return userID
}
func createTestPost(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, userID int, title, content string, published bool) int {
var postID int
err := adapter.Query(ctx, &postID,
"INSERT INTO posts (title, content, user_id, published) VALUES ($1, $2, $3, $4) RETURNING id",
title, content, userID, published)
require.NoError(t, err)
return postID
}
func createTestComment(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, postID int, content string) int {
var commentID int
err := adapter.Query(ctx, &commentID,
"INSERT INTO comments (content, post_id) VALUES ($1, $2) RETURNING id",
content, postID)
require.NoError(t, err)
return commentID
}

View File

@@ -0,0 +1,275 @@
package database
import (
"context"
"database/sql"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Example models for demonstrating preload functionality
// Author model - has many Posts
type Author struct {
ID int `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Posts []*Post `bun:"rel:has-many,join:id=author_id"`
}
func (a Author) TableName() string {
return "authors"
}
// Post model - belongs to Author, has many Comments
type Post struct {
ID int `db:"id"`
Title string `db:"title"`
Content string `db:"content"`
AuthorID int `db:"author_id"`
Author *Author `bun:"rel:belongs-to,join:author_id=id"`
Comments []*Comment `bun:"rel:has-many,join:id=post_id"`
}
func (p Post) TableName() string {
return "posts"
}
// Comment model - belongs to Post
type Comment struct {
ID int `db:"id"`
Content string `db:"content"`
PostID int `db:"post_id"`
Post *Post `bun:"rel:belongs-to,join:post_id=id"`
}
func (c Comment) TableName() string {
return "comments"
}
// ExamplePreload demonstrates the Preload functionality
func ExamplePreload() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Example 1: Simple Preload (uses subquery for has-many)
var authors []*Author
err = adapter.NewSelect().
Model(&Author{}).
Table("authors").
Preload("Posts"). // Load all posts for each author
Scan(ctx, &authors)
if err != nil {
return err
}
// Now authors[i].Posts will be populated with their posts
return nil
}
// ExamplePreloadRelation demonstrates smart PreloadRelation with auto-detection
func ExamplePreloadRelation() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Example 1: PreloadRelation auto-detects has-many (uses subquery)
var authors []*Author
err = adapter.NewSelect().
Model(&Author{}).
Table("authors").
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
return q.Where("published = ?", true).Order("created_at DESC")
}).
Where("active = ?", true).
Scan(ctx, &authors)
if err != nil {
return err
}
// Example 2: PreloadRelation auto-detects belongs-to (uses JOIN)
var posts []*Post
err = adapter.NewSelect().
Model(&Post{}).
Table("posts").
PreloadRelation("Author"). // Will use JOIN because it's belongs-to
Scan(ctx, &posts)
if err != nil {
return err
}
// Example 3: Nested preloads
err = adapter.NewSelect().
Model(&Author{}).
Table("authors").
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
// First load posts, then preload comments for each post
return q.Limit(10)
}).
Scan(ctx, &authors)
if err != nil {
return err
}
// Manually load nested relationships (two-level preloading)
for _, author := range authors {
if author.Posts != nil {
for _, post := range author.Posts {
var comments []*Comment
err := adapter.NewSelect().
Table("comments").
Where("post_id = ?", post.ID).
Scan(ctx, &comments)
if err == nil {
post.Comments = comments
}
}
}
}
return nil
}
// ExampleJoinRelation demonstrates explicit JOIN loading
func ExampleJoinRelation() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Example 1: Force JOIN for belongs-to relationship
var posts []*Post
err = adapter.NewSelect().
Model(&Post{}).
Table("posts").
JoinRelation("Author", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &posts)
if err != nil {
return err
}
// Example 2: Multiple JOINs
err = adapter.NewSelect().
Model(&Post{}).
Table("posts p").
Column("p.*", "a.name as author_name", "a.email as author_email").
LeftJoin("authors a ON a.id = p.author_id").
Where("p.published = ?", true).
Scan(ctx, &posts)
return err
}
// ExampleScanModel demonstrates ScanModel with struct destinations
func ExampleScanModel() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Example 1: Scan single struct
author := Author{}
err = adapter.NewSelect().
Model(&author).
Table("authors").
Where("id = ?", 1).
ScanModel(ctx) // ScanModel automatically uses the model set with Model()
if err != nil {
return err
}
// Example 2: Scan slice of structs
authors := []*Author{}
err = adapter.NewSelect().
Model(&authors).
Table("authors").
Where("active = ?", true).
Limit(10).
ScanModel(ctx)
return err
}
// ExampleCompleteWorkflow demonstrates a complete workflow with preloading
func ExampleCompleteWorkflow() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
adapter.EnableQueryDebug() // Enable query logging
ctx := context.Background()
// Step 1: Create an author
author := &Author{
Name: "John Doe",
Email: "john@example.com",
}
result, err := adapter.NewInsert().
Table("authors").
Value("name", author.Name).
Value("email", author.Email).
Returning("id").
Exec(ctx)
if err != nil {
return err
}
_ = result
// Step 2: Load author with all their posts
var loadedAuthor Author
err = adapter.NewSelect().
Model(&loadedAuthor).
Table("authors").
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
return q.Order("created_at DESC").Limit(5)
}).
Where("id = ?", 1).
ScanModel(ctx)
if err != nil {
return err
}
// Step 3: Update author name
_, err = adapter.NewUpdate().
Table("authors").
Set("name", "Jane Doe").
Where("id = ?", 1).
Exec(ctx)
return err
}

View File

@@ -0,0 +1,629 @@
package database
import (
"context"
"database/sql"
"reflect"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Test models
type TestUser struct {
ID int `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
}
func (u TestUser) TableName() string {
return "users"
}
type TestPost struct {
ID int `db:"id"`
Title string `db:"title"`
Content string `db:"content"`
UserID int `db:"user_id"`
User *TestUser `bun:"rel:belongs-to,join:user_id=id"`
Comments []TestComment `bun:"rel:has-many,join:id=post_id"`
}
func (p TestPost) TableName() string {
return "posts"
}
type TestComment struct {
ID int `db:"id"`
Content string `db:"content"`
PostID int `db:"post_id"`
}
func (c TestComment) TableName() string {
return "comments"
}
// TestNewPgSQLAdapter tests adapter creation
func TestNewPgSQLAdapter(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
adapter := NewPgSQLAdapter(db)
assert.NotNil(t, adapter)
assert.Equal(t, db, adapter.db)
}
// TestPgSQLSelectQuery_BuildSQL tests SQL query building
func TestPgSQLSelectQuery_BuildSQL(t *testing.T) {
tests := []struct {
name string
setup func(*PgSQLSelectQuery)
expected string
}{
{
name: "simple select",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
},
expected: "SELECT * FROM users",
},
{
name: "select with columns",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.columns = []string{"id", "name", "email"}
},
expected: "SELECT id, name, email FROM users",
},
{
name: "select with where",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.whereClauses = []string{"age > $1"}
q.args = []interface{}{18}
},
expected: "SELECT * FROM users WHERE (age > $1)",
},
{
name: "select with order and limit",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.orderBy = []string{"created_at DESC"}
q.limit = 10
q.offset = 5
},
expected: "SELECT * FROM users ORDER BY created_at DESC LIMIT 10 OFFSET 5",
},
{
name: "select with join",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.joins = []string{"LEFT JOIN posts ON posts.user_id = users.id"}
},
expected: "SELECT * FROM users LEFT JOIN posts ON posts.user_id = users.id",
},
{
name: "select with group and having",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.groupBy = []string{"country"}
q.havingClauses = []string{"COUNT(*) > $1"}
q.args = []interface{}{5}
},
expected: "SELECT * FROM users GROUP BY country HAVING COUNT(*) > $1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := &PgSQLSelectQuery{
columns: []string{"*"},
}
tt.setup(q)
sql := q.buildSQL()
assert.Equal(t, tt.expected, sql)
})
}
}
// TestPgSQLSelectQuery_ReplacePlaceholders tests placeholder replacement
func TestPgSQLSelectQuery_ReplacePlaceholders(t *testing.T) {
tests := []struct {
name string
query string
argCount int
paramCounter int
expected string
}{
{
name: "single placeholder",
query: "age > ?",
argCount: 1,
paramCounter: 0,
expected: "age > $1",
},
{
name: "multiple placeholders",
query: "age > ? AND status = ?",
argCount: 2,
paramCounter: 0,
expected: "age > $1 AND status = $2",
},
{
name: "with existing counter",
query: "name = ?",
argCount: 1,
paramCounter: 5,
expected: "name = $6",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := &PgSQLSelectQuery{paramCounter: tt.paramCounter}
result := q.replacePlaceholders(tt.query, tt.argCount)
assert.Equal(t, tt.expected, result)
})
}
}
// TestPgSQLSelectQuery_Chaining tests method chaining
func TestPgSQLSelectQuery_Chaining(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
adapter := NewPgSQLAdapter(db)
query := adapter.NewSelect().
Table("users").
Column("id", "name").
Where("age > ?", 18).
Order("name ASC").
Limit(10).
Offset(5)
pgQuery := query.(*PgSQLSelectQuery)
assert.Equal(t, "users", pgQuery.tableName)
assert.Equal(t, []string{"id", "name"}, pgQuery.columns)
assert.Len(t, pgQuery.whereClauses, 1)
assert.Equal(t, []string{"name ASC"}, pgQuery.orderBy)
assert.Equal(t, 10, pgQuery.limit)
assert.Equal(t, 5, pgQuery.offset)
}
// TestPgSQLSelectQuery_Model tests model setting
func TestPgSQLSelectQuery_Model(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
adapter := NewPgSQLAdapter(db)
user := &TestUser{}
query := adapter.NewSelect().Model(user)
pgQuery := query.(*PgSQLSelectQuery)
assert.Equal(t, "users", pgQuery.tableName)
assert.Equal(t, user, pgQuery.model)
}
// TestScanRowsToStructSlice tests scanning rows into struct slice
func TestScanRowsToStructSlice(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
AddRow(1, "John Doe", "john@example.com", 25).
AddRow(2, "Jane Smith", "jane@example.com", 30)
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
var users []TestUser
err = adapter.NewSelect().
Table("users").
Scan(ctx, &users)
require.NoError(t, err)
assert.Len(t, users, 2)
assert.Equal(t, "John Doe", users[0].Name)
assert.Equal(t, "jane@example.com", users[1].Email)
assert.Equal(t, 30, users[1].Age)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestScanRowsToStructSlicePointers tests scanning rows into pointer slice
func TestScanRowsToStructSlicePointers(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
AddRow(1, "John Doe", "john@example.com", 25)
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
var users []*TestUser
err = adapter.NewSelect().
Table("users").
Scan(ctx, &users)
require.NoError(t, err)
assert.Len(t, users, 1)
assert.NotNil(t, users[0])
assert.Equal(t, "John Doe", users[0].Name)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestScanRowsToSingleStruct tests scanning a single row
func TestScanRowsToSingleStruct(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
AddRow(1, "John Doe", "john@example.com", 25)
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
var user TestUser
err = adapter.NewSelect().
Table("users").
Where("id = ?", 1).
Scan(ctx, &user)
require.NoError(t, err)
assert.Equal(t, 1, user.ID)
assert.Equal(t, "John Doe", user.Name)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestScanRowsToMapSlice tests scanning into map slice
func TestScanRowsToMapSlice(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email"}).
AddRow(1, "John Doe", "john@example.com").
AddRow(2, "Jane Smith", "jane@example.com")
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
var results []map[string]interface{}
err = adapter.NewSelect().
Table("users").
Scan(ctx, &results)
require.NoError(t, err)
assert.Len(t, results, 2)
assert.Equal(t, int64(1), results[0]["id"])
assert.Equal(t, "John Doe", results[0]["name"])
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLInsertQuery_Exec tests insert query execution
func TestPgSQLInsertQuery_Exec(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock.ExpectExec("INSERT INTO users").
WithArgs("John Doe", "john@example.com", 25).
WillReturnResult(sqlmock.NewResult(1, 1))
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
result, err := adapter.NewInsert().
Table("users").
Value("name", "John Doe").
Value("email", "john@example.com").
Value("age", 25).
Exec(ctx)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int64(1), result.RowsAffected())
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLUpdateQuery_Exec tests update query execution
func TestPgSQLUpdateQuery_Exec(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
// Note: Args order is SET values first, then WHERE values
mock.ExpectExec("UPDATE users SET name = \\$1 WHERE id = \\$2").
WithArgs("Jane Doe", 1).
WillReturnResult(sqlmock.NewResult(0, 1))
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
result, err := adapter.NewUpdate().
Table("users").
Set("name", "Jane Doe").
Where("id = ?", 1).
Exec(ctx)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int64(1), result.RowsAffected())
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLDeleteQuery_Exec tests delete query execution
func TestPgSQLDeleteQuery_Exec(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock.ExpectExec("DELETE FROM users WHERE id = \\$1").
WithArgs(1).
WillReturnResult(sqlmock.NewResult(0, 1))
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
result, err := adapter.NewDelete().
Table("users").
Where("id = ?", 1).
Exec(ctx)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int64(1), result.RowsAffected())
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLSelectQuery_Count tests count query
func TestPgSQLSelectQuery_Count(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"count"}).AddRow(42)
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
count, err := adapter.NewSelect().
Table("users").
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 42, count)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLSelectQuery_Exists tests exists query
func TestPgSQLSelectQuery_Exists(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"count"}).AddRow(1)
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
exists, err := adapter.NewSelect().
Table("users").
Where("email = ?", "john@example.com").
Exists(ctx)
require.NoError(t, err)
assert.True(t, exists)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLAdapter_Transaction tests transaction handling
func TestPgSQLAdapter_Transaction(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock.ExpectBegin()
mock.ExpectExec("INSERT INTO users").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
_, err := tx.NewInsert().
Table("users").
Value("name", "John").
Exec(ctx)
return err
})
require.NoError(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLAdapter_TransactionRollback tests transaction rollback
func TestPgSQLAdapter_TransactionRollback(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock.ExpectBegin()
mock.ExpectExec("INSERT INTO users").WillReturnError(sql.ErrConnDone)
mock.ExpectRollback()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
_, err := tx.NewInsert().
Table("users").
Value("name", "John").
Exec(ctx)
return err
})
assert.Error(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestBuildFieldMap tests field mapping construction
func TestBuildFieldMap(t *testing.T) {
userType := reflect.TypeOf(TestUser{})
fieldMap := buildFieldMap(userType, nil)
assert.NotEmpty(t, fieldMap)
// Check that fields are mapped
assert.Contains(t, fieldMap, "id")
assert.Contains(t, fieldMap, "name")
assert.Contains(t, fieldMap, "email")
assert.Contains(t, fieldMap, "age")
// Check field info
idInfo := fieldMap["id"]
assert.Equal(t, "ID", idInfo.Name)
}
// TestGetRelationMetadata tests relationship metadata extraction
func TestGetRelationMetadata(t *testing.T) {
q := &PgSQLSelectQuery{
model: &TestPost{},
}
// Test belongs-to relationship
meta := q.getRelationMetadata("User")
assert.NotNil(t, meta)
assert.Equal(t, "User", meta.fieldName)
// Test has-many relationship
meta = q.getRelationMetadata("Comments")
assert.NotNil(t, meta)
assert.Equal(t, "Comments", meta.fieldName)
}
// TestPreloadConfiguration tests preload configuration
func TestPreloadConfiguration(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
adapter := NewPgSQLAdapter(db)
// Test Preload
query := adapter.NewSelect().
Model(&TestPost{}).
Table("posts").
Preload("User")
pgQuery := query.(*PgSQLSelectQuery)
assert.Len(t, pgQuery.preloads, 1)
assert.Equal(t, "User", pgQuery.preloads[0].relation)
assert.False(t, pgQuery.preloads[0].useJoin)
// Test PreloadRelation
query = adapter.NewSelect().
Model(&TestPost{}).
Table("posts").
PreloadRelation("Comments")
pgQuery = query.(*PgSQLSelectQuery)
assert.Len(t, pgQuery.preloads, 1)
assert.Equal(t, "Comments", pgQuery.preloads[0].relation)
// Test JoinRelation
query = adapter.NewSelect().
Model(&TestPost{}).
Table("posts").
JoinRelation("User")
pgQuery = query.(*PgSQLSelectQuery)
assert.Len(t, pgQuery.preloads, 1)
assert.Equal(t, "User", pgQuery.preloads[0].relation)
assert.True(t, pgQuery.preloads[0].useJoin)
}
// TestScanModel tests ScanModel functionality
func TestScanModel(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
AddRow(1, "John Doe", "john@example.com", 25)
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
user := &TestUser{}
err = adapter.NewSelect().
Model(user).
Table("users").
Where("id = ?", 1).
ScanModel(ctx)
require.NoError(t, err)
assert.Equal(t, 1, user.ID)
assert.Equal(t, "John Doe", user.Name)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestRawSQL tests raw SQL execution
func TestRawSQL(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
// Test Exec
mock.ExpectExec("CREATE TABLE test").WillReturnResult(sqlmock.NewResult(0, 0))
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
_, err = adapter.Exec(ctx, "CREATE TABLE test (id INT)")
require.NoError(t, err)
// Test Query
rows := sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "Test")
mock.ExpectQuery("SELECT (.+) FROM test").WillReturnRows(rows)
var results []map[string]interface{}
err = adapter.Query(ctx, &results, "SELECT * FROM test WHERE id = $1", 1)
require.NoError(t, err)
assert.Len(t, results, 1)
assert.NoError(t, mock.ExpectationsWereMet())
}

View File

@@ -0,0 +1,132 @@
package database
import (
"context"
"database/sql"
"testing"
"github.com/stretchr/testify/require"
)
// TestHelper provides utilities for database testing
type TestHelper struct {
DB *sql.DB
Adapter *PgSQLAdapter
t *testing.T
}
// NewTestHelper creates a new test helper
func NewTestHelper(t *testing.T, db *sql.DB) *TestHelper {
return &TestHelper{
DB: db,
Adapter: NewPgSQLAdapter(db),
t: t,
}
}
// CleanupTables truncates all test tables
func (h *TestHelper) CleanupTables() {
ctx := context.Background()
tables := []string{"comments", "posts", "users"}
for _, table := range tables {
_, err := h.DB.ExecContext(ctx, "TRUNCATE TABLE "+table+" CASCADE")
require.NoError(h.t, err)
}
}
// InsertUser inserts a test user and returns the ID
func (h *TestHelper) InsertUser(name, email string, age int) int {
ctx := context.Background()
result, err := h.Adapter.NewInsert().
Table("users").
Value("name", name).
Value("email", email).
Value("age", age).
Exec(ctx)
require.NoError(h.t, err)
id, _ := result.LastInsertId()
return int(id)
}
// InsertPost inserts a test post and returns the ID
func (h *TestHelper) InsertPost(userID int, title, content string, published bool) int {
ctx := context.Background()
result, err := h.Adapter.NewInsert().
Table("posts").
Value("user_id", userID).
Value("title", title).
Value("content", content).
Value("published", published).
Exec(ctx)
require.NoError(h.t, err)
id, _ := result.LastInsertId()
return int(id)
}
// InsertComment inserts a test comment and returns the ID
func (h *TestHelper) InsertComment(postID int, content string) int {
ctx := context.Background()
result, err := h.Adapter.NewInsert().
Table("comments").
Value("post_id", postID).
Value("content", content).
Exec(ctx)
require.NoError(h.t, err)
id, _ := result.LastInsertId()
return int(id)
}
// AssertUserExists checks if a user exists by email
func (h *TestHelper) AssertUserExists(email string) {
ctx := context.Background()
exists, err := h.Adapter.NewSelect().
Table("users").
Where("email = ?", email).
Exists(ctx)
require.NoError(h.t, err)
require.True(h.t, exists, "User with email %s should exist", email)
}
// AssertUserCount asserts the number of users
func (h *TestHelper) AssertUserCount(expected int) {
ctx := context.Background()
count, err := h.Adapter.NewSelect().
Table("users").
Count(ctx)
require.NoError(h.t, err)
require.Equal(h.t, expected, count)
}
// GetUserByEmail retrieves a user by email
func (h *TestHelper) GetUserByEmail(email string) map[string]interface{} {
ctx := context.Background()
var results []map[string]interface{}
err := h.Adapter.NewSelect().
Table("users").
Where("email = ?", email).
Scan(ctx, &results)
require.NoError(h.t, err)
require.Len(h.t, results, 1, "Expected exactly one user with email %s", email)
return results[0]
}
// BeginTestTransaction starts a transaction for testing
func (h *TestHelper) BeginTestTransaction() (*PgSQLTxAdapter, func()) {
ctx := context.Background()
tx, err := h.DB.BeginTx(ctx, nil)
require.NoError(h.t, err)
adapter := &PgSQLTxAdapter{tx: tx}
cleanup := func() {
tx.Rollback()
}
return adapter, cleanup
}

View File

@@ -430,7 +430,45 @@ func extractTableAndColumn(cond string) (table string, column string) {
// Remove any quotes
columnRef = strings.Trim(columnRef, "`\"'")
// Check if it contains a dot (qualified reference)
// Check if there's a function call (contains opening parenthesis)
openParenIdx := strings.Index(columnRef, "(")
if openParenIdx >= 0 {
// There's a function call - find the FIRST dot after the opening paren
// This handles cases like: ifblnk(users.status, orders.status) - extracts users.status
dotIdx := strings.Index(columnRef[openParenIdx:], ".")
if dotIdx > 0 {
dotIdx += openParenIdx // Adjust to absolute position
// Extract table name (between paren and dot)
// Find the last opening paren before this dot
lastOpenParen := strings.LastIndex(columnRef[:dotIdx], "(")
table = columnRef[lastOpenParen+1 : dotIdx]
// Find the column name - it ends at comma, closing paren, whitespace, or end of string
columnStart := dotIdx + 1
columnEnd := len(columnRef)
for i := columnStart; i < len(columnRef); i++ {
ch := columnRef[i]
if ch == ',' || ch == ')' || ch == ' ' || ch == '\t' {
columnEnd = i
break
}
}
column = columnRef[columnStart:columnEnd]
// Remove quotes from table and column if present
table = strings.Trim(table, "`\"'")
column = strings.Trim(column, "`\"'")
return table, column
}
}
// No function call - check if it contains a dot (qualified reference)
// Use LastIndex to handle schema.table.column properly
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
table = columnRef[:dotIdx]
column = columnRef[dotIdx+1:]

View File

@@ -286,6 +286,48 @@ func TestExtractTableAndColumn(t *testing.T) {
expectedTable: "",
expectedCol: "",
},
{
name: "function call with table.column - ifblnk",
input: "ifblnk(users.status,0) in (1,2,3,4)",
expectedTable: "users",
expectedCol: "status",
},
{
name: "function call with table.column - coalesce",
input: "coalesce(users.age, 0) = 25",
expectedTable: "users",
expectedCol: "age",
},
{
name: "nested function calls",
input: "upper(trim(users.name)) = 'JOHN'",
expectedTable: "users",
expectedCol: "name",
},
{
name: "function with multiple args and table.column",
input: "substring(users.email, 1, 5) = 'admin'",
expectedTable: "users",
expectedCol: "email",
},
{
name: "cast function with table.column",
input: "cast(orders.total as decimal) > 100",
expectedTable: "orders",
expectedCol: "total",
},
{
name: "complex nested functions",
input: "coalesce(nullif(users.status, ''), 'default') = 'active'",
expectedTable: "users",
expectedCol: "status",
},
{
name: "function with multiple table.column refs (extracts first)",
input: "greatest(users.created_at, users.updated_at) > '2024-01-01'",
expectedTable: "users",
expectedCol: "created_at",
},
}
for _, tt := range tests {
@@ -352,6 +394,14 @@ func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
},
expected: "users.status = 'active' AND Department.name = 'Engineering'",
},
{
name: "Function Call with correct table prefix - unchanged",
where: "ifblnk(users.status,0) in (1,2,3,4)",
tableName: "users",
options: nil,
expected: "ifblnk(users.status,0) in (1,2,3,4)",
},
{
name: "no options provided - works as before",
where: "wrong_table.status = 'active'",

View File

@@ -127,7 +127,7 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// Validate and filter columns in options (log warnings for invalid columns)
validator := common.NewColumnValidator(model)
options = filterExtendedOptions(validator, options)
options = h.filterExtendedOptions(validator, options, model)
// Add request-scoped data to context (including options)
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
@@ -2241,7 +2241,7 @@ func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
}
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions {
func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions, model interface{}) ExtendedRequestOptions {
filtered := options
// Filter base RequestOptions
@@ -2265,12 +2265,30 @@ func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRe
// No filtering needed for ComputedQL keys
filtered.ComputedQL = options.ComputedQL
// Filter Expand columns
// Filter Expand columns using the expand relation's model
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
for _, expand := range options.Expand {
filteredExpand := expand
// Don't validate relation name, only columns
filteredExpand.Columns = validator.FilterValidColumns(expand.Columns)
// Get the relationship info for this expand relation
relInfo := h.getRelationshipInfo(modelType, expand.Relation)
if relInfo != nil && relInfo.relatedModel != nil {
// Create a validator for the related model
expandValidator := common.NewColumnValidator(relInfo.relatedModel)
// Filter columns using the related model's validator
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
} else {
// If we can't find the relationship, log a warning and skip column filtering
logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation)
// Keep the columns as-is if we can't validate them
filteredExpand.Columns = expand.Columns
}
filteredExpands = append(filteredExpands, filteredExpand)
}
filtered.Expand = filteredExpands

View File

@@ -111,7 +111,7 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)`
err = a.db.QueryRowContext(ctx, query, reqJSON).Scan(&success, &errorMsg, &dataJSON)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil {
return nil, fmt.Errorf("login query failed: %w", err)
}
@@ -145,7 +145,7 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)`
err = a.db.QueryRowContext(ctx, query, reqJSON).Scan(&success, &errorMsg, &dataJSON)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil {
return fmt.Errorf("logout query failed: %w", err)
}
@@ -297,7 +297,7 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
var updatedUserJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)`
_ = a.db.QueryRowContext(ctx, query, sessionToken, userJSON).Scan(&success, &errorMsg, &updatedUserJSON)
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
}
// RefreshToken implements Refreshable interface