mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 15:54:26 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c50eeac5bf | ||
|
|
6d88f2668a | ||
|
|
8a9423df6d | ||
|
|
4cc943b9d3 | ||
|
|
68dee78a34 | ||
|
|
efb9e5d9d5 | ||
|
|
490ae37c6d | ||
|
|
99307e31e6 | ||
|
|
e3f7869c6d | ||
|
|
c696d502c5 | ||
|
|
4ed1fba6ad |
82
.github/workflows/make_tag.yml
vendored
Normal file
82
.github/workflows/make_tag.yml
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
# This workflow will build a golang project
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
|
||||
|
||||
name: Create Go Release (Tag Versioning)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
semver:
|
||||
description: "New Version"
|
||||
required: true
|
||||
default: "patch"
|
||||
type: choice
|
||||
options:
|
||||
- patch
|
||||
- minor
|
||||
- major
|
||||
|
||||
jobs:
|
||||
tag_and_commit:
|
||||
name: "Tag and Commit ${{ github.event.inputs.semver }}"
|
||||
runs-on: linux
|
||||
permissions:
|
||||
contents: write # 'write' access to repository contents
|
||||
pull-requests: write # 'write' access to pull requests
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Git
|
||||
run: |
|
||||
git config --global user.name "Hein"
|
||||
git config --global user.email "hein.puth@gmail.com"
|
||||
|
||||
- name: Fetch latest tag
|
||||
id: latest_tag
|
||||
run: |
|
||||
git fetch --tags
|
||||
latest_tag=$(git describe --tags `git rev-list --tags --max-count=1`)
|
||||
echo "::set-output name=tag::$latest_tag"
|
||||
|
||||
- name: Determine new tag version
|
||||
id: new_tag
|
||||
run: |
|
||||
current_tag=${{ steps.latest_tag.outputs.tag }}
|
||||
version=$(echo $current_tag | cut -c 2-) # remove the leading 'v'
|
||||
IFS='.' read -r -a version_parts <<< "$version"
|
||||
major=${version_parts[0]}
|
||||
minor=${version_parts[1]}
|
||||
patch=${version_parts[2]}
|
||||
case "${{ github.event.inputs.semver }}" in
|
||||
"patch")
|
||||
((patch++))
|
||||
;;
|
||||
"minor")
|
||||
((minor++))
|
||||
patch=0
|
||||
;;
|
||||
"release")
|
||||
((major++))
|
||||
minor=0
|
||||
patch=0
|
||||
;;
|
||||
*)
|
||||
echo "Invalid semver input"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
new_tag="v$major.$minor.$patch"
|
||||
echo "::set-output name=tag::$new_tag"
|
||||
|
||||
- name: Create tag
|
||||
run: |
|
||||
git tag -a ${{ steps.new_tag.outputs.tag }} -m "Tagging ${{ steps.new_tag.outputs.tag }} for release"
|
||||
|
||||
- name: Push changes
|
||||
uses: ad-m/github-push-action@master
|
||||
with:
|
||||
github_token: ${{ secrets.BITECH_GITHUB_TOKEN }}
|
||||
force: true
|
||||
tags: true
|
||||
14
.vscode/tasks.json
vendored
14
.vscode/tasks.json
vendored
@@ -230,7 +230,17 @@
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "test"
|
||||
"group": "build"
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "go: lint workspace (fix)",
|
||||
"command": "golangci-lint run --timeout=5m --fix",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "build"
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
@@ -275,4 +285,4 @@
|
||||
"command": "sh ${workspaceFolder}/make_release.sh"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
6
go.mod
6
go.mod
@@ -5,12 +5,15 @@ go 1.24.0
|
||||
toolchain go1.24.6
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf
|
||||
github.com/getsentry/sentry-go v0.40.0
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/redis/go-redis/v9 v9.17.1
|
||||
github.com/spf13/viper v1.21.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
@@ -30,7 +33,6 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
@@ -38,7 +40,6 @@ require (
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/getsentry/sentry-go v0.40.0 // indirect
|
||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
@@ -66,7 +67,6 @@ require (
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/spf13/viper v1.21.0 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
|
||||
8
go.sum
8
go.sum
@@ -19,6 +19,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
|
||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
|
||||
@@ -27,6 +29,8 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g
|
||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
||||
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
|
||||
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
@@ -77,6 +81,10 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
|
||||
@@ -34,6 +34,63 @@ func (h *QueryDebugHook) AfterQuery(ctx context.Context, event *bun.QueryEvent)
|
||||
}
|
||||
}
|
||||
|
||||
// debugScanIntoStruct attempts to scan rows into a struct with detailed field-level logging
|
||||
// This helps identify which specific field is causing scanning issues
|
||||
func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("dest must be a pointer")
|
||||
}
|
||||
|
||||
v = v.Elem()
|
||||
if v.Kind() != reflect.Struct && v.Kind() != reflect.Slice {
|
||||
return fmt.Errorf("dest must be pointer to struct or slice")
|
||||
}
|
||||
|
||||
// Log the type being scanned into
|
||||
typeName := v.Type().String()
|
||||
logger.Debug("Debug scan into type: %s (kind: %s)", typeName, v.Kind())
|
||||
|
||||
// Handle slice types - inspect the element type
|
||||
var structType reflect.Type
|
||||
if v.Kind() == reflect.Slice {
|
||||
elemType := v.Type().Elem()
|
||||
logger.Debug(" Slice element type: %s", elemType)
|
||||
|
||||
// If slice of pointers, get the underlying type
|
||||
if elemType.Kind() == reflect.Ptr {
|
||||
structType = elemType.Elem()
|
||||
} else {
|
||||
structType = elemType
|
||||
}
|
||||
} else if v.Kind() == reflect.Struct {
|
||||
structType = v.Type()
|
||||
}
|
||||
|
||||
// If we have a struct type, log all its fields
|
||||
if structType != nil && structType.Kind() == reflect.Struct {
|
||||
logger.Debug(" Struct %s has %d fields:", structType.Name(), structType.NumField())
|
||||
for i := 0; i < structType.NumField(); i++ {
|
||||
field := structType.Field(i)
|
||||
|
||||
// Log embedded fields specially
|
||||
if field.Anonymous {
|
||||
logger.Debug(" [%d] EMBEDDED: %s (type: %s, kind: %s, bun:%q)",
|
||||
i, field.Name, field.Type, field.Type.Kind(), field.Tag.Get("bun"))
|
||||
} else {
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag == "" {
|
||||
bunTag = "(no tag)"
|
||||
}
|
||||
logger.Debug(" [%d] %s (type: %s, kind: %s, bun:%q)",
|
||||
i, field.Name, field.Type, field.Type.Kind(), bunTag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BunAdapter adapts Bun to work with our Database interface
|
||||
// This demonstrates how the abstraction works with different ORMs
|
||||
type BunAdapter struct {
|
||||
@@ -52,6 +109,14 @@ func (b *BunAdapter) EnableQueryDebug() {
|
||||
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
|
||||
}
|
||||
|
||||
// EnableDetailedScanDebug enables verbose logging of scan operations
|
||||
// WARNING: This generates a LOT of log output. Use only for debugging specific issues.
|
||||
func (b *BunAdapter) EnableDetailedScanDebug() {
|
||||
logger.Info("Detailed scan debugging enabled - will log all field scanning operations")
|
||||
// This is a flag that can be checked in scan operations
|
||||
// Implementation would require modifying the scan logic
|
||||
}
|
||||
|
||||
// DisableQueryDebug removes all query hooks
|
||||
func (b *BunAdapter) DisableQueryDebug() {
|
||||
// Create a new DB without hooks
|
||||
@@ -676,6 +741,31 @@ func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error)
|
||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// Enhanced panic recovery with model information
|
||||
model := b.query.GetModel()
|
||||
var modelInfo string
|
||||
if model != nil && model.Value() != nil {
|
||||
modelValue := model.Value()
|
||||
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
|
||||
|
||||
// Try to get the model's underlying struct type
|
||||
v := reflect.ValueOf(modelValue)
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
if v.Kind() == reflect.Slice {
|
||||
if v.Type().Elem().Kind() == reflect.Ptr {
|
||||
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Elem().Name())
|
||||
} else {
|
||||
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Name())
|
||||
}
|
||||
} else if v.Kind() == reflect.Struct {
|
||||
modelInfo += fmt.Sprintf(", Struct: %s", v.Type().Name())
|
||||
}
|
||||
}
|
||||
|
||||
sqlStr := b.query.String()
|
||||
logger.Error("Panic in BunSelectQuery.ScanModel: %v. %s. SQL: %s", r, modelInfo, sqlStr)
|
||||
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
|
||||
}
|
||||
}()
|
||||
@@ -683,6 +773,17 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
return fmt.Errorf("model is nil")
|
||||
}
|
||||
|
||||
// Optional: Enable detailed field-level debugging (set to true to debug)
|
||||
const enableDetailedDebug = true
|
||||
if enableDetailedDebug {
|
||||
model := b.query.GetModel()
|
||||
if model != nil && model.Value() != nil {
|
||||
if err := debugScanIntoStruct(nil, model.Value()); err != nil {
|
||||
logger.Warn("Debug scan inspection failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the main query first
|
||||
err = b.query.Scan(ctx)
|
||||
if err != nil {
|
||||
|
||||
1355
pkg/common/adapters/database/pgsql.go
Normal file
1355
pkg/common/adapters/database/pgsql.go
Normal file
File diff suppressed because it is too large
Load Diff
176
pkg/common/adapters/database/pgsql_example.go
Normal file
176
pkg/common/adapters/database/pgsql_example.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// Example demonstrates how to use the PgSQL adapter
|
||||
func ExamplePgSQLAdapter() error {
|
||||
// Connect to PostgreSQL database
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create the PgSQL adapter
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
|
||||
// Enable query debugging (optional)
|
||||
adapter.EnableQueryDebug()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Example 1: Simple SELECT query
|
||||
var results []map[string]interface{}
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("age > ?", 18).
|
||||
Order("created_at DESC").
|
||||
Limit(10).
|
||||
Scan(ctx, &results)
|
||||
if err != nil {
|
||||
return fmt.Errorf("select failed: %w", err)
|
||||
}
|
||||
|
||||
// Example 2: INSERT query
|
||||
result, err := adapter.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "John Doe").
|
||||
Value("email", "john@example.com").
|
||||
Value("age", 25).
|
||||
Returning("id").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
fmt.Printf("Rows affected: %d\n", result.RowsAffected())
|
||||
|
||||
// Example 3: UPDATE query
|
||||
result, err = adapter.NewUpdate().
|
||||
Table("users").
|
||||
Set("name", "Jane Doe").
|
||||
Where("id = ?", 1).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
fmt.Printf("Rows updated: %d\n", result.RowsAffected())
|
||||
|
||||
// Example 4: DELETE query
|
||||
result, err = adapter.NewDelete().
|
||||
Table("users").
|
||||
Where("age < ?", 18).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete failed: %w", err)
|
||||
}
|
||||
fmt.Printf("Rows deleted: %d\n", result.RowsAffected())
|
||||
|
||||
// Example 5: Using transactions
|
||||
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Insert a new user
|
||||
_, err := tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "Transaction User").
|
||||
Value("email", "tx@example.com").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update another user
|
||||
_, err = tx.NewUpdate().
|
||||
Table("users").
|
||||
Set("verified", true).
|
||||
Where("email = ?", "tx@example.com").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Both operations succeed or both rollback
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("transaction failed: %w", err)
|
||||
}
|
||||
|
||||
// Example 6: JOIN query
|
||||
err = adapter.NewSelect().
|
||||
Table("users u").
|
||||
Column("u.id", "u.name", "p.title as post_title").
|
||||
LeftJoin("posts p ON p.user_id = u.id").
|
||||
Where("u.active = ?", true).
|
||||
Scan(ctx, &results)
|
||||
if err != nil {
|
||||
return fmt.Errorf("join query failed: %w", err)
|
||||
}
|
||||
|
||||
// Example 7: Aggregation query
|
||||
count, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("active = ?", true).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("count failed: %w", err)
|
||||
}
|
||||
fmt.Printf("Active users: %d\n", count)
|
||||
|
||||
// Example 8: Raw SQL execution
|
||||
_, err = adapter.Exec(ctx, "CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
|
||||
if err != nil {
|
||||
return fmt.Errorf("raw exec failed: %w", err)
|
||||
}
|
||||
|
||||
// Example 9: Raw SQL query
|
||||
var users []map[string]interface{}
|
||||
err = adapter.Query(ctx, &users, "SELECT * FROM users WHERE age > $1 LIMIT $2", 18, 10)
|
||||
if err != nil {
|
||||
return fmt.Errorf("raw query failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// User is an example model
|
||||
type User struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
// TableName implements common.TableNameProvider
|
||||
func (u User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
// ExampleWithModel demonstrates using models with the PgSQL adapter
|
||||
func ExampleWithModel() error {
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Use model with adapter
|
||||
user := User{}
|
||||
err = adapter.NewSelect().
|
||||
Model(&user).
|
||||
Where("id = ?", 1).
|
||||
Scan(ctx, &user)
|
||||
|
||||
return err
|
||||
}
|
||||
526
pkg/common/adapters/database/pgsql_integration_test.go
Normal file
526
pkg/common/adapters/database/pgsql_integration_test.go
Normal file
@@ -0,0 +1,526 @@
|
||||
// +build integration
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
// Integration test models
|
||||
type IntegrationUser struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Email string `db:"email"`
|
||||
Age int `db:"age"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
Posts []*IntegrationPost `bun:"rel:has-many,join:id=user_id"`
|
||||
}
|
||||
|
||||
func (u IntegrationUser) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
type IntegrationPost struct {
|
||||
ID int `db:"id"`
|
||||
Title string `db:"title"`
|
||||
Content string `db:"content"`
|
||||
UserID int `db:"user_id"`
|
||||
Published bool `db:"published"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
User *IntegrationUser `bun:"rel:belongs-to,join:user_id=id"`
|
||||
Comments []*IntegrationComment `bun:"rel:has-many,join:id=post_id"`
|
||||
}
|
||||
|
||||
func (p IntegrationPost) TableName() string {
|
||||
return "posts"
|
||||
}
|
||||
|
||||
type IntegrationComment struct {
|
||||
ID int `db:"id"`
|
||||
Content string `db:"content"`
|
||||
PostID int `db:"post_id"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
Post *IntegrationPost `bun:"rel:belongs-to,join:post_id=id"`
|
||||
}
|
||||
|
||||
func (c IntegrationComment) TableName() string {
|
||||
return "comments"
|
||||
}
|
||||
|
||||
// setupTestDB creates a PostgreSQL container and returns the connection
|
||||
func setupTestDB(t *testing.T) (*sql.DB, func()) {
|
||||
ctx := context.Background()
|
||||
|
||||
req := testcontainers.ContainerRequest{
|
||||
Image: "postgres:15-alpine",
|
||||
ExposedPorts: []string{"5432/tcp"},
|
||||
Env: map[string]string{
|
||||
"POSTGRES_USER": "testuser",
|
||||
"POSTGRES_PASSWORD": "testpass",
|
||||
"POSTGRES_DB": "testdb",
|
||||
},
|
||||
WaitingFor: wait.ForLog("database system is ready to accept connections").
|
||||
WithOccurrence(2).
|
||||
WithStartupTimeout(60 * time.Second),
|
||||
}
|
||||
|
||||
postgres, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||
ContainerRequest: req,
|
||||
Started: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
host, err := postgres.Host(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
port, err := postgres.MappedPort(ctx, "5432")
|
||||
require.NoError(t, err)
|
||||
|
||||
dsn := fmt.Sprintf("postgres://testuser:testpass@%s:%s/testdb?sslmode=disable",
|
||||
host, port.Port())
|
||||
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for database to be ready
|
||||
err = db.Ping()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create schema
|
||||
createSchema(t, db)
|
||||
|
||||
cleanup := func() {
|
||||
db.Close()
|
||||
postgres.Terminate(ctx)
|
||||
}
|
||||
|
||||
return db, cleanup
|
||||
}
|
||||
|
||||
// createSchema creates test tables
|
||||
func createSchema(t *testing.T, db *sql.DB) {
|
||||
schema := `
|
||||
DROP TABLE IF EXISTS comments CASCADE;
|
||||
DROP TABLE IF EXISTS posts CASCADE;
|
||||
DROP TABLE IF EXISTS users CASCADE;
|
||||
|
||||
CREATE TABLE users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
email VARCHAR(255) UNIQUE NOT NULL,
|
||||
age INT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE posts (
|
||||
id SERIAL PRIMARY KEY,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
user_id INT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
published BOOLEAN DEFAULT false,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE comments (
|
||||
id SERIAL PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
post_id INT NOT NULL REFERENCES posts(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
`
|
||||
|
||||
_, err := db.Exec(schema)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// TestIntegration_BasicCRUD tests basic CRUD operations
|
||||
func TestIntegration_BasicCRUD(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// CREATE
|
||||
result, err := adapter.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "John Doe").
|
||||
Value("email", "john@example.com").
|
||||
Value("age", 25).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
// READ
|
||||
var users []IntegrationUser
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("email = ?", "john@example.com").
|
||||
Scan(ctx, &users)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 1)
|
||||
assert.Equal(t, "John Doe", users[0].Name)
|
||||
assert.Equal(t, 25, users[0].Age)
|
||||
|
||||
userID := users[0].ID
|
||||
|
||||
// UPDATE
|
||||
result, err = adapter.NewUpdate().
|
||||
Table("users").
|
||||
Set("age", 26).
|
||||
Where("id = ?", userID).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
// Verify update
|
||||
var updatedUser IntegrationUser
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("id = ?", userID).
|
||||
Scan(ctx, &updatedUser)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 26, updatedUser.Age)
|
||||
|
||||
// DELETE
|
||||
result, err = adapter.NewDelete().
|
||||
Table("users").
|
||||
Where("id = ?", userID).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
// Verify delete
|
||||
count, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("id = ?", userID).
|
||||
Count(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
// TestIntegration_ScanModel tests ScanModel functionality
|
||||
func TestIntegration_ScanModel(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Insert test data
|
||||
_, err := adapter.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "Jane Smith").
|
||||
Value("email", "jane@example.com").
|
||||
Value("age", 30).
|
||||
Exec(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test single struct scan
|
||||
user := &IntegrationUser{}
|
||||
err = adapter.NewSelect().
|
||||
Model(user).
|
||||
Table("users").
|
||||
Where("email = ?", "jane@example.com").
|
||||
ScanModel(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Jane Smith", user.Name)
|
||||
assert.Equal(t, 30, user.Age)
|
||||
|
||||
// Test slice scan
|
||||
users := []*IntegrationUser{}
|
||||
err = adapter.NewSelect().
|
||||
Model(&users).
|
||||
Table("users").
|
||||
ScanModel(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 1)
|
||||
}
|
||||
|
||||
// TestIntegration_Transaction tests transaction handling
|
||||
func TestIntegration_Transaction(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Successful transaction
|
||||
err := adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
_, err := tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "Alice").
|
||||
Value("email", "alice@example.com").
|
||||
Value("age", 28).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "Bob").
|
||||
Value("email", "bob@example.com").
|
||||
Value("age", 32).
|
||||
Exec(ctx)
|
||||
return err
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify both records exist
|
||||
count, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
// Failed transaction (should rollback)
|
||||
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
_, err := tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "Charlie").
|
||||
Value("email", "charlie@example.com").
|
||||
Value("age", 35).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Intentional error - duplicate email
|
||||
_, err = tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "David").
|
||||
Value("email", "alice@example.com"). // Duplicate
|
||||
Value("age", 40).
|
||||
Exec(ctx)
|
||||
return err
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
|
||||
// Verify rollback - count should still be 2
|
||||
count, err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Count(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
// TestIntegration_Preload tests basic preload functionality
|
||||
func TestIntegration_Preload(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
userID := createTestUser(t, adapter, ctx, "John Doe", "john@example.com", 25)
|
||||
createTestPost(t, adapter, ctx, userID, "First Post", "Content 1", true)
|
||||
createTestPost(t, adapter, ctx, userID, "Second Post", "Content 2", false)
|
||||
|
||||
// Test Preload
|
||||
var users []*IntegrationUser
|
||||
err := adapter.NewSelect().
|
||||
Model(&IntegrationUser{}).
|
||||
Table("users").
|
||||
Preload("Posts").
|
||||
Scan(ctx, &users)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 1)
|
||||
assert.NotNil(t, users[0].Posts)
|
||||
assert.Len(t, users[0].Posts, 2)
|
||||
}
|
||||
|
||||
// TestIntegration_PreloadRelation tests smart PreloadRelation
|
||||
func TestIntegration_PreloadRelation(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
userID := createTestUser(t, adapter, ctx, "Jane Smith", "jane@example.com", 30)
|
||||
postID := createTestPost(t, adapter, ctx, userID, "Test Post", "Test Content", true)
|
||||
createTestComment(t, adapter, ctx, postID, "Great post!")
|
||||
createTestComment(t, adapter, ctx, postID, "Thanks for sharing!")
|
||||
|
||||
// Test PreloadRelation with belongs-to (should use JOIN)
|
||||
var posts []*IntegrationPost
|
||||
err := adapter.NewSelect().
|
||||
Model(&IntegrationPost{}).
|
||||
Table("posts").
|
||||
PreloadRelation("User").
|
||||
Scan(ctx, &posts)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, posts, 1)
|
||||
// Note: JOIN preloading needs proper column selection to work
|
||||
// For now, we test that it doesn't error
|
||||
|
||||
// Test PreloadRelation with has-many (should use subquery)
|
||||
posts = []*IntegrationPost{}
|
||||
err = adapter.NewSelect().
|
||||
Model(&IntegrationPost{}).
|
||||
Table("posts").
|
||||
PreloadRelation("Comments").
|
||||
Scan(ctx, &posts)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, posts, 1)
|
||||
if posts[0].Comments != nil {
|
||||
assert.Len(t, posts[0].Comments, 2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestIntegration_JoinRelation tests explicit JoinRelation
|
||||
func TestIntegration_JoinRelation(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
userID := createTestUser(t, adapter, ctx, "Bob Wilson", "bob@example.com", 35)
|
||||
createTestPost(t, adapter, ctx, userID, "Join Test", "Content", true)
|
||||
|
||||
// Test JoinRelation
|
||||
var posts []*IntegrationPost
|
||||
err := adapter.NewSelect().
|
||||
Model(&IntegrationPost{}).
|
||||
Table("posts").
|
||||
JoinRelation("User").
|
||||
Scan(ctx, &posts)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, posts, 1)
|
||||
}
|
||||
|
||||
// TestIntegration_ComplexQuery tests complex queries
|
||||
func TestIntegration_ComplexQuery(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
userID1 := createTestUser(t, adapter, ctx, "Alice", "alice@example.com", 25)
|
||||
userID2 := createTestUser(t, adapter, ctx, "Bob", "bob@example.com", 30)
|
||||
userID3 := createTestUser(t, adapter, ctx, "Charlie", "charlie@example.com", 35)
|
||||
|
||||
createTestPost(t, adapter, ctx, userID1, "Post 1", "Content", true)
|
||||
createTestPost(t, adapter, ctx, userID2, "Post 2", "Content", true)
|
||||
createTestPost(t, adapter, ctx, userID3, "Post 3", "Content", false)
|
||||
|
||||
// Complex query with joins, where, order, limit
|
||||
var results []map[string]interface{}
|
||||
err := adapter.NewSelect().
|
||||
Table("posts p").
|
||||
Column("p.title", "u.name as author_name", "u.age as author_age").
|
||||
LeftJoin("users u ON u.id = p.user_id").
|
||||
Where("p.published = ?", true).
|
||||
WhereOr("u.age > ?", 25).
|
||||
Order("u.age DESC").
|
||||
Limit(2).
|
||||
Scan(ctx, &results)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.LessOrEqual(t, len(results), 2)
|
||||
}
|
||||
|
||||
// TestIntegration_Aggregation tests aggregation queries
|
||||
func TestIntegration_Aggregation(t *testing.T) {
|
||||
db, cleanup := setupTestDB(t)
|
||||
defer cleanup()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
createTestUser(t, adapter, ctx, "User 1", "user1@example.com", 20)
|
||||
createTestUser(t, adapter, ctx, "User 2", "user2@example.com", 25)
|
||||
createTestUser(t, adapter, ctx, "User 3", "user3@example.com", 30)
|
||||
|
||||
// Test Count
|
||||
count, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("age >= ?", 25).
|
||||
Count(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
// Test Exists
|
||||
exists, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("email = ?", "user1@example.com").
|
||||
Exists(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// Test Group By with aggregation
|
||||
var results []map[string]interface{}
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Column("age", "COUNT(*) as count").
|
||||
Group("age").
|
||||
Having("COUNT(*) > ?", 0).
|
||||
Order("age ASC").
|
||||
Scan(ctx, &results)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 3)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func createTestUser(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, name, email string, age int) int {
|
||||
var userID int
|
||||
err := adapter.Query(ctx, &userID,
|
||||
"INSERT INTO users (name, email, age) VALUES ($1, $2, $3) RETURNING id",
|
||||
name, email, age)
|
||||
require.NoError(t, err)
|
||||
return userID
|
||||
}
|
||||
|
||||
func createTestPost(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, userID int, title, content string, published bool) int {
|
||||
var postID int
|
||||
err := adapter.Query(ctx, &postID,
|
||||
"INSERT INTO posts (title, content, user_id, published) VALUES ($1, $2, $3, $4) RETURNING id",
|
||||
title, content, userID, published)
|
||||
require.NoError(t, err)
|
||||
return postID
|
||||
}
|
||||
|
||||
func createTestComment(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, postID int, content string) int {
|
||||
var commentID int
|
||||
err := adapter.Query(ctx, &commentID,
|
||||
"INSERT INTO comments (content, post_id) VALUES ($1, $2) RETURNING id",
|
||||
content, postID)
|
||||
require.NoError(t, err)
|
||||
return commentID
|
||||
}
|
||||
275
pkg/common/adapters/database/pgsql_preload_example.go
Normal file
275
pkg/common/adapters/database/pgsql_preload_example.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// Example models for demonstrating preload functionality
|
||||
|
||||
// Author model - has many Posts
|
||||
type Author struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Email string `db:"email"`
|
||||
Posts []*Post `bun:"rel:has-many,join:id=author_id"`
|
||||
}
|
||||
|
||||
func (a Author) TableName() string {
|
||||
return "authors"
|
||||
}
|
||||
|
||||
// Post model - belongs to Author, has many Comments
|
||||
type Post struct {
|
||||
ID int `db:"id"`
|
||||
Title string `db:"title"`
|
||||
Content string `db:"content"`
|
||||
AuthorID int `db:"author_id"`
|
||||
Author *Author `bun:"rel:belongs-to,join:author_id=id"`
|
||||
Comments []*Comment `bun:"rel:has-many,join:id=post_id"`
|
||||
}
|
||||
|
||||
func (p Post) TableName() string {
|
||||
return "posts"
|
||||
}
|
||||
|
||||
// Comment model - belongs to Post
|
||||
type Comment struct {
|
||||
ID int `db:"id"`
|
||||
Content string `db:"content"`
|
||||
PostID int `db:"post_id"`
|
||||
Post *Post `bun:"rel:belongs-to,join:post_id=id"`
|
||||
}
|
||||
|
||||
func (c Comment) TableName() string {
|
||||
return "comments"
|
||||
}
|
||||
|
||||
// ExamplePreload demonstrates the Preload functionality
|
||||
func ExamplePreload() error {
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Example 1: Simple Preload (uses subquery for has-many)
|
||||
var authors []*Author
|
||||
err = adapter.NewSelect().
|
||||
Model(&Author{}).
|
||||
Table("authors").
|
||||
Preload("Posts"). // Load all posts for each author
|
||||
Scan(ctx, &authors)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now authors[i].Posts will be populated with their posts
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExamplePreloadRelation demonstrates smart PreloadRelation with auto-detection
|
||||
func ExamplePreloadRelation() error {
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Example 1: PreloadRelation auto-detects has-many (uses subquery)
|
||||
var authors []*Author
|
||||
err = adapter.NewSelect().
|
||||
Model(&Author{}).
|
||||
Table("authors").
|
||||
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
|
||||
return q.Where("published = ?", true).Order("created_at DESC")
|
||||
}).
|
||||
Where("active = ?", true).
|
||||
Scan(ctx, &authors)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Example 2: PreloadRelation auto-detects belongs-to (uses JOIN)
|
||||
var posts []*Post
|
||||
err = adapter.NewSelect().
|
||||
Model(&Post{}).
|
||||
Table("posts").
|
||||
PreloadRelation("Author"). // Will use JOIN because it's belongs-to
|
||||
Scan(ctx, &posts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Example 3: Nested preloads
|
||||
err = adapter.NewSelect().
|
||||
Model(&Author{}).
|
||||
Table("authors").
|
||||
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
|
||||
// First load posts, then preload comments for each post
|
||||
return q.Limit(10)
|
||||
}).
|
||||
Scan(ctx, &authors)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Manually load nested relationships (two-level preloading)
|
||||
for _, author := range authors {
|
||||
if author.Posts != nil {
|
||||
for _, post := range author.Posts {
|
||||
var comments []*Comment
|
||||
err := adapter.NewSelect().
|
||||
Table("comments").
|
||||
Where("post_id = ?", post.ID).
|
||||
Scan(ctx, &comments)
|
||||
if err == nil {
|
||||
post.Comments = comments
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleJoinRelation demonstrates explicit JOIN loading
|
||||
func ExampleJoinRelation() error {
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Example 1: Force JOIN for belongs-to relationship
|
||||
var posts []*Post
|
||||
err = adapter.NewSelect().
|
||||
Model(&Post{}).
|
||||
Table("posts").
|
||||
JoinRelation("Author", func(q common.SelectQuery) common.SelectQuery {
|
||||
return q.Where("active = ?", true)
|
||||
}).
|
||||
Scan(ctx, &posts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Example 2: Multiple JOINs
|
||||
err = adapter.NewSelect().
|
||||
Model(&Post{}).
|
||||
Table("posts p").
|
||||
Column("p.*", "a.name as author_name", "a.email as author_email").
|
||||
LeftJoin("authors a ON a.id = p.author_id").
|
||||
Where("p.published = ?", true).
|
||||
Scan(ctx, &posts)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ExampleScanModel demonstrates ScanModel with struct destinations
|
||||
func ExampleScanModel() error {
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Example 1: Scan single struct
|
||||
author := Author{}
|
||||
err = adapter.NewSelect().
|
||||
Model(&author).
|
||||
Table("authors").
|
||||
Where("id = ?", 1).
|
||||
ScanModel(ctx) // ScanModel automatically uses the model set with Model()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Example 2: Scan slice of structs
|
||||
authors := []*Author{}
|
||||
err = adapter.NewSelect().
|
||||
Model(&authors).
|
||||
Table("authors").
|
||||
Where("active = ?", true).
|
||||
Limit(10).
|
||||
ScanModel(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ExampleCompleteWorkflow demonstrates a complete workflow with preloading
|
||||
func ExampleCompleteWorkflow() error {
|
||||
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
adapter.EnableQueryDebug() // Enable query logging
|
||||
ctx := context.Background()
|
||||
|
||||
// Step 1: Create an author
|
||||
author := &Author{
|
||||
Name: "John Doe",
|
||||
Email: "john@example.com",
|
||||
}
|
||||
|
||||
result, err := adapter.NewInsert().
|
||||
Table("authors").
|
||||
Value("name", author.Name).
|
||||
Value("email", author.Email).
|
||||
Returning("id").
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_ = result
|
||||
|
||||
// Step 2: Load author with all their posts
|
||||
var loadedAuthor Author
|
||||
err = adapter.NewSelect().
|
||||
Model(&loadedAuthor).
|
||||
Table("authors").
|
||||
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
|
||||
return q.Order("created_at DESC").Limit(5)
|
||||
}).
|
||||
Where("id = ?", 1).
|
||||
ScanModel(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Step 3: Update author name
|
||||
_, err = adapter.NewUpdate().
|
||||
Table("authors").
|
||||
Set("name", "Jane Doe").
|
||||
Where("id = ?", 1).
|
||||
Exec(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
629
pkg/common/adapters/database/pgsql_test.go
Normal file
629
pkg/common/adapters/database/pgsql_test.go
Normal file
@@ -0,0 +1,629 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// Test models
|
||||
type TestUser struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Email string `db:"email"`
|
||||
Age int `db:"age"`
|
||||
}
|
||||
|
||||
func (u TestUser) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
type TestPost struct {
|
||||
ID int `db:"id"`
|
||||
Title string `db:"title"`
|
||||
Content string `db:"content"`
|
||||
UserID int `db:"user_id"`
|
||||
User *TestUser `bun:"rel:belongs-to,join:user_id=id"`
|
||||
Comments []TestComment `bun:"rel:has-many,join:id=post_id"`
|
||||
}
|
||||
|
||||
func (p TestPost) TableName() string {
|
||||
return "posts"
|
||||
}
|
||||
|
||||
type TestComment struct {
|
||||
ID int `db:"id"`
|
||||
Content string `db:"content"`
|
||||
PostID int `db:"post_id"`
|
||||
}
|
||||
|
||||
func (c TestComment) TableName() string {
|
||||
return "comments"
|
||||
}
|
||||
|
||||
// TestNewPgSQLAdapter tests adapter creation
|
||||
func TestNewPgSQLAdapter(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
assert.NotNil(t, adapter)
|
||||
assert.Equal(t, db, adapter.db)
|
||||
}
|
||||
|
||||
// TestPgSQLSelectQuery_BuildSQL tests SQL query building
|
||||
func TestPgSQLSelectQuery_BuildSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*PgSQLSelectQuery)
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple select",
|
||||
setup: func(q *PgSQLSelectQuery) {
|
||||
q.tableName = "users"
|
||||
},
|
||||
expected: "SELECT * FROM users",
|
||||
},
|
||||
{
|
||||
name: "select with columns",
|
||||
setup: func(q *PgSQLSelectQuery) {
|
||||
q.tableName = "users"
|
||||
q.columns = []string{"id", "name", "email"}
|
||||
},
|
||||
expected: "SELECT id, name, email FROM users",
|
||||
},
|
||||
{
|
||||
name: "select with where",
|
||||
setup: func(q *PgSQLSelectQuery) {
|
||||
q.tableName = "users"
|
||||
q.whereClauses = []string{"age > $1"}
|
||||
q.args = []interface{}{18}
|
||||
},
|
||||
expected: "SELECT * FROM users WHERE (age > $1)",
|
||||
},
|
||||
{
|
||||
name: "select with order and limit",
|
||||
setup: func(q *PgSQLSelectQuery) {
|
||||
q.tableName = "users"
|
||||
q.orderBy = []string{"created_at DESC"}
|
||||
q.limit = 10
|
||||
q.offset = 5
|
||||
},
|
||||
expected: "SELECT * FROM users ORDER BY created_at DESC LIMIT 10 OFFSET 5",
|
||||
},
|
||||
{
|
||||
name: "select with join",
|
||||
setup: func(q *PgSQLSelectQuery) {
|
||||
q.tableName = "users"
|
||||
q.joins = []string{"LEFT JOIN posts ON posts.user_id = users.id"}
|
||||
},
|
||||
expected: "SELECT * FROM users LEFT JOIN posts ON posts.user_id = users.id",
|
||||
},
|
||||
{
|
||||
name: "select with group and having",
|
||||
setup: func(q *PgSQLSelectQuery) {
|
||||
q.tableName = "users"
|
||||
q.groupBy = []string{"country"}
|
||||
q.havingClauses = []string{"COUNT(*) > $1"}
|
||||
q.args = []interface{}{5}
|
||||
},
|
||||
expected: "SELECT * FROM users GROUP BY country HAVING COUNT(*) > $1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := &PgSQLSelectQuery{
|
||||
columns: []string{"*"},
|
||||
}
|
||||
tt.setup(q)
|
||||
sql := q.buildSQL()
|
||||
assert.Equal(t, tt.expected, sql)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPgSQLSelectQuery_ReplacePlaceholders tests placeholder replacement
|
||||
func TestPgSQLSelectQuery_ReplacePlaceholders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
argCount int
|
||||
paramCounter int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single placeholder",
|
||||
query: "age > ?",
|
||||
argCount: 1,
|
||||
paramCounter: 0,
|
||||
expected: "age > $1",
|
||||
},
|
||||
{
|
||||
name: "multiple placeholders",
|
||||
query: "age > ? AND status = ?",
|
||||
argCount: 2,
|
||||
paramCounter: 0,
|
||||
expected: "age > $1 AND status = $2",
|
||||
},
|
||||
{
|
||||
name: "with existing counter",
|
||||
query: "name = ?",
|
||||
argCount: 1,
|
||||
paramCounter: 5,
|
||||
expected: "name = $6",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := &PgSQLSelectQuery{paramCounter: tt.paramCounter}
|
||||
result := q.replacePlaceholders(tt.query, tt.argCount)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPgSQLSelectQuery_Chaining tests method chaining
|
||||
func TestPgSQLSelectQuery_Chaining(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
query := adapter.NewSelect().
|
||||
Table("users").
|
||||
Column("id", "name").
|
||||
Where("age > ?", 18).
|
||||
Order("name ASC").
|
||||
Limit(10).
|
||||
Offset(5)
|
||||
|
||||
pgQuery := query.(*PgSQLSelectQuery)
|
||||
assert.Equal(t, "users", pgQuery.tableName)
|
||||
assert.Equal(t, []string{"id", "name"}, pgQuery.columns)
|
||||
assert.Len(t, pgQuery.whereClauses, 1)
|
||||
assert.Equal(t, []string{"name ASC"}, pgQuery.orderBy)
|
||||
assert.Equal(t, 10, pgQuery.limit)
|
||||
assert.Equal(t, 5, pgQuery.offset)
|
||||
}
|
||||
|
||||
// TestPgSQLSelectQuery_Model tests model setting
|
||||
func TestPgSQLSelectQuery_Model(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
user := &TestUser{}
|
||||
query := adapter.NewSelect().Model(user)
|
||||
|
||||
pgQuery := query.(*PgSQLSelectQuery)
|
||||
assert.Equal(t, "users", pgQuery.tableName)
|
||||
assert.Equal(t, user, pgQuery.model)
|
||||
}
|
||||
|
||||
// TestScanRowsToStructSlice tests scanning rows into struct slice
|
||||
func TestScanRowsToStructSlice(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
|
||||
AddRow(1, "John Doe", "john@example.com", 25).
|
||||
AddRow(2, "Jane Smith", "jane@example.com", 30)
|
||||
|
||||
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
var users []TestUser
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Scan(ctx, &users)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 2)
|
||||
assert.Equal(t, "John Doe", users[0].Name)
|
||||
assert.Equal(t, "jane@example.com", users[1].Email)
|
||||
assert.Equal(t, 30, users[1].Age)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestScanRowsToStructSlicePointers tests scanning rows into pointer slice
|
||||
func TestScanRowsToStructSlicePointers(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
|
||||
AddRow(1, "John Doe", "john@example.com", 25)
|
||||
|
||||
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
var users []*TestUser
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Scan(ctx, &users)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, users, 1)
|
||||
assert.NotNil(t, users[0])
|
||||
assert.Equal(t, "John Doe", users[0].Name)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestScanRowsToSingleStruct tests scanning a single row
|
||||
func TestScanRowsToSingleStruct(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
|
||||
AddRow(1, "John Doe", "john@example.com", 25)
|
||||
|
||||
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
var user TestUser
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("id = ?", 1).
|
||||
Scan(ctx, &user)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, user.ID)
|
||||
assert.Equal(t, "John Doe", user.Name)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestScanRowsToMapSlice tests scanning into map slice
|
||||
func TestScanRowsToMapSlice(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "name", "email"}).
|
||||
AddRow(1, "John Doe", "john@example.com").
|
||||
AddRow(2, "Jane Smith", "jane@example.com")
|
||||
|
||||
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
var results []map[string]interface{}
|
||||
err = adapter.NewSelect().
|
||||
Table("users").
|
||||
Scan(ctx, &results)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 2)
|
||||
assert.Equal(t, int64(1), results[0]["id"])
|
||||
assert.Equal(t, "John Doe", results[0]["name"])
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLInsertQuery_Exec tests insert query execution
|
||||
func TestPgSQLInsertQuery_Exec(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectExec("INSERT INTO users").
|
||||
WithArgs("John Doe", "john@example.com", 25).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := adapter.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "John Doe").
|
||||
Value("email", "john@example.com").
|
||||
Value("age", 25).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLUpdateQuery_Exec tests update query execution
|
||||
func TestPgSQLUpdateQuery_Exec(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Note: Args order is SET values first, then WHERE values
|
||||
mock.ExpectExec("UPDATE users SET name = \\$1 WHERE id = \\$2").
|
||||
WithArgs("Jane Doe", 1).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := adapter.NewUpdate().
|
||||
Table("users").
|
||||
Set("name", "Jane Doe").
|
||||
Where("id = ?", 1).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLDeleteQuery_Exec tests delete query execution
|
||||
func TestPgSQLDeleteQuery_Exec(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectExec("DELETE FROM users WHERE id = \\$1").
|
||||
WithArgs(1).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
result, err := adapter.NewDelete().
|
||||
Table("users").
|
||||
Where("id = ?", 1).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLSelectQuery_Count tests count query
|
||||
func TestPgSQLSelectQuery_Count(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"count"}).AddRow(42)
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
count, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Count(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 42, count)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLSelectQuery_Exists tests exists query
|
||||
func TestPgSQLSelectQuery_Exists(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"count"}).AddRow(1)
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
exists, err := adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("email = ?", "john@example.com").
|
||||
Exists(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLAdapter_Transaction tests transaction handling
|
||||
func TestPgSQLAdapter_Transaction(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("INSERT INTO users").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
_, err := tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "John").
|
||||
Exec(ctx)
|
||||
return err
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestPgSQLAdapter_TransactionRollback tests transaction rollback
|
||||
func TestPgSQLAdapter_TransactionRollback(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("INSERT INTO users").WillReturnError(sql.ErrConnDone)
|
||||
mock.ExpectRollback()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
_, err := tx.NewInsert().
|
||||
Table("users").
|
||||
Value("name", "John").
|
||||
Exec(ctx)
|
||||
return err
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestBuildFieldMap tests field mapping construction
|
||||
func TestBuildFieldMap(t *testing.T) {
|
||||
userType := reflect.TypeOf(TestUser{})
|
||||
fieldMap := buildFieldMap(userType, nil)
|
||||
|
||||
assert.NotEmpty(t, fieldMap)
|
||||
|
||||
// Check that fields are mapped
|
||||
assert.Contains(t, fieldMap, "id")
|
||||
assert.Contains(t, fieldMap, "name")
|
||||
assert.Contains(t, fieldMap, "email")
|
||||
assert.Contains(t, fieldMap, "age")
|
||||
|
||||
// Check field info
|
||||
idInfo := fieldMap["id"]
|
||||
assert.Equal(t, "ID", idInfo.Name)
|
||||
}
|
||||
|
||||
// TestGetRelationMetadata tests relationship metadata extraction
|
||||
func TestGetRelationMetadata(t *testing.T) {
|
||||
q := &PgSQLSelectQuery{
|
||||
model: &TestPost{},
|
||||
}
|
||||
|
||||
// Test belongs-to relationship
|
||||
meta := q.getRelationMetadata("User")
|
||||
assert.NotNil(t, meta)
|
||||
assert.Equal(t, "User", meta.fieldName)
|
||||
|
||||
// Test has-many relationship
|
||||
meta = q.getRelationMetadata("Comments")
|
||||
assert.NotNil(t, meta)
|
||||
assert.Equal(t, "Comments", meta.fieldName)
|
||||
}
|
||||
|
||||
// TestPreloadConfiguration tests preload configuration
|
||||
func TestPreloadConfiguration(t *testing.T) {
|
||||
db, _, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
|
||||
// Test Preload
|
||||
query := adapter.NewSelect().
|
||||
Model(&TestPost{}).
|
||||
Table("posts").
|
||||
Preload("User")
|
||||
|
||||
pgQuery := query.(*PgSQLSelectQuery)
|
||||
assert.Len(t, pgQuery.preloads, 1)
|
||||
assert.Equal(t, "User", pgQuery.preloads[0].relation)
|
||||
assert.False(t, pgQuery.preloads[0].useJoin)
|
||||
|
||||
// Test PreloadRelation
|
||||
query = adapter.NewSelect().
|
||||
Model(&TestPost{}).
|
||||
Table("posts").
|
||||
PreloadRelation("Comments")
|
||||
|
||||
pgQuery = query.(*PgSQLSelectQuery)
|
||||
assert.Len(t, pgQuery.preloads, 1)
|
||||
assert.Equal(t, "Comments", pgQuery.preloads[0].relation)
|
||||
|
||||
// Test JoinRelation
|
||||
query = adapter.NewSelect().
|
||||
Model(&TestPost{}).
|
||||
Table("posts").
|
||||
JoinRelation("User")
|
||||
|
||||
pgQuery = query.(*PgSQLSelectQuery)
|
||||
assert.Len(t, pgQuery.preloads, 1)
|
||||
assert.Equal(t, "User", pgQuery.preloads[0].relation)
|
||||
assert.True(t, pgQuery.preloads[0].useJoin)
|
||||
}
|
||||
|
||||
// TestScanModel tests ScanModel functionality
|
||||
func TestScanModel(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
|
||||
AddRow(1, "John Doe", "john@example.com", 25)
|
||||
|
||||
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &TestUser{}
|
||||
err = adapter.NewSelect().
|
||||
Model(user).
|
||||
Table("users").
|
||||
Where("id = ?", 1).
|
||||
ScanModel(ctx)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, user.ID)
|
||||
assert.Equal(t, "John Doe", user.Name)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
// TestRawSQL tests raw SQL execution
|
||||
func TestRawSQL(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
|
||||
// Test Exec
|
||||
mock.ExpectExec("CREATE TABLE test").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
|
||||
adapter := NewPgSQLAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = adapter.Exec(ctx, "CREATE TABLE test (id INT)")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test Query
|
||||
rows := sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "Test")
|
||||
mock.ExpectQuery("SELECT (.+) FROM test").WillReturnRows(rows)
|
||||
|
||||
var results []map[string]interface{}
|
||||
err = adapter.Query(ctx, &results, "SELECT * FROM test WHERE id = $1", 1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, 1)
|
||||
|
||||
assert.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
132
pkg/common/adapters/database/test_helpers.go
Normal file
132
pkg/common/adapters/database/test_helpers.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestHelper provides utilities for database testing
|
||||
type TestHelper struct {
|
||||
DB *sql.DB
|
||||
Adapter *PgSQLAdapter
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
// NewTestHelper creates a new test helper
|
||||
func NewTestHelper(t *testing.T, db *sql.DB) *TestHelper {
|
||||
return &TestHelper{
|
||||
DB: db,
|
||||
Adapter: NewPgSQLAdapter(db),
|
||||
t: t,
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupTables truncates all test tables
|
||||
func (h *TestHelper) CleanupTables() {
|
||||
ctx := context.Background()
|
||||
tables := []string{"comments", "posts", "users"}
|
||||
|
||||
for _, table := range tables {
|
||||
_, err := h.DB.ExecContext(ctx, "TRUNCATE TABLE "+table+" CASCADE")
|
||||
require.NoError(h.t, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InsertUser inserts a test user and returns the ID
|
||||
func (h *TestHelper) InsertUser(name, email string, age int) int {
|
||||
ctx := context.Background()
|
||||
result, err := h.Adapter.NewInsert().
|
||||
Table("users").
|
||||
Value("name", name).
|
||||
Value("email", email).
|
||||
Value("age", age).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(h.t, err)
|
||||
id, _ := result.LastInsertId()
|
||||
return int(id)
|
||||
}
|
||||
|
||||
// InsertPost inserts a test post and returns the ID
|
||||
func (h *TestHelper) InsertPost(userID int, title, content string, published bool) int {
|
||||
ctx := context.Background()
|
||||
result, err := h.Adapter.NewInsert().
|
||||
Table("posts").
|
||||
Value("user_id", userID).
|
||||
Value("title", title).
|
||||
Value("content", content).
|
||||
Value("published", published).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(h.t, err)
|
||||
id, _ := result.LastInsertId()
|
||||
return int(id)
|
||||
}
|
||||
|
||||
// InsertComment inserts a test comment and returns the ID
|
||||
func (h *TestHelper) InsertComment(postID int, content string) int {
|
||||
ctx := context.Background()
|
||||
result, err := h.Adapter.NewInsert().
|
||||
Table("comments").
|
||||
Value("post_id", postID).
|
||||
Value("content", content).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(h.t, err)
|
||||
id, _ := result.LastInsertId()
|
||||
return int(id)
|
||||
}
|
||||
|
||||
// AssertUserExists checks if a user exists by email
|
||||
func (h *TestHelper) AssertUserExists(email string) {
|
||||
ctx := context.Background()
|
||||
exists, err := h.Adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("email = ?", email).
|
||||
Exists(ctx)
|
||||
|
||||
require.NoError(h.t, err)
|
||||
require.True(h.t, exists, "User with email %s should exist", email)
|
||||
}
|
||||
|
||||
// AssertUserCount asserts the number of users
|
||||
func (h *TestHelper) AssertUserCount(expected int) {
|
||||
ctx := context.Background()
|
||||
count, err := h.Adapter.NewSelect().
|
||||
Table("users").
|
||||
Count(ctx)
|
||||
|
||||
require.NoError(h.t, err)
|
||||
require.Equal(h.t, expected, count)
|
||||
}
|
||||
|
||||
// GetUserByEmail retrieves a user by email
|
||||
func (h *TestHelper) GetUserByEmail(email string) map[string]interface{} {
|
||||
ctx := context.Background()
|
||||
var results []map[string]interface{}
|
||||
err := h.Adapter.NewSelect().
|
||||
Table("users").
|
||||
Where("email = ?", email).
|
||||
Scan(ctx, &results)
|
||||
|
||||
require.NoError(h.t, err)
|
||||
require.Len(h.t, results, 1, "Expected exactly one user with email %s", email)
|
||||
return results[0]
|
||||
}
|
||||
|
||||
// BeginTestTransaction starts a transaction for testing
|
||||
func (h *TestHelper) BeginTestTransaction() (*PgSQLTxAdapter, func()) {
|
||||
ctx := context.Background()
|
||||
tx, err := h.DB.BeginTx(ctx, nil)
|
||||
require.NoError(h.t, err)
|
||||
|
||||
adapter := &PgSQLTxAdapter{tx: tx}
|
||||
cleanup := func() {
|
||||
tx.Rollback()
|
||||
}
|
||||
|
||||
return adapter, cleanup
|
||||
}
|
||||
@@ -393,6 +393,7 @@ func getValidColumnsForTable(tableName string) map[string]bool {
|
||||
// extractTableAndColumn extracts the table prefix and column name from a qualified reference
|
||||
// For example: "users.status = 'active'" returns ("users", "status")
|
||||
// Returns empty strings if no table prefix is found
|
||||
// This function is parenthesis-aware and will only look for operators outside of subqueries
|
||||
func extractTableAndColumn(cond string) (table string, column string) {
|
||||
// Common SQL operators to find the column reference
|
||||
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
||||
@@ -400,13 +401,20 @@ func extractTableAndColumn(cond string) (table string, column string) {
|
||||
var columnRef string
|
||||
|
||||
// Find the column reference (left side of the operator)
|
||||
// We need to find the first operator that appears OUTSIDE of parentheses
|
||||
minIdx := -1
|
||||
|
||||
for _, op := range operators {
|
||||
if idx := strings.Index(cond, op); idx > 0 {
|
||||
columnRef = strings.TrimSpace(cond[:idx])
|
||||
break
|
||||
idx := findOperatorOutsideParentheses(cond, op)
|
||||
if idx > 0 && (minIdx == -1 || idx < minIdx) {
|
||||
minIdx = idx
|
||||
}
|
||||
}
|
||||
|
||||
if minIdx > 0 {
|
||||
columnRef = strings.TrimSpace(cond[:minIdx])
|
||||
}
|
||||
|
||||
// If no operator found, the whole condition might be the column reference
|
||||
if columnRef == "" {
|
||||
parts := strings.Fields(cond)
|
||||
@@ -422,7 +430,45 @@ func extractTableAndColumn(cond string) (table string, column string) {
|
||||
// Remove any quotes
|
||||
columnRef = strings.Trim(columnRef, "`\"'")
|
||||
|
||||
// Check if it contains a dot (qualified reference)
|
||||
// Check if there's a function call (contains opening parenthesis)
|
||||
openParenIdx := strings.Index(columnRef, "(")
|
||||
|
||||
if openParenIdx >= 0 {
|
||||
// There's a function call - find the FIRST dot after the opening paren
|
||||
// This handles cases like: ifblnk(users.status, orders.status) - extracts users.status
|
||||
dotIdx := strings.Index(columnRef[openParenIdx:], ".")
|
||||
if dotIdx > 0 {
|
||||
dotIdx += openParenIdx // Adjust to absolute position
|
||||
|
||||
// Extract table name (between paren and dot)
|
||||
// Find the last opening paren before this dot
|
||||
lastOpenParen := strings.LastIndex(columnRef[:dotIdx], "(")
|
||||
table = columnRef[lastOpenParen+1 : dotIdx]
|
||||
|
||||
// Find the column name - it ends at comma, closing paren, whitespace, or end of string
|
||||
columnStart := dotIdx + 1
|
||||
columnEnd := len(columnRef)
|
||||
|
||||
for i := columnStart; i < len(columnRef); i++ {
|
||||
ch := columnRef[i]
|
||||
if ch == ',' || ch == ')' || ch == ' ' || ch == '\t' {
|
||||
columnEnd = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
column = columnRef[columnStart:columnEnd]
|
||||
|
||||
// Remove quotes from table and column if present
|
||||
table = strings.Trim(table, "`\"'")
|
||||
column = strings.Trim(column, "`\"'")
|
||||
|
||||
return table, column
|
||||
}
|
||||
}
|
||||
|
||||
// No function call - check if it contains a dot (qualified reference)
|
||||
// Use LastIndex to handle schema.table.column properly
|
||||
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
|
||||
table = columnRef[:dotIdx]
|
||||
column = columnRef[dotIdx+1:]
|
||||
@@ -437,6 +483,53 @@ func extractTableAndColumn(cond string) (table string, column string) {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
|
||||
// Returns the index of the operator, or -1 if not found or only found inside parentheses
|
||||
func findOperatorOutsideParentheses(s string, operator string) int {
|
||||
depth := 0
|
||||
inSingleQuote := false
|
||||
inDoubleQuote := false
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
ch := s[i]
|
||||
|
||||
// Track quote state (operators inside quotes should be ignored)
|
||||
if ch == '\'' && !inDoubleQuote {
|
||||
inSingleQuote = !inSingleQuote
|
||||
continue
|
||||
}
|
||||
if ch == '"' && !inSingleQuote {
|
||||
inDoubleQuote = !inDoubleQuote
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if we're inside quotes
|
||||
if inSingleQuote || inDoubleQuote {
|
||||
continue
|
||||
}
|
||||
|
||||
// Track parenthesis depth
|
||||
switch ch {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
}
|
||||
|
||||
// Only look for the operator when we're outside parentheses (depth == 0)
|
||||
if depth == 0 {
|
||||
// Check if the operator starts at this position
|
||||
if i+len(operator) <= len(s) {
|
||||
if s[i:i+len(operator)] == operator {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return -1
|
||||
}
|
||||
|
||||
// isValidColumn checks if a column name exists in the valid columns map
|
||||
// Handles case-insensitive comparison
|
||||
func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||
|
||||
@@ -122,6 +122,18 @@ func TestSanitizeWhereClause(t *testing.T) {
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "subquery with table alias should not be modified",
|
||||
where: "apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576)",
|
||||
tableName: "apiprovider",
|
||||
expected: "apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576)",
|
||||
},
|
||||
{
|
||||
name: "complex subquery with AND and multiple operators",
|
||||
where: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
|
||||
tableName: "apiprovider",
|
||||
expected: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -274,6 +286,48 @@ func TestExtractTableAndColumn(t *testing.T) {
|
||||
expectedTable: "",
|
||||
expectedCol: "",
|
||||
},
|
||||
{
|
||||
name: "function call with table.column - ifblnk",
|
||||
input: "ifblnk(users.status,0) in (1,2,3,4)",
|
||||
expectedTable: "users",
|
||||
expectedCol: "status",
|
||||
},
|
||||
{
|
||||
name: "function call with table.column - coalesce",
|
||||
input: "coalesce(users.age, 0) = 25",
|
||||
expectedTable: "users",
|
||||
expectedCol: "age",
|
||||
},
|
||||
{
|
||||
name: "nested function calls",
|
||||
input: "upper(trim(users.name)) = 'JOHN'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "name",
|
||||
},
|
||||
{
|
||||
name: "function with multiple args and table.column",
|
||||
input: "substring(users.email, 1, 5) = 'admin'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "email",
|
||||
},
|
||||
{
|
||||
name: "cast function with table.column",
|
||||
input: "cast(orders.total as decimal) > 100",
|
||||
expectedTable: "orders",
|
||||
expectedCol: "total",
|
||||
},
|
||||
{
|
||||
name: "complex nested functions",
|
||||
input: "coalesce(nullif(users.status, ''), 'default') = 'active'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "status",
|
||||
},
|
||||
{
|
||||
name: "function with multiple table.column refs (extracts first)",
|
||||
input: "greatest(users.created_at, users.updated_at) > '2024-01-01'",
|
||||
expectedTable: "users",
|
||||
expectedCol: "created_at",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -340,6 +394,14 @@ func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
|
||||
},
|
||||
expected: "users.status = 'active' AND Department.name = 'Engineering'",
|
||||
},
|
||||
|
||||
{
|
||||
name: "Function Call with correct table prefix - unchanged",
|
||||
where: "ifblnk(users.status,0) in (1,2,3,4)",
|
||||
tableName: "users",
|
||||
options: nil,
|
||||
expected: "ifblnk(users.status,0) in (1,2,3,4)",
|
||||
},
|
||||
{
|
||||
name: "no options provided - works as before",
|
||||
where: "wrong_table.status = 'active'",
|
||||
|
||||
@@ -24,6 +24,13 @@ type Handler struct {
|
||||
hooks *HookRegistry
|
||||
}
|
||||
|
||||
type SqlQueryOptions struct {
|
||||
GetVariablesCallback func(w http.ResponseWriter, r *http.Request) map[string]interface{}
|
||||
NoCount bool
|
||||
BlankParams bool
|
||||
AllowFilter bool
|
||||
}
|
||||
|
||||
// NewHandler creates a new function API handler
|
||||
func NewHandler(db common.Database) *Handler {
|
||||
return &Handler{
|
||||
@@ -48,7 +55,7 @@ func (h *Handler) Hooks() *HookRegistry {
|
||||
type HTTPFuncType func(http.ResponseWriter, *http.Request)
|
||||
|
||||
// SqlQueryList creates an HTTP handler that executes a SQL query and returns a list with pagination
|
||||
func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFilter bool) HTTPFuncType {
|
||||
func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFuncType {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
@@ -70,6 +77,9 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
inputvars := make([]string, 0)
|
||||
metainfo := make(map[string]interface{})
|
||||
variables := make(map[string]interface{})
|
||||
if options.GetVariablesCallback != nil {
|
||||
variables = options.GetVariablesCallback(w, r)
|
||||
}
|
||||
complexAPI := false
|
||||
|
||||
// Get user context from security package
|
||||
@@ -93,9 +103,9 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
MetaInfo: metainfo,
|
||||
PropQry: propQry,
|
||||
UserContext: userCtx,
|
||||
NoCount: pNoCount,
|
||||
BlankParams: pBlankparms,
|
||||
AllowFilter: pAllowFilter,
|
||||
NoCount: options.NoCount,
|
||||
BlankParams: options.BlankParams,
|
||||
AllowFilter: options.AllowFilter,
|
||||
ComplexAPI: complexAPI,
|
||||
}
|
||||
|
||||
@@ -131,13 +141,13 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
complexAPI = reqParams.ComplexAPI
|
||||
|
||||
// Merge query string parameters
|
||||
sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry)
|
||||
sqlquery = h.mergeQueryParams(r, sqlquery, variables, options.AllowFilter, propQry)
|
||||
|
||||
// Merge header parameters
|
||||
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
|
||||
|
||||
// Apply filters from parsed parameters (if not already applied by pAllowFilter)
|
||||
if !pAllowFilter {
|
||||
if !options.AllowFilter {
|
||||
sqlquery = h.ApplyFilters(sqlquery, reqParams)
|
||||
}
|
||||
|
||||
@@ -149,7 +159,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
|
||||
// Override pNoCount if skipcount is specified
|
||||
if reqParams.SkipCount {
|
||||
pNoCount = true
|
||||
options.NoCount = true
|
||||
}
|
||||
|
||||
// Build metainfo
|
||||
@@ -164,7 +174,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
|
||||
|
||||
// Remove unused input variables
|
||||
if pBlankparms {
|
||||
if options.BlankParams {
|
||||
for _, kw := range inputvars {
|
||||
replacement := getReplacementForBlankParam(sqlquery, kw)
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
||||
@@ -205,7 +215,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
sqlquery = fmt.Sprintf("%s \nORDER BY %s", sqlquery, ValidSQL(sortcols, "select"))
|
||||
}
|
||||
|
||||
if !pNoCount {
|
||||
if !options.NoCount {
|
||||
if limit > 0 && offset > 0 {
|
||||
sqlquery = fmt.Sprintf("%s \nLIMIT %d OFFSET %d", sqlquery, limit, offset)
|
||||
} else if limit > 0 {
|
||||
@@ -244,7 +254,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
// Normalize PostgreSQL types for proper JSON marshaling
|
||||
dbobjlist = normalizePostgresTypesList(rows)
|
||||
|
||||
if pNoCount {
|
||||
if options.NoCount {
|
||||
total = int64(len(dbobjlist))
|
||||
}
|
||||
|
||||
@@ -386,7 +396,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
||||
}
|
||||
|
||||
// SqlQuery creates an HTTP handler that executes a SQL query and returns a single record
|
||||
func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
||||
func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncType {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
@@ -406,6 +416,9 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
||||
inputvars := make([]string, 0)
|
||||
metainfo := make(map[string]interface{})
|
||||
variables := make(map[string]interface{})
|
||||
if options.GetVariablesCallback != nil {
|
||||
variables = options.GetVariablesCallback(w, r)
|
||||
}
|
||||
dbobj := make(map[string]interface{})
|
||||
complexAPI := false
|
||||
|
||||
@@ -430,7 +443,7 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
||||
MetaInfo: metainfo,
|
||||
PropQry: propQry,
|
||||
UserContext: userCtx,
|
||||
BlankParams: pBlankparms,
|
||||
BlankParams: options.BlankParams,
|
||||
ComplexAPI: complexAPI,
|
||||
}
|
||||
|
||||
@@ -507,7 +520,7 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
||||
}
|
||||
|
||||
// Remove unused input variables
|
||||
if pBlankparms {
|
||||
if options.BlankParams {
|
||||
for _, kw := range inputvars {
|
||||
replacement := getReplacementForBlankParam(sqlquery, kw)
|
||||
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
||||
|
||||
@@ -532,7 +532,7 @@ func TestSqlQuery(t *testing.T) {
|
||||
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handlerFunc := handler.SqlQuery(tt.sqlQuery, tt.blankParams)
|
||||
handlerFunc := handler.SqlQuery(tt.sqlQuery, SqlQueryOptions{BlankParams: tt.blankParams})
|
||||
handlerFunc(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
@@ -655,7 +655,7 @@ func TestSqlQueryList(t *testing.T) {
|
||||
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handlerFunc := handler.SqlQueryList(tt.sqlQuery, tt.noCount, tt.blankParams, tt.allowFilter)
|
||||
handlerFunc := handler.SqlQueryList(tt.sqlQuery, SqlQueryOptions{NoCount: tt.noCount, BlankParams: tt.blankParams, AllowFilter: tt.allowFilter})
|
||||
handlerFunc(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
|
||||
@@ -576,7 +576,7 @@ func TestHookIntegrationWithHandler(t *testing.T) {
|
||||
req := createTestRequest("GET", "/test", nil, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", false)
|
||||
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", SqlQueryOptions{})
|
||||
handlerFunc(w, req)
|
||||
|
||||
if !hookCalled {
|
||||
|
||||
@@ -127,7 +127,7 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
||||
|
||||
// Validate and filter columns in options (log warnings for invalid columns)
|
||||
validator := common.NewColumnValidator(model)
|
||||
options = filterExtendedOptions(validator, options)
|
||||
options = h.filterExtendedOptions(validator, options, model)
|
||||
|
||||
// Add request-scoped data to context (including options)
|
||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
||||
@@ -2241,7 +2241,7 @@ func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
|
||||
}
|
||||
|
||||
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
|
||||
func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions {
|
||||
func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions, model interface{}) ExtendedRequestOptions {
|
||||
filtered := options
|
||||
|
||||
// Filter base RequestOptions
|
||||
@@ -2265,12 +2265,30 @@ func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRe
|
||||
// No filtering needed for ComputedQL keys
|
||||
filtered.ComputedQL = options.ComputedQL
|
||||
|
||||
// Filter Expand columns
|
||||
// Filter Expand columns using the expand relation's model
|
||||
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
for _, expand := range options.Expand {
|
||||
filteredExpand := expand
|
||||
// Don't validate relation name, only columns
|
||||
filteredExpand.Columns = validator.FilterValidColumns(expand.Columns)
|
||||
|
||||
// Get the relationship info for this expand relation
|
||||
relInfo := h.getRelationshipInfo(modelType, expand.Relation)
|
||||
if relInfo != nil && relInfo.relatedModel != nil {
|
||||
// Create a validator for the related model
|
||||
expandValidator := common.NewColumnValidator(relInfo.relatedModel)
|
||||
// Filter columns using the related model's validator
|
||||
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
|
||||
} else {
|
||||
// If we can't find the relationship, log a warning and skip column filtering
|
||||
logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation)
|
||||
// Keep the columns as-is if we can't validate them
|
||||
filteredExpand.Columns = expand.Columns
|
||||
}
|
||||
|
||||
filteredExpands = append(filteredExpands, filteredExpand)
|
||||
}
|
||||
filtered.Expand = filteredExpands
|
||||
|
||||
@@ -29,10 +29,11 @@ type LoginRequest struct {
|
||||
|
||||
// LoginResponse contains the result of a login attempt
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
User *UserContext `json:"user"`
|
||||
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
User *UserContext `json:"user"`
|
||||
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
}
|
||||
|
||||
// LogoutRequest contains information for logout
|
||||
|
||||
@@ -111,7 +111,7 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)`
|
||||
err = a.db.QueryRowContext(ctx, query, reqJSON).Scan(&success, &errorMsg, &dataJSON)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
}
|
||||
@@ -145,7 +145,7 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)`
|
||||
err = a.db.QueryRowContext(ctx, query, reqJSON).Scan(&success, &errorMsg, &dataJSON)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
}
|
||||
@@ -297,7 +297,7 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
||||
var updatedUserJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)`
|
||||
_ = a.db.QueryRowContext(ctx, query, sessionToken, userJSON).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
}
|
||||
|
||||
// RefreshToken implements Refreshable interface
|
||||
|
||||
Reference in New Issue
Block a user