mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-14 01:20:36 +00:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
932f12ab0a | ||
|
|
b22792bad6 | ||
|
|
e8111c01aa | ||
|
|
5862016031 | ||
|
|
2f18dde29c | ||
|
|
31ad217818 | ||
|
|
7ef1d6424a | ||
|
|
c50eeac5bf | ||
|
|
6d88f2668a | ||
|
|
8a9423df6d | ||
|
|
4cc943b9d3 | ||
|
|
68dee78a34 | ||
|
|
efb9e5d9d5 | ||
|
|
490ae37c6d | ||
|
|
99307e31e6 |
82
.github/workflows/make_tag.yml
vendored
Normal file
82
.github/workflows/make_tag.yml
vendored
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
# This workflow will build a golang project
|
||||||
|
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
|
||||||
|
|
||||||
|
name: Create Go Release (Tag Versioning)
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
semver:
|
||||||
|
description: "New Version"
|
||||||
|
required: true
|
||||||
|
default: "patch"
|
||||||
|
type: choice
|
||||||
|
options:
|
||||||
|
- patch
|
||||||
|
- minor
|
||||||
|
- major
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
tag_and_commit:
|
||||||
|
name: "Tag and Commit ${{ github.event.inputs.semver }}"
|
||||||
|
runs-on: linux
|
||||||
|
permissions:
|
||||||
|
contents: write # 'write' access to repository contents
|
||||||
|
pull-requests: write # 'write' access to pull requests
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Git
|
||||||
|
run: |
|
||||||
|
git config --global user.name "Hein"
|
||||||
|
git config --global user.email "hein.puth@gmail.com"
|
||||||
|
|
||||||
|
- name: Fetch latest tag
|
||||||
|
id: latest_tag
|
||||||
|
run: |
|
||||||
|
git fetch --tags
|
||||||
|
latest_tag=$(git describe --tags `git rev-list --tags --max-count=1`)
|
||||||
|
echo "::set-output name=tag::$latest_tag"
|
||||||
|
|
||||||
|
- name: Determine new tag version
|
||||||
|
id: new_tag
|
||||||
|
run: |
|
||||||
|
current_tag=${{ steps.latest_tag.outputs.tag }}
|
||||||
|
version=$(echo $current_tag | cut -c 2-) # remove the leading 'v'
|
||||||
|
IFS='.' read -r -a version_parts <<< "$version"
|
||||||
|
major=${version_parts[0]}
|
||||||
|
minor=${version_parts[1]}
|
||||||
|
patch=${version_parts[2]}
|
||||||
|
case "${{ github.event.inputs.semver }}" in
|
||||||
|
"patch")
|
||||||
|
((patch++))
|
||||||
|
;;
|
||||||
|
"minor")
|
||||||
|
((minor++))
|
||||||
|
patch=0
|
||||||
|
;;
|
||||||
|
"release")
|
||||||
|
((major++))
|
||||||
|
minor=0
|
||||||
|
patch=0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Invalid semver input"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
new_tag="v$major.$minor.$patch"
|
||||||
|
echo "::set-output name=tag::$new_tag"
|
||||||
|
|
||||||
|
- name: Create tag
|
||||||
|
run: |
|
||||||
|
git tag -a ${{ steps.new_tag.outputs.tag }} -m "Tagging ${{ steps.new_tag.outputs.tag }} for release"
|
||||||
|
|
||||||
|
- name: Push changes
|
||||||
|
uses: ad-m/github-push-action@master
|
||||||
|
with:
|
||||||
|
github_token: ${{ secrets.BITECH_GITHUB_TOKEN }}
|
||||||
|
force: true
|
||||||
|
tags: true
|
||||||
14
.vscode/tasks.json
vendored
14
.vscode/tasks.json
vendored
@ -230,7 +230,17 @@
|
|||||||
"cwd": "${workspaceFolder}"
|
"cwd": "${workspaceFolder}"
|
||||||
},
|
},
|
||||||
"problemMatcher": [],
|
"problemMatcher": [],
|
||||||
"group": "test"
|
"group": "build"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "shell",
|
||||||
|
"label": "go: lint workspace (fix)",
|
||||||
|
"command": "golangci-lint run --timeout=5m --fix",
|
||||||
|
"options": {
|
||||||
|
"cwd": "${workspaceFolder}"
|
||||||
|
},
|
||||||
|
"problemMatcher": [],
|
||||||
|
"group": "build"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "shell",
|
"type": "shell",
|
||||||
@ -275,4 +285,4 @@
|
|||||||
"command": "sh ${workspaceFolder}/make_release.sh"
|
"command": "sh ${workspaceFolder}/make_release.sh"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -48,21 +48,42 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Log the type being scanned into
|
// Log the type being scanned into
|
||||||
logger.Debug("Debug scan into type: %s (kind: %s)", v.Type().Name(), v.Kind())
|
typeName := v.Type().String()
|
||||||
|
logger.Debug("Debug scan into type: %s (kind: %s)", typeName, v.Kind())
|
||||||
|
|
||||||
// If it's a struct, log all field types
|
// Handle slice types - inspect the element type
|
||||||
if v.Kind() == reflect.Struct {
|
var structType reflect.Type
|
||||||
for i := 0; i < v.NumField(); i++ {
|
if v.Kind() == reflect.Slice {
|
||||||
field := v.Type().Field(i)
|
elemType := v.Type().Elem()
|
||||||
fieldValue := v.Field(i)
|
logger.Debug(" Slice element type: %s", elemType)
|
||||||
|
|
||||||
|
// If slice of pointers, get the underlying type
|
||||||
|
if elemType.Kind() == reflect.Ptr {
|
||||||
|
structType = elemType.Elem()
|
||||||
|
} else {
|
||||||
|
structType = elemType
|
||||||
|
}
|
||||||
|
} else if v.Kind() == reflect.Struct {
|
||||||
|
structType = v.Type()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have a struct type, log all its fields
|
||||||
|
if structType != nil && structType.Kind() == reflect.Struct {
|
||||||
|
logger.Debug(" Struct %s has %d fields:", structType.Name(), structType.NumField())
|
||||||
|
for i := 0; i < structType.NumField(); i++ {
|
||||||
|
field := structType.Field(i)
|
||||||
|
|
||||||
// Log embedded fields specially
|
// Log embedded fields specially
|
||||||
if field.Anonymous {
|
if field.Anonymous {
|
||||||
logger.Debug(" Embedded field [%d]: %s (type: %s, kind: %s)",
|
logger.Debug(" [%d] EMBEDDED: %s (type: %s, kind: %s, bun:%q)",
|
||||||
i, field.Name, field.Type, fieldValue.Kind())
|
i, field.Name, field.Type, field.Type.Kind(), field.Tag.Get("bun"))
|
||||||
} else {
|
} else {
|
||||||
logger.Debug(" Field [%d]: %s (type: %s, kind: %s, tag: %s)",
|
bunTag := field.Tag.Get("bun")
|
||||||
i, field.Name, field.Type, fieldValue.Kind(), field.Tag.Get("bun"))
|
if bunTag == "" {
|
||||||
|
bunTag = "(no tag)"
|
||||||
|
}
|
||||||
|
logger.Debug(" [%d] %s (type: %s, kind: %s, bun:%q)",
|
||||||
|
i, field.Name, field.Type, field.Type.Kind(), bunTag)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -175,6 +196,10 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
||||||
|
return b.db
|
||||||
|
}
|
||||||
|
|
||||||
// BunSelectQuery implements SelectQuery for Bun
|
// BunSelectQuery implements SelectQuery for Bun
|
||||||
type BunSelectQuery struct {
|
type BunSelectQuery struct {
|
||||||
query *bun.SelectQuery
|
query *bun.SelectQuery
|
||||||
@ -1187,3 +1212,7 @@ func (b *BunTxAdapter) RollbackTx(ctx context.Context) error {
|
|||||||
func (b *BunTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
func (b *BunTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
||||||
return fn(b) // Already in transaction
|
return fn(b) // Already in transaction
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BunTxAdapter) GetUnderlyingDB() interface{} {
|
||||||
|
return b.tx
|
||||||
|
}
|
||||||
|
|||||||
@ -102,6 +102,10 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||||
|
return g.db
|
||||||
|
}
|
||||||
|
|
||||||
// GormSelectQuery implements SelectQuery for GORM
|
// GormSelectQuery implements SelectQuery for GORM
|
||||||
type GormSelectQuery struct {
|
type GormSelectQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
|||||||
1363
pkg/common/adapters/database/pgsql.go
Normal file
1363
pkg/common/adapters/database/pgsql.go
Normal file
File diff suppressed because it is too large
Load Diff
176
pkg/common/adapters/database/pgsql_example.go
Normal file
176
pkg/common/adapters/database/pgsql_example.go
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Example demonstrates how to use the PgSQL adapter
|
||||||
|
func ExamplePgSQLAdapter() error {
|
||||||
|
// Connect to PostgreSQL database
|
||||||
|
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||||
|
db, err := sql.Open("pgx", dsn)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to open database: %w", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Create the PgSQL adapter
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
|
||||||
|
// Enable query debugging (optional)
|
||||||
|
adapter.EnableQueryDebug()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Example 1: Simple SELECT query
|
||||||
|
var results []map[string]interface{}
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Where("age > ?", 18).
|
||||||
|
Order("created_at DESC").
|
||||||
|
Limit(10).
|
||||||
|
Scan(ctx, &results)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("select failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 2: INSERT query
|
||||||
|
result, err := adapter.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", "John Doe").
|
||||||
|
Value("email", "john@example.com").
|
||||||
|
Value("age", 25).
|
||||||
|
Returning("id").
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("insert failed: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("Rows affected: %d\n", result.RowsAffected())
|
||||||
|
|
||||||
|
// Example 3: UPDATE query
|
||||||
|
result, err = adapter.NewUpdate().
|
||||||
|
Table("users").
|
||||||
|
Set("name", "Jane Doe").
|
||||||
|
Where("id = ?", 1).
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("update failed: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("Rows updated: %d\n", result.RowsAffected())
|
||||||
|
|
||||||
|
// Example 4: DELETE query
|
||||||
|
result, err = adapter.NewDelete().
|
||||||
|
Table("users").
|
||||||
|
Where("age < ?", 18).
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete failed: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("Rows deleted: %d\n", result.RowsAffected())
|
||||||
|
|
||||||
|
// Example 5: Using transactions
|
||||||
|
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
// Insert a new user
|
||||||
|
_, err := tx.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", "Transaction User").
|
||||||
|
Value("email", "tx@example.com").
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update another user
|
||||||
|
_, err = tx.NewUpdate().
|
||||||
|
Table("users").
|
||||||
|
Set("verified", true).
|
||||||
|
Where("email = ?", "tx@example.com").
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both operations succeed or both rollback
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("transaction failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 6: JOIN query
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Table("users u").
|
||||||
|
Column("u.id", "u.name", "p.title as post_title").
|
||||||
|
LeftJoin("posts p ON p.user_id = u.id").
|
||||||
|
Where("u.active = ?", true).
|
||||||
|
Scan(ctx, &results)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("join query failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 7: Aggregation query
|
||||||
|
count, err := adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Where("active = ?", true).
|
||||||
|
Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("count failed: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("Active users: %d\n", count)
|
||||||
|
|
||||||
|
// Example 8: Raw SQL execution
|
||||||
|
_, err = adapter.Exec(ctx, "CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("raw exec failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 9: Raw SQL query
|
||||||
|
var users []map[string]interface{}
|
||||||
|
err = adapter.Query(ctx, &users, "SELECT * FROM users WHERE age > $1 LIMIT $2", 18, 10)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("raw query failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// User is an example model
|
||||||
|
type User struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Age int `json:"age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableName implements common.TableNameProvider
|
||||||
|
func (u User) TableName() string {
|
||||||
|
return "users"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleWithModel demonstrates using models with the PgSQL adapter
|
||||||
|
func ExampleWithModel() error {
|
||||||
|
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||||
|
db, err := sql.Open("pgx", dsn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Use model with adapter
|
||||||
|
user := User{}
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(&user).
|
||||||
|
Where("id = ?", 1).
|
||||||
|
Scan(ctx, &user)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
526
pkg/common/adapters/database/pgsql_integration_test.go
Normal file
526
pkg/common/adapters/database/pgsql_integration_test.go
Normal file
@ -0,0 +1,526 @@
|
|||||||
|
// +build integration
|
||||||
|
|
||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/testcontainers/testcontainers-go"
|
||||||
|
"github.com/testcontainers/testcontainers-go/wait"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Integration test models
|
||||||
|
type IntegrationUser struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
Email string `db:"email"`
|
||||||
|
Age int `db:"age"`
|
||||||
|
CreatedAt time.Time `db:"created_at"`
|
||||||
|
Posts []*IntegrationPost `bun:"rel:has-many,join:id=user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u IntegrationUser) TableName() string {
|
||||||
|
return "users"
|
||||||
|
}
|
||||||
|
|
||||||
|
type IntegrationPost struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Title string `db:"title"`
|
||||||
|
Content string `db:"content"`
|
||||||
|
UserID int `db:"user_id"`
|
||||||
|
Published bool `db:"published"`
|
||||||
|
CreatedAt time.Time `db:"created_at"`
|
||||||
|
User *IntegrationUser `bun:"rel:belongs-to,join:user_id=id"`
|
||||||
|
Comments []*IntegrationComment `bun:"rel:has-many,join:id=post_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p IntegrationPost) TableName() string {
|
||||||
|
return "posts"
|
||||||
|
}
|
||||||
|
|
||||||
|
type IntegrationComment struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Content string `db:"content"`
|
||||||
|
PostID int `db:"post_id"`
|
||||||
|
CreatedAt time.Time `db:"created_at"`
|
||||||
|
Post *IntegrationPost `bun:"rel:belongs-to,join:post_id=id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c IntegrationComment) TableName() string {
|
||||||
|
return "comments"
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupTestDB creates a PostgreSQL container and returns the connection
|
||||||
|
func setupTestDB(t *testing.T) (*sql.DB, func()) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
req := testcontainers.ContainerRequest{
|
||||||
|
Image: "postgres:15-alpine",
|
||||||
|
ExposedPorts: []string{"5432/tcp"},
|
||||||
|
Env: map[string]string{
|
||||||
|
"POSTGRES_USER": "testuser",
|
||||||
|
"POSTGRES_PASSWORD": "testpass",
|
||||||
|
"POSTGRES_DB": "testdb",
|
||||||
|
},
|
||||||
|
WaitingFor: wait.ForLog("database system is ready to accept connections").
|
||||||
|
WithOccurrence(2).
|
||||||
|
WithStartupTimeout(60 * time.Second),
|
||||||
|
}
|
||||||
|
|
||||||
|
postgres, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||||
|
ContainerRequest: req,
|
||||||
|
Started: true,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
host, err := postgres.Host(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
port, err := postgres.MappedPort(ctx, "5432")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
dsn := fmt.Sprintf("postgres://testuser:testpass@%s:%s/testdb?sslmode=disable",
|
||||||
|
host, port.Port())
|
||||||
|
|
||||||
|
db, err := sql.Open("pgx", dsn)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait for database to be ready
|
||||||
|
err = db.Ping()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create schema
|
||||||
|
createSchema(t, db)
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
db.Close()
|
||||||
|
postgres.Terminate(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSchema creates test tables
|
||||||
|
func createSchema(t *testing.T, db *sql.DB) {
|
||||||
|
schema := `
|
||||||
|
DROP TABLE IF EXISTS comments CASCADE;
|
||||||
|
DROP TABLE IF EXISTS posts CASCADE;
|
||||||
|
DROP TABLE IF EXISTS users CASCADE;
|
||||||
|
|
||||||
|
CREATE TABLE users (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
name VARCHAR(255) NOT NULL,
|
||||||
|
email VARCHAR(255) UNIQUE NOT NULL,
|
||||||
|
age INT NOT NULL,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE posts (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
title VARCHAR(255) NOT NULL,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
user_id INT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
published BOOLEAN DEFAULT false,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE comments (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
post_id INT NOT NULL REFERENCES posts(id) ON DELETE CASCADE,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
_, err := db.Exec(schema)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_BasicCRUD tests basic CRUD operations
|
||||||
|
func TestIntegration_BasicCRUD(t *testing.T) {
|
||||||
|
db, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// CREATE
|
||||||
|
result, err := adapter.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", "John Doe").
|
||||||
|
Value("email", "john@example.com").
|
||||||
|
Value("age", 25).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), result.RowsAffected())
|
||||||
|
|
||||||
|
// READ
|
||||||
|
var users []IntegrationUser
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Where("email = ?", "john@example.com").
|
||||||
|
Scan(ctx, &users)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, users, 1)
|
||||||
|
assert.Equal(t, "John Doe", users[0].Name)
|
||||||
|
assert.Equal(t, 25, users[0].Age)
|
||||||
|
|
||||||
|
userID := users[0].ID
|
||||||
|
|
||||||
|
// UPDATE
|
||||||
|
result, err = adapter.NewUpdate().
|
||||||
|
Table("users").
|
||||||
|
Set("age", 26).
|
||||||
|
Where("id = ?", userID).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), result.RowsAffected())
|
||||||
|
|
||||||
|
// Verify update
|
||||||
|
var updatedUser IntegrationUser
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Where("id = ?", userID).
|
||||||
|
Scan(ctx, &updatedUser)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 26, updatedUser.Age)
|
||||||
|
|
||||||
|
// DELETE
|
||||||
|
result, err = adapter.NewDelete().
|
||||||
|
Table("users").
|
||||||
|
Where("id = ?", userID).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), result.RowsAffected())
|
||||||
|
|
||||||
|
// Verify delete
|
||||||
|
count, err := adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Where("id = ?", userID).
|
||||||
|
Count(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_ScanModel tests ScanModel functionality
|
||||||
|
func TestIntegration_ScanModel(t *testing.T) {
|
||||||
|
db, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Insert test data
|
||||||
|
_, err := adapter.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", "Jane Smith").
|
||||||
|
Value("email", "jane@example.com").
|
||||||
|
Value("age", 30).
|
||||||
|
Exec(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test single struct scan
|
||||||
|
user := &IntegrationUser{}
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(user).
|
||||||
|
Table("users").
|
||||||
|
Where("email = ?", "jane@example.com").
|
||||||
|
ScanModel(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "Jane Smith", user.Name)
|
||||||
|
assert.Equal(t, 30, user.Age)
|
||||||
|
|
||||||
|
// Test slice scan
|
||||||
|
users := []*IntegrationUser{}
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(&users).
|
||||||
|
Table("users").
|
||||||
|
ScanModel(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, users, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Transaction tests transaction handling
|
||||||
|
func TestIntegration_Transaction(t *testing.T) {
|
||||||
|
db, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Successful transaction
|
||||||
|
err := adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
_, err := tx.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", "Alice").
|
||||||
|
Value("email", "alice@example.com").
|
||||||
|
Value("age", 28).
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", "Bob").
|
||||||
|
Value("email", "bob@example.com").
|
||||||
|
Value("age", 32).
|
||||||
|
Exec(ctx)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify both records exist
|
||||||
|
count, err := adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, count)
|
||||||
|
|
||||||
|
// Failed transaction (should rollback)
|
||||||
|
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
_, err := tx.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", "Charlie").
|
||||||
|
Value("email", "charlie@example.com").
|
||||||
|
Value("age", 35).
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Intentional error - duplicate email
|
||||||
|
_, err = tx.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", "David").
|
||||||
|
Value("email", "alice@example.com"). // Duplicate
|
||||||
|
Value("age", 40).
|
||||||
|
Exec(ctx)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// Verify rollback - count should still be 2
|
||||||
|
count, err = adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Count(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Preload tests basic preload functionality
|
||||||
|
func TestIntegration_Preload(t *testing.T) {
|
||||||
|
db, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create test data
|
||||||
|
userID := createTestUser(t, adapter, ctx, "John Doe", "john@example.com", 25)
|
||||||
|
createTestPost(t, adapter, ctx, userID, "First Post", "Content 1", true)
|
||||||
|
createTestPost(t, adapter, ctx, userID, "Second Post", "Content 2", false)
|
||||||
|
|
||||||
|
// Test Preload
|
||||||
|
var users []*IntegrationUser
|
||||||
|
err := adapter.NewSelect().
|
||||||
|
Model(&IntegrationUser{}).
|
||||||
|
Table("users").
|
||||||
|
Preload("Posts").
|
||||||
|
Scan(ctx, &users)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, users, 1)
|
||||||
|
assert.NotNil(t, users[0].Posts)
|
||||||
|
assert.Len(t, users[0].Posts, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_PreloadRelation tests smart PreloadRelation
|
||||||
|
func TestIntegration_PreloadRelation(t *testing.T) {
|
||||||
|
db, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create test data
|
||||||
|
userID := createTestUser(t, adapter, ctx, "Jane Smith", "jane@example.com", 30)
|
||||||
|
postID := createTestPost(t, adapter, ctx, userID, "Test Post", "Test Content", true)
|
||||||
|
createTestComment(t, adapter, ctx, postID, "Great post!")
|
||||||
|
createTestComment(t, adapter, ctx, postID, "Thanks for sharing!")
|
||||||
|
|
||||||
|
// Test PreloadRelation with belongs-to (should use JOIN)
|
||||||
|
var posts []*IntegrationPost
|
||||||
|
err := adapter.NewSelect().
|
||||||
|
Model(&IntegrationPost{}).
|
||||||
|
Table("posts").
|
||||||
|
PreloadRelation("User").
|
||||||
|
Scan(ctx, &posts)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, posts, 1)
|
||||||
|
// Note: JOIN preloading needs proper column selection to work
|
||||||
|
// For now, we test that it doesn't error
|
||||||
|
|
||||||
|
// Test PreloadRelation with has-many (should use subquery)
|
||||||
|
posts = []*IntegrationPost{}
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(&IntegrationPost{}).
|
||||||
|
Table("posts").
|
||||||
|
PreloadRelation("Comments").
|
||||||
|
Scan(ctx, &posts)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, posts, 1)
|
||||||
|
if posts[0].Comments != nil {
|
||||||
|
assert.Len(t, posts[0].Comments, 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_JoinRelation tests explicit JoinRelation
|
||||||
|
func TestIntegration_JoinRelation(t *testing.T) {
|
||||||
|
db, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create test data
|
||||||
|
userID := createTestUser(t, adapter, ctx, "Bob Wilson", "bob@example.com", 35)
|
||||||
|
createTestPost(t, adapter, ctx, userID, "Join Test", "Content", true)
|
||||||
|
|
||||||
|
// Test JoinRelation
|
||||||
|
var posts []*IntegrationPost
|
||||||
|
err := adapter.NewSelect().
|
||||||
|
Model(&IntegrationPost{}).
|
||||||
|
Table("posts").
|
||||||
|
JoinRelation("User").
|
||||||
|
Scan(ctx, &posts)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, posts, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_ComplexQuery tests complex queries
|
||||||
|
func TestIntegration_ComplexQuery(t *testing.T) {
|
||||||
|
db, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create test data
|
||||||
|
userID1 := createTestUser(t, adapter, ctx, "Alice", "alice@example.com", 25)
|
||||||
|
userID2 := createTestUser(t, adapter, ctx, "Bob", "bob@example.com", 30)
|
||||||
|
userID3 := createTestUser(t, adapter, ctx, "Charlie", "charlie@example.com", 35)
|
||||||
|
|
||||||
|
createTestPost(t, adapter, ctx, userID1, "Post 1", "Content", true)
|
||||||
|
createTestPost(t, adapter, ctx, userID2, "Post 2", "Content", true)
|
||||||
|
createTestPost(t, adapter, ctx, userID3, "Post 3", "Content", false)
|
||||||
|
|
||||||
|
// Complex query with joins, where, order, limit
|
||||||
|
var results []map[string]interface{}
|
||||||
|
err := adapter.NewSelect().
|
||||||
|
Table("posts p").
|
||||||
|
Column("p.title", "u.name as author_name", "u.age as author_age").
|
||||||
|
LeftJoin("users u ON u.id = p.user_id").
|
||||||
|
Where("p.published = ?", true).
|
||||||
|
WhereOr("u.age > ?", 25).
|
||||||
|
Order("u.age DESC").
|
||||||
|
Limit(2).
|
||||||
|
Scan(ctx, &results)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.LessOrEqual(t, len(results), 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Aggregation tests aggregation queries
|
||||||
|
func TestIntegration_Aggregation(t *testing.T) {
|
||||||
|
db, cleanup := setupTestDB(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create test data
|
||||||
|
createTestUser(t, adapter, ctx, "User 1", "user1@example.com", 20)
|
||||||
|
createTestUser(t, adapter, ctx, "User 2", "user2@example.com", 25)
|
||||||
|
createTestUser(t, adapter, ctx, "User 3", "user3@example.com", 30)
|
||||||
|
|
||||||
|
// Test Count
|
||||||
|
count, err := adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Where("age >= ?", 25).
|
||||||
|
Count(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, count)
|
||||||
|
|
||||||
|
// Test Exists
|
||||||
|
exists, err := adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Where("email = ?", "user1@example.com").
|
||||||
|
Exists(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists)
|
||||||
|
|
||||||
|
// Test Group By with aggregation
|
||||||
|
var results []map[string]interface{}
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Column("age", "COUNT(*) as count").
|
||||||
|
Group("age").
|
||||||
|
Having("COUNT(*) > ?", 0).
|
||||||
|
Order("age ASC").
|
||||||
|
Scan(ctx, &results)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, results, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
|
||||||
|
func createTestUser(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, name, email string, age int) int {
|
||||||
|
var userID int
|
||||||
|
err := adapter.Query(ctx, &userID,
|
||||||
|
"INSERT INTO users (name, email, age) VALUES ($1, $2, $3) RETURNING id",
|
||||||
|
name, email, age)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return userID
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestPost(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, userID int, title, content string, published bool) int {
|
||||||
|
var postID int
|
||||||
|
err := adapter.Query(ctx, &postID,
|
||||||
|
"INSERT INTO posts (title, content, user_id, published) VALUES ($1, $2, $3, $4) RETURNING id",
|
||||||
|
title, content, userID, published)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return postID
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestComment(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, postID int, content string) int {
|
||||||
|
var commentID int
|
||||||
|
err := adapter.Query(ctx, &commentID,
|
||||||
|
"INSERT INTO comments (content, post_id) VALUES ($1, $2) RETURNING id",
|
||||||
|
content, postID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return commentID
|
||||||
|
}
|
||||||
275
pkg/common/adapters/database/pgsql_preload_example.go
Normal file
275
pkg/common/adapters/database/pgsql_preload_example.go
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
_ "github.com/jackc/pgx/v5/stdlib"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Example models for demonstrating preload functionality
|
||||||
|
|
||||||
|
// Author model - has many Posts
|
||||||
|
type Author struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
Email string `db:"email"`
|
||||||
|
Posts []*Post `bun:"rel:has-many,join:id=author_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a Author) TableName() string {
|
||||||
|
return "authors"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Post model - belongs to Author, has many Comments
|
||||||
|
type Post struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Title string `db:"title"`
|
||||||
|
Content string `db:"content"`
|
||||||
|
AuthorID int `db:"author_id"`
|
||||||
|
Author *Author `bun:"rel:belongs-to,join:author_id=id"`
|
||||||
|
Comments []*Comment `bun:"rel:has-many,join:id=post_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Post) TableName() string {
|
||||||
|
return "posts"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Comment model - belongs to Post
|
||||||
|
type Comment struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Content string `db:"content"`
|
||||||
|
PostID int `db:"post_id"`
|
||||||
|
Post *Post `bun:"rel:belongs-to,join:post_id=id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Comment) TableName() string {
|
||||||
|
return "comments"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExamplePreload demonstrates the Preload functionality
|
||||||
|
func ExamplePreload() error {
|
||||||
|
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||||
|
db, err := sql.Open("pgx", dsn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Example 1: Simple Preload (uses subquery for has-many)
|
||||||
|
var authors []*Author
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(&Author{}).
|
||||||
|
Table("authors").
|
||||||
|
Preload("Posts"). // Load all posts for each author
|
||||||
|
Scan(ctx, &authors)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now authors[i].Posts will be populated with their posts
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExamplePreloadRelation demonstrates smart PreloadRelation with auto-detection
|
||||||
|
func ExamplePreloadRelation() error {
|
||||||
|
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||||
|
db, err := sql.Open("pgx", dsn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Example 1: PreloadRelation auto-detects has-many (uses subquery)
|
||||||
|
var authors []*Author
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(&Author{}).
|
||||||
|
Table("authors").
|
||||||
|
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
|
||||||
|
return q.Where("published = ?", true).Order("created_at DESC")
|
||||||
|
}).
|
||||||
|
Where("active = ?", true).
|
||||||
|
Scan(ctx, &authors)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 2: PreloadRelation auto-detects belongs-to (uses JOIN)
|
||||||
|
var posts []*Post
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(&Post{}).
|
||||||
|
Table("posts").
|
||||||
|
PreloadRelation("Author"). // Will use JOIN because it's belongs-to
|
||||||
|
Scan(ctx, &posts)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 3: Nested preloads
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(&Author{}).
|
||||||
|
Table("authors").
|
||||||
|
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
|
||||||
|
// First load posts, then preload comments for each post
|
||||||
|
return q.Limit(10)
|
||||||
|
}).
|
||||||
|
Scan(ctx, &authors)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manually load nested relationships (two-level preloading)
|
||||||
|
for _, author := range authors {
|
||||||
|
if author.Posts != nil {
|
||||||
|
for _, post := range author.Posts {
|
||||||
|
var comments []*Comment
|
||||||
|
err := adapter.NewSelect().
|
||||||
|
Table("comments").
|
||||||
|
Where("post_id = ?", post.ID).
|
||||||
|
Scan(ctx, &comments)
|
||||||
|
if err == nil {
|
||||||
|
post.Comments = comments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleJoinRelation demonstrates explicit JOIN loading
|
||||||
|
func ExampleJoinRelation() error {
|
||||||
|
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||||
|
db, err := sql.Open("pgx", dsn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Example 1: Force JOIN for belongs-to relationship
|
||||||
|
var posts []*Post
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(&Post{}).
|
||||||
|
Table("posts").
|
||||||
|
JoinRelation("Author", func(q common.SelectQuery) common.SelectQuery {
|
||||||
|
return q.Where("active = ?", true)
|
||||||
|
}).
|
||||||
|
Scan(ctx, &posts)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 2: Multiple JOINs
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(&Post{}).
|
||||||
|
Table("posts p").
|
||||||
|
Column("p.*", "a.name as author_name", "a.email as author_email").
|
||||||
|
LeftJoin("authors a ON a.id = p.author_id").
|
||||||
|
Where("p.published = ?", true).
|
||||||
|
Scan(ctx, &posts)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleScanModel demonstrates ScanModel with struct destinations
|
||||||
|
func ExampleScanModel() error {
|
||||||
|
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||||
|
db, err := sql.Open("pgx", dsn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Example 1: Scan single struct
|
||||||
|
author := Author{}
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(&author).
|
||||||
|
Table("authors").
|
||||||
|
Where("id = ?", 1).
|
||||||
|
ScanModel(ctx) // ScanModel automatically uses the model set with Model()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example 2: Scan slice of structs
|
||||||
|
authors := []*Author{}
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(&authors).
|
||||||
|
Table("authors").
|
||||||
|
Where("active = ?", true).
|
||||||
|
Limit(10).
|
||||||
|
ScanModel(ctx)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleCompleteWorkflow demonstrates a complete workflow with preloading
|
||||||
|
func ExampleCompleteWorkflow() error {
|
||||||
|
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
|
||||||
|
db, err := sql.Open("pgx", dsn)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
adapter.EnableQueryDebug() // Enable query logging
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Step 1: Create an author
|
||||||
|
author := &Author{
|
||||||
|
Name: "John Doe",
|
||||||
|
Email: "john@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := adapter.NewInsert().
|
||||||
|
Table("authors").
|
||||||
|
Value("name", author.Name).
|
||||||
|
Value("email", author.Email).
|
||||||
|
Returning("id").
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = result
|
||||||
|
|
||||||
|
// Step 2: Load author with all their posts
|
||||||
|
var loadedAuthor Author
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(&loadedAuthor).
|
||||||
|
Table("authors").
|
||||||
|
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
|
||||||
|
return q.Order("created_at DESC").Limit(5)
|
||||||
|
}).
|
||||||
|
Where("id = ?", 1).
|
||||||
|
ScanModel(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Update author name
|
||||||
|
_, err = adapter.NewUpdate().
|
||||||
|
Table("authors").
|
||||||
|
Set("name", "Jane Doe").
|
||||||
|
Where("id = ?", 1).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
629
pkg/common/adapters/database/pgsql_test.go
Normal file
629
pkg/common/adapters/database/pgsql_test.go
Normal file
@ -0,0 +1,629 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test models
|
||||||
|
type TestUser struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
Email string `db:"email"`
|
||||||
|
Age int `db:"age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u TestUser) TableName() string {
|
||||||
|
return "users"
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestPost struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Title string `db:"title"`
|
||||||
|
Content string `db:"content"`
|
||||||
|
UserID int `db:"user_id"`
|
||||||
|
User *TestUser `bun:"rel:belongs-to,join:user_id=id"`
|
||||||
|
Comments []TestComment `bun:"rel:has-many,join:id=post_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p TestPost) TableName() string {
|
||||||
|
return "posts"
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestComment struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
Content string `db:"content"`
|
||||||
|
PostID int `db:"post_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c TestComment) TableName() string {
|
||||||
|
return "comments"
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewPgSQLAdapter tests adapter creation
|
||||||
|
func TestNewPgSQLAdapter(t *testing.T) {
|
||||||
|
db, _, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
assert.NotNil(t, adapter)
|
||||||
|
assert.Equal(t, db, adapter.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgSQLSelectQuery_BuildSQL tests SQL query building
|
||||||
|
func TestPgSQLSelectQuery_BuildSQL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setup func(*PgSQLSelectQuery)
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple select",
|
||||||
|
setup: func(q *PgSQLSelectQuery) {
|
||||||
|
q.tableName = "users"
|
||||||
|
},
|
||||||
|
expected: "SELECT * FROM users",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "select with columns",
|
||||||
|
setup: func(q *PgSQLSelectQuery) {
|
||||||
|
q.tableName = "users"
|
||||||
|
q.columns = []string{"id", "name", "email"}
|
||||||
|
},
|
||||||
|
expected: "SELECT id, name, email FROM users",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "select with where",
|
||||||
|
setup: func(q *PgSQLSelectQuery) {
|
||||||
|
q.tableName = "users"
|
||||||
|
q.whereClauses = []string{"age > $1"}
|
||||||
|
q.args = []interface{}{18}
|
||||||
|
},
|
||||||
|
expected: "SELECT * FROM users WHERE (age > $1)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "select with order and limit",
|
||||||
|
setup: func(q *PgSQLSelectQuery) {
|
||||||
|
q.tableName = "users"
|
||||||
|
q.orderBy = []string{"created_at DESC"}
|
||||||
|
q.limit = 10
|
||||||
|
q.offset = 5
|
||||||
|
},
|
||||||
|
expected: "SELECT * FROM users ORDER BY created_at DESC LIMIT 10 OFFSET 5",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "select with join",
|
||||||
|
setup: func(q *PgSQLSelectQuery) {
|
||||||
|
q.tableName = "users"
|
||||||
|
q.joins = []string{"LEFT JOIN posts ON posts.user_id = users.id"}
|
||||||
|
},
|
||||||
|
expected: "SELECT * FROM users LEFT JOIN posts ON posts.user_id = users.id",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "select with group and having",
|
||||||
|
setup: func(q *PgSQLSelectQuery) {
|
||||||
|
q.tableName = "users"
|
||||||
|
q.groupBy = []string{"country"}
|
||||||
|
q.havingClauses = []string{"COUNT(*) > $1"}
|
||||||
|
q.args = []interface{}{5}
|
||||||
|
},
|
||||||
|
expected: "SELECT * FROM users GROUP BY country HAVING COUNT(*) > $1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
q := &PgSQLSelectQuery{
|
||||||
|
columns: []string{"*"},
|
||||||
|
}
|
||||||
|
tt.setup(q)
|
||||||
|
sql := q.buildSQL()
|
||||||
|
assert.Equal(t, tt.expected, sql)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgSQLSelectQuery_ReplacePlaceholders tests placeholder replacement
|
||||||
|
func TestPgSQLSelectQuery_ReplacePlaceholders(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
query string
|
||||||
|
argCount int
|
||||||
|
paramCounter int
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single placeholder",
|
||||||
|
query: "age > ?",
|
||||||
|
argCount: 1,
|
||||||
|
paramCounter: 0,
|
||||||
|
expected: "age > $1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple placeholders",
|
||||||
|
query: "age > ? AND status = ?",
|
||||||
|
argCount: 2,
|
||||||
|
paramCounter: 0,
|
||||||
|
expected: "age > $1 AND status = $2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with existing counter",
|
||||||
|
query: "name = ?",
|
||||||
|
argCount: 1,
|
||||||
|
paramCounter: 5,
|
||||||
|
expected: "name = $6",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
q := &PgSQLSelectQuery{paramCounter: tt.paramCounter}
|
||||||
|
result := q.replacePlaceholders(tt.query, tt.argCount)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgSQLSelectQuery_Chaining tests method chaining
|
||||||
|
func TestPgSQLSelectQuery_Chaining(t *testing.T) {
|
||||||
|
db, _, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
query := adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Column("id", "name").
|
||||||
|
Where("age > ?", 18).
|
||||||
|
Order("name ASC").
|
||||||
|
Limit(10).
|
||||||
|
Offset(5)
|
||||||
|
|
||||||
|
pgQuery := query.(*PgSQLSelectQuery)
|
||||||
|
assert.Equal(t, "users", pgQuery.tableName)
|
||||||
|
assert.Equal(t, []string{"id", "name"}, pgQuery.columns)
|
||||||
|
assert.Len(t, pgQuery.whereClauses, 1)
|
||||||
|
assert.Equal(t, []string{"name ASC"}, pgQuery.orderBy)
|
||||||
|
assert.Equal(t, 10, pgQuery.limit)
|
||||||
|
assert.Equal(t, 5, pgQuery.offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgSQLSelectQuery_Model tests model setting
|
||||||
|
func TestPgSQLSelectQuery_Model(t *testing.T) {
|
||||||
|
db, _, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
user := &TestUser{}
|
||||||
|
query := adapter.NewSelect().Model(user)
|
||||||
|
|
||||||
|
pgQuery := query.(*PgSQLSelectQuery)
|
||||||
|
assert.Equal(t, "users", pgQuery.tableName)
|
||||||
|
assert.Equal(t, user, pgQuery.model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScanRowsToStructSlice tests scanning rows into struct slice
|
||||||
|
func TestScanRowsToStructSlice(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
|
||||||
|
AddRow(1, "John Doe", "john@example.com", 25).
|
||||||
|
AddRow(2, "Jane Smith", "jane@example.com", 30)
|
||||||
|
|
||||||
|
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var users []TestUser
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Scan(ctx, &users)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, users, 2)
|
||||||
|
assert.Equal(t, "John Doe", users[0].Name)
|
||||||
|
assert.Equal(t, "jane@example.com", users[1].Email)
|
||||||
|
assert.Equal(t, 30, users[1].Age)
|
||||||
|
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScanRowsToStructSlicePointers tests scanning rows into pointer slice
|
||||||
|
func TestScanRowsToStructSlicePointers(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
|
||||||
|
AddRow(1, "John Doe", "john@example.com", 25)
|
||||||
|
|
||||||
|
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var users []*TestUser
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Scan(ctx, &users)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, users, 1)
|
||||||
|
assert.NotNil(t, users[0])
|
||||||
|
assert.Equal(t, "John Doe", users[0].Name)
|
||||||
|
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScanRowsToSingleStruct tests scanning a single row
|
||||||
|
func TestScanRowsToSingleStruct(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
|
||||||
|
AddRow(1, "John Doe", "john@example.com", 25)
|
||||||
|
|
||||||
|
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var user TestUser
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Where("id = ?", 1).
|
||||||
|
Scan(ctx, &user)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, user.ID)
|
||||||
|
assert.Equal(t, "John Doe", user.Name)
|
||||||
|
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScanRowsToMapSlice tests scanning into map slice
|
||||||
|
func TestScanRowsToMapSlice(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"id", "name", "email"}).
|
||||||
|
AddRow(1, "John Doe", "john@example.com").
|
||||||
|
AddRow(2, "Jane Smith", "jane@example.com")
|
||||||
|
|
||||||
|
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var results []map[string]interface{}
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Scan(ctx, &results)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, results, 2)
|
||||||
|
assert.Equal(t, int64(1), results[0]["id"])
|
||||||
|
assert.Equal(t, "John Doe", results[0]["name"])
|
||||||
|
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgSQLInsertQuery_Exec tests insert query execution
|
||||||
|
func TestPgSQLInsertQuery_Exec(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectExec("INSERT INTO users").
|
||||||
|
WithArgs("John Doe", "john@example.com", 25).
|
||||||
|
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result, err := adapter.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", "John Doe").
|
||||||
|
Value("email", "john@example.com").
|
||||||
|
Value("age", 25).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, int64(1), result.RowsAffected())
|
||||||
|
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgSQLUpdateQuery_Exec tests update query execution
|
||||||
|
func TestPgSQLUpdateQuery_Exec(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Note: Args order is SET values first, then WHERE values
|
||||||
|
mock.ExpectExec("UPDATE users SET name = \\$1 WHERE id = \\$2").
|
||||||
|
WithArgs("Jane Doe", 1).
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result, err := adapter.NewUpdate().
|
||||||
|
Table("users").
|
||||||
|
Set("name", "Jane Doe").
|
||||||
|
Where("id = ?", 1).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, int64(1), result.RowsAffected())
|
||||||
|
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgSQLDeleteQuery_Exec tests delete query execution
|
||||||
|
func TestPgSQLDeleteQuery_Exec(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectExec("DELETE FROM users WHERE id = \\$1").
|
||||||
|
WithArgs(1).
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
result, err := adapter.NewDelete().
|
||||||
|
Table("users").
|
||||||
|
Where("id = ?", 1).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, result)
|
||||||
|
assert.Equal(t, int64(1), result.RowsAffected())
|
||||||
|
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgSQLSelectQuery_Count tests count query
|
||||||
|
func TestPgSQLSelectQuery_Count(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"count"}).AddRow(42)
|
||||||
|
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows)
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
count, err := adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Count(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 42, count)
|
||||||
|
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgSQLSelectQuery_Exists tests exists query
|
||||||
|
func TestPgSQLSelectQuery_Exists(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"count"}).AddRow(1)
|
||||||
|
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows)
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
exists, err := adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Where("email = ?", "john@example.com").
|
||||||
|
Exists(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists)
|
||||||
|
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgSQLAdapter_Transaction tests transaction handling
|
||||||
|
func TestPgSQLAdapter_Transaction(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec("INSERT INTO users").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||||
|
mock.ExpectCommit()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
_, err := tx.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", "John").
|
||||||
|
Exec(ctx)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPgSQLAdapter_TransactionRollback tests transaction rollback
|
||||||
|
func TestPgSQLAdapter_TransactionRollback(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
mock.ExpectBegin()
|
||||||
|
mock.ExpectExec("INSERT INTO users").WillReturnError(sql.ErrConnDone)
|
||||||
|
mock.ExpectRollback()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
_, err := tx.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", "John").
|
||||||
|
Exec(ctx)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBuildFieldMap tests field mapping construction
|
||||||
|
func TestBuildFieldMap(t *testing.T) {
|
||||||
|
userType := reflect.TypeOf(TestUser{})
|
||||||
|
fieldMap := buildFieldMap(userType, nil)
|
||||||
|
|
||||||
|
assert.NotEmpty(t, fieldMap)
|
||||||
|
|
||||||
|
// Check that fields are mapped
|
||||||
|
assert.Contains(t, fieldMap, "id")
|
||||||
|
assert.Contains(t, fieldMap, "name")
|
||||||
|
assert.Contains(t, fieldMap, "email")
|
||||||
|
assert.Contains(t, fieldMap, "age")
|
||||||
|
|
||||||
|
// Check field info
|
||||||
|
idInfo := fieldMap["id"]
|
||||||
|
assert.Equal(t, "ID", idInfo.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetRelationMetadata tests relationship metadata extraction
|
||||||
|
func TestGetRelationMetadata(t *testing.T) {
|
||||||
|
q := &PgSQLSelectQuery{
|
||||||
|
model: &TestPost{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test belongs-to relationship
|
||||||
|
meta := q.getRelationMetadata("User")
|
||||||
|
assert.NotNil(t, meta)
|
||||||
|
assert.Equal(t, "User", meta.fieldName)
|
||||||
|
|
||||||
|
// Test has-many relationship
|
||||||
|
meta = q.getRelationMetadata("Comments")
|
||||||
|
assert.NotNil(t, meta)
|
||||||
|
assert.Equal(t, "Comments", meta.fieldName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPreloadConfiguration tests preload configuration
|
||||||
|
func TestPreloadConfiguration(t *testing.T) {
|
||||||
|
db, _, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
|
||||||
|
// Test Preload
|
||||||
|
query := adapter.NewSelect().
|
||||||
|
Model(&TestPost{}).
|
||||||
|
Table("posts").
|
||||||
|
Preload("User")
|
||||||
|
|
||||||
|
pgQuery := query.(*PgSQLSelectQuery)
|
||||||
|
assert.Len(t, pgQuery.preloads, 1)
|
||||||
|
assert.Equal(t, "User", pgQuery.preloads[0].relation)
|
||||||
|
assert.False(t, pgQuery.preloads[0].useJoin)
|
||||||
|
|
||||||
|
// Test PreloadRelation
|
||||||
|
query = adapter.NewSelect().
|
||||||
|
Model(&TestPost{}).
|
||||||
|
Table("posts").
|
||||||
|
PreloadRelation("Comments")
|
||||||
|
|
||||||
|
pgQuery = query.(*PgSQLSelectQuery)
|
||||||
|
assert.Len(t, pgQuery.preloads, 1)
|
||||||
|
assert.Equal(t, "Comments", pgQuery.preloads[0].relation)
|
||||||
|
|
||||||
|
// Test JoinRelation
|
||||||
|
query = adapter.NewSelect().
|
||||||
|
Model(&TestPost{}).
|
||||||
|
Table("posts").
|
||||||
|
JoinRelation("User")
|
||||||
|
|
||||||
|
pgQuery = query.(*PgSQLSelectQuery)
|
||||||
|
assert.Len(t, pgQuery.preloads, 1)
|
||||||
|
assert.Equal(t, "User", pgQuery.preloads[0].relation)
|
||||||
|
assert.True(t, pgQuery.preloads[0].useJoin)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestScanModel tests ScanModel functionality
|
||||||
|
func TestScanModel(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
|
||||||
|
AddRow(1, "John Doe", "john@example.com", 25)
|
||||||
|
|
||||||
|
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
user := &TestUser{}
|
||||||
|
err = adapter.NewSelect().
|
||||||
|
Model(user).
|
||||||
|
Table("users").
|
||||||
|
Where("id = ?", 1).
|
||||||
|
ScanModel(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, user.ID)
|
||||||
|
assert.Equal(t, "John Doe", user.Name)
|
||||||
|
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRawSQL tests raw SQL execution
|
||||||
|
func TestRawSQL(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
// Test Exec
|
||||||
|
mock.ExpectExec("CREATE TABLE test").WillReturnResult(sqlmock.NewResult(0, 0))
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err = adapter.Exec(ctx, "CREATE TABLE test (id INT)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test Query
|
||||||
|
rows := sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "Test")
|
||||||
|
mock.ExpectQuery("SELECT (.+) FROM test").WillReturnRows(rows)
|
||||||
|
|
||||||
|
var results []map[string]interface{}
|
||||||
|
err = adapter.Query(ctx, &results, "SELECT * FROM test WHERE id = $1", 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, results, 1)
|
||||||
|
|
||||||
|
assert.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
}
|
||||||
132
pkg/common/adapters/database/test_helpers.go
Normal file
132
pkg/common/adapters/database/test_helpers.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestHelper provides utilities for database testing
|
||||||
|
type TestHelper struct {
|
||||||
|
DB *sql.DB
|
||||||
|
Adapter *PgSQLAdapter
|
||||||
|
t *testing.T
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestHelper creates a new test helper
|
||||||
|
func NewTestHelper(t *testing.T, db *sql.DB) *TestHelper {
|
||||||
|
return &TestHelper{
|
||||||
|
DB: db,
|
||||||
|
Adapter: NewPgSQLAdapter(db),
|
||||||
|
t: t,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupTables truncates all test tables
|
||||||
|
func (h *TestHelper) CleanupTables() {
|
||||||
|
ctx := context.Background()
|
||||||
|
tables := []string{"comments", "posts", "users"}
|
||||||
|
|
||||||
|
for _, table := range tables {
|
||||||
|
_, err := h.DB.ExecContext(ctx, "TRUNCATE TABLE "+table+" CASCADE")
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertUser inserts a test user and returns the ID
|
||||||
|
func (h *TestHelper) InsertUser(name, email string, age int) int {
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := h.Adapter.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", name).
|
||||||
|
Value("email", email).
|
||||||
|
Value("age", age).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
id, _ := result.LastInsertId()
|
||||||
|
return int(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertPost inserts a test post and returns the ID
|
||||||
|
func (h *TestHelper) InsertPost(userID int, title, content string, published bool) int {
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := h.Adapter.NewInsert().
|
||||||
|
Table("posts").
|
||||||
|
Value("user_id", userID).
|
||||||
|
Value("title", title).
|
||||||
|
Value("content", content).
|
||||||
|
Value("published", published).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
id, _ := result.LastInsertId()
|
||||||
|
return int(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertComment inserts a test comment and returns the ID
|
||||||
|
func (h *TestHelper) InsertComment(postID int, content string) int {
|
||||||
|
ctx := context.Background()
|
||||||
|
result, err := h.Adapter.NewInsert().
|
||||||
|
Table("comments").
|
||||||
|
Value("post_id", postID).
|
||||||
|
Value("content", content).
|
||||||
|
Exec(ctx)
|
||||||
|
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
id, _ := result.LastInsertId()
|
||||||
|
return int(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssertUserExists checks if a user exists by email
|
||||||
|
func (h *TestHelper) AssertUserExists(email string) {
|
||||||
|
ctx := context.Background()
|
||||||
|
exists, err := h.Adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Where("email = ?", email).
|
||||||
|
Exists(ctx)
|
||||||
|
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
require.True(h.t, exists, "User with email %s should exist", email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AssertUserCount asserts the number of users
|
||||||
|
func (h *TestHelper) AssertUserCount(expected int) {
|
||||||
|
ctx := context.Background()
|
||||||
|
count, err := h.Adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Count(ctx)
|
||||||
|
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
require.Equal(h.t, expected, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByEmail retrieves a user by email
|
||||||
|
func (h *TestHelper) GetUserByEmail(email string) map[string]interface{} {
|
||||||
|
ctx := context.Background()
|
||||||
|
var results []map[string]interface{}
|
||||||
|
err := h.Adapter.NewSelect().
|
||||||
|
Table("users").
|
||||||
|
Where("email = ?", email).
|
||||||
|
Scan(ctx, &results)
|
||||||
|
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
require.Len(h.t, results, 1, "Expected exactly one user with email %s", email)
|
||||||
|
return results[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginTestTransaction starts a transaction for testing
|
||||||
|
func (h *TestHelper) BeginTestTransaction() (*PgSQLTxAdapter, func()) {
|
||||||
|
ctx := context.Background()
|
||||||
|
tx, err := h.DB.BeginTx(ctx, nil)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
adapter := &PgSQLTxAdapter{tx: tx}
|
||||||
|
cleanup := func() {
|
||||||
|
tx.Rollback()
|
||||||
|
}
|
||||||
|
|
||||||
|
return adapter, cleanup
|
||||||
|
}
|
||||||
@ -24,6 +24,12 @@ type Database interface {
|
|||||||
CommitTx(ctx context.Context) error
|
CommitTx(ctx context.Context) error
|
||||||
RollbackTx(ctx context.Context) error
|
RollbackTx(ctx context.Context) error
|
||||||
RunInTransaction(ctx context.Context, fn func(Database) error) error
|
RunInTransaction(ctx context.Context, fn func(Database) error) error
|
||||||
|
|
||||||
|
// GetUnderlyingDB returns the underlying database connection
|
||||||
|
// For GORM, this returns *gorm.DB
|
||||||
|
// For Bun, this returns *bun.DB
|
||||||
|
// This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN
|
||||||
|
GetUnderlyingDB() interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)
|
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)
|
||||||
|
|||||||
@ -2,6 +2,7 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@ -207,6 +208,20 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if tableName != "" && !hasTablePrefix(condToCheck) {
|
||||||
|
// If tableName is provided and the condition DOESN'T have a table prefix,
|
||||||
|
// qualify unambiguous column references to prevent "ambiguous column" errors
|
||||||
|
// when there are multiple joins on the same table (e.g., recursive preloads)
|
||||||
|
columnName := extractUnqualifiedColumnName(condToCheck)
|
||||||
|
if columnName != "" && (validColumns == nil || isValidColumn(columnName, validColumns)) {
|
||||||
|
// Qualify the column with the table name
|
||||||
|
// Be careful to only replace the column name, not other occurrences of the string
|
||||||
|
oldRef := columnName
|
||||||
|
newRef := tableName + "." + columnName
|
||||||
|
// Use word boundary matching to avoid replacing partial matches
|
||||||
|
cond = qualifyColumnInCondition(cond, oldRef, newRef)
|
||||||
|
logger.Debug("Qualified unqualified column in condition: '%s' added table prefix '%s'", oldRef, tableName)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
validConditions = append(validConditions, cond)
|
validConditions = append(validConditions, cond)
|
||||||
@ -430,7 +445,45 @@ func extractTableAndColumn(cond string) (table string, column string) {
|
|||||||
// Remove any quotes
|
// Remove any quotes
|
||||||
columnRef = strings.Trim(columnRef, "`\"'")
|
columnRef = strings.Trim(columnRef, "`\"'")
|
||||||
|
|
||||||
// Check if it contains a dot (qualified reference)
|
// Check if there's a function call (contains opening parenthesis)
|
||||||
|
openParenIdx := strings.Index(columnRef, "(")
|
||||||
|
|
||||||
|
if openParenIdx >= 0 {
|
||||||
|
// There's a function call - find the FIRST dot after the opening paren
|
||||||
|
// This handles cases like: ifblnk(users.status, orders.status) - extracts users.status
|
||||||
|
dotIdx := strings.Index(columnRef[openParenIdx:], ".")
|
||||||
|
if dotIdx > 0 {
|
||||||
|
dotIdx += openParenIdx // Adjust to absolute position
|
||||||
|
|
||||||
|
// Extract table name (between paren and dot)
|
||||||
|
// Find the last opening paren before this dot
|
||||||
|
lastOpenParen := strings.LastIndex(columnRef[:dotIdx], "(")
|
||||||
|
table = columnRef[lastOpenParen+1 : dotIdx]
|
||||||
|
|
||||||
|
// Find the column name - it ends at comma, closing paren, whitespace, or end of string
|
||||||
|
columnStart := dotIdx + 1
|
||||||
|
columnEnd := len(columnRef)
|
||||||
|
|
||||||
|
for i := columnStart; i < len(columnRef); i++ {
|
||||||
|
ch := columnRef[i]
|
||||||
|
if ch == ',' || ch == ')' || ch == ' ' || ch == '\t' {
|
||||||
|
columnEnd = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
column = columnRef[columnStart:columnEnd]
|
||||||
|
|
||||||
|
// Remove quotes from table and column if present
|
||||||
|
table = strings.Trim(table, "`\"'")
|
||||||
|
column = strings.Trim(column, "`\"'")
|
||||||
|
|
||||||
|
return table, column
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No function call - check if it contains a dot (qualified reference)
|
||||||
|
// Use LastIndex to handle schema.table.column properly
|
||||||
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
|
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
|
||||||
table = columnRef[:dotIdx]
|
table = columnRef[:dotIdx]
|
||||||
column = columnRef[dotIdx+1:]
|
column = columnRef[dotIdx+1:]
|
||||||
@ -445,6 +498,86 @@ func extractTableAndColumn(cond string) (table string, column string) {
|
|||||||
return "", ""
|
return "", ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractUnqualifiedColumnName extracts the column name from an unqualified condition
|
||||||
|
// For example: "rid_parentmastertaskitem is null" returns "rid_parentmastertaskitem"
|
||||||
|
// "status = 'active'" returns "status"
|
||||||
|
func extractUnqualifiedColumnName(cond string) string {
|
||||||
|
// Common SQL operators
|
||||||
|
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
|
||||||
|
|
||||||
|
// Find the column reference (left side of the operator)
|
||||||
|
minIdx := -1
|
||||||
|
for _, op := range operators {
|
||||||
|
idx := strings.Index(cond, op)
|
||||||
|
if idx > 0 && (minIdx == -1 || idx < minIdx) {
|
||||||
|
minIdx = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var columnRef string
|
||||||
|
if minIdx > 0 {
|
||||||
|
columnRef = strings.TrimSpace(cond[:minIdx])
|
||||||
|
} else {
|
||||||
|
// No operator found, might be a single column reference
|
||||||
|
parts := strings.Fields(cond)
|
||||||
|
if len(parts) > 0 {
|
||||||
|
columnRef = parts[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if columnRef == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove any quotes
|
||||||
|
columnRef = strings.Trim(columnRef, "`\"'")
|
||||||
|
|
||||||
|
// Return empty if it contains a dot (already qualified) or function call
|
||||||
|
if strings.Contains(columnRef, ".") || strings.Contains(columnRef, "(") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return columnRef
|
||||||
|
}
|
||||||
|
|
||||||
|
// qualifyColumnInCondition replaces an unqualified column name with a qualified one in a condition
|
||||||
|
// Uses word boundaries to avoid partial matches
|
||||||
|
// For example: qualifyColumnInCondition("rid_item is null", "rid_item", "table.rid_item")
|
||||||
|
// returns "table.rid_item is null"
|
||||||
|
func qualifyColumnInCondition(cond, oldRef, newRef string) string {
|
||||||
|
// Use word boundary matching with Go's supported regex syntax
|
||||||
|
// \b matches word boundaries
|
||||||
|
escapedOld := regexp.QuoteMeta(oldRef)
|
||||||
|
pattern := `\b` + escapedOld + `\b`
|
||||||
|
|
||||||
|
re, err := regexp.Compile(pattern)
|
||||||
|
if err != nil {
|
||||||
|
// If regex fails, fall back to simple string replacement
|
||||||
|
logger.Debug("Failed to compile regex for column qualification, using simple replace: %v", err)
|
||||||
|
return strings.Replace(cond, oldRef, newRef, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only replace if the match is not preceded by a dot (to avoid replacing already qualified columns)
|
||||||
|
result := cond
|
||||||
|
matches := re.FindAllStringIndex(cond, -1)
|
||||||
|
|
||||||
|
// Process matches in reverse order to maintain correct indices
|
||||||
|
for i := len(matches) - 1; i >= 0; i-- {
|
||||||
|
match := matches[i]
|
||||||
|
start := match[0]
|
||||||
|
|
||||||
|
// Check if preceded by a dot (already qualified)
|
||||||
|
if start > 0 && cond[start-1] == '.' {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace this occurrence
|
||||||
|
result = result[:start] + newRef + result[match[1]:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
|
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
|
||||||
// Returns the index of the operator, or -1 if not found or only found inside parentheses
|
// Returns the index of the operator, or -1 if not found or only found inside parentheses
|
||||||
func findOperatorOutsideParentheses(s string, operator string) int {
|
func findOperatorOutsideParentheses(s string, operator string) int {
|
||||||
|
|||||||
@ -33,16 +33,16 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid condition with parentheses - no prefix added",
|
name: "valid condition with parentheses - prefix added to prevent ambiguity",
|
||||||
where: "(status = 'active')",
|
where: "(status = 'active')",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "status = 'active'",
|
expected: "users.status = 'active'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "mixed trivial and valid conditions - no prefix added",
|
name: "mixed trivial and valid conditions - prefix added",
|
||||||
where: "true AND status = 'active' AND 1=1",
|
where: "true AND status = 'active' AND 1=1",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "status = 'active'",
|
expected: "users.status = 'active'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "condition with correct table prefix - unchanged",
|
name: "condition with correct table prefix - unchanged",
|
||||||
@ -63,10 +63,10 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
expected: "users.status = 'active' AND users.age > 18",
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple valid conditions without prefix - no prefix added",
|
name: "multiple valid conditions without prefix - prefixes added",
|
||||||
where: "status = 'active' AND age > 18",
|
where: "status = 'active' AND age > 18",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "status = 'active' AND age > 18",
|
expected: "users.status = 'active' AND users.age > 18",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "no table name provided",
|
name: "no table name provided",
|
||||||
@ -90,13 +90,13 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
name: "mixed case AND operators",
|
name: "mixed case AND operators",
|
||||||
where: "status = 'active' AND age > 18 and name = 'John'",
|
where: "status = 'active' AND age > 18 and name = 'John'",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "status = 'active' AND age > 18 AND name = 'John'",
|
expected: "users.status = 'active' AND users.age > 18 AND users.name = 'John'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "subquery with ORDER BY and LIMIT - allowed",
|
name: "subquery with ORDER BY and LIMIT - allowed",
|
||||||
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||||
tableName: "users",
|
tableName: "users",
|
||||||
expected: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
expected: "users.id IN (SELECT users.id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "dangerous DELETE keyword - blocked",
|
name: "dangerous DELETE keyword - blocked",
|
||||||
@ -286,6 +286,48 @@ func TestExtractTableAndColumn(t *testing.T) {
|
|||||||
expectedTable: "",
|
expectedTable: "",
|
||||||
expectedCol: "",
|
expectedCol: "",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "function call with table.column - ifblnk",
|
||||||
|
input: "ifblnk(users.status,0) in (1,2,3,4)",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "status",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function call with table.column - coalesce",
|
||||||
|
input: "coalesce(users.age, 0) = 25",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "age",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested function calls",
|
||||||
|
input: "upper(trim(users.name)) = 'JOHN'",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "name",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function with multiple args and table.column",
|
||||||
|
input: "substring(users.email, 1, 5) = 'admin'",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "email",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cast function with table.column",
|
||||||
|
input: "cast(orders.total as decimal) > 100",
|
||||||
|
expectedTable: "orders",
|
||||||
|
expectedCol: "total",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex nested functions",
|
||||||
|
input: "coalesce(nullif(users.status, ''), 'default') = 'active'",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "status",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "function with multiple table.column refs (extracts first)",
|
||||||
|
input: "greatest(users.created_at, users.updated_at) > '2024-01-01'",
|
||||||
|
expectedTable: "users",
|
||||||
|
expectedCol: "created_at",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@ -352,6 +394,14 @@ func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expected: "users.status = 'active' AND Department.name = 'Engineering'",
|
expected: "users.status = 'active' AND Department.name = 'Engineering'",
|
||||||
},
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "Function Call with correct table prefix - unchanged",
|
||||||
|
where: "ifblnk(users.status,0) in (1,2,3,4)",
|
||||||
|
tableName: "users",
|
||||||
|
options: nil,
|
||||||
|
expected: "ifblnk(users.status,0) in (1,2,3,4)",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "no options provided - works as before",
|
name: "no options provided - works as before",
|
||||||
where: "wrong_table.status = 'active'",
|
where: "wrong_table.status = 'active'",
|
||||||
|
|||||||
@ -12,6 +12,7 @@ type Config struct {
|
|||||||
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
||||||
CORS CORSConfig `mapstructure:"cors"`
|
CORS CORSConfig `mapstructure:"cors"`
|
||||||
Database DatabaseConfig `mapstructure:"database"`
|
Database DatabaseConfig `mapstructure:"database"`
|
||||||
|
EventBroker EventBrokerConfig `mapstructure:"event_broker"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig holds server-related configuration
|
// ServerConfig holds server-related configuration
|
||||||
@ -91,3 +92,52 @@ type ErrorTrackingConfig struct {
|
|||||||
SampleRate float64 `mapstructure:"sample_rate"` // Error sample rate (0.0-1.0)
|
SampleRate float64 `mapstructure:"sample_rate"` // Error sample rate (0.0-1.0)
|
||||||
TracesSampleRate float64 `mapstructure:"traces_sample_rate"` // Traces sample rate (0.0-1.0)
|
TracesSampleRate float64 `mapstructure:"traces_sample_rate"` // Traces sample rate (0.0-1.0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EventBrokerConfig contains configuration for the event broker
|
||||||
|
type EventBrokerConfig struct {
|
||||||
|
Enabled bool `mapstructure:"enabled"`
|
||||||
|
Provider string `mapstructure:"provider"` // memory, redis, nats, database
|
||||||
|
Mode string `mapstructure:"mode"` // sync, async
|
||||||
|
WorkerCount int `mapstructure:"worker_count"`
|
||||||
|
BufferSize int `mapstructure:"buffer_size"`
|
||||||
|
InstanceID string `mapstructure:"instance_id"`
|
||||||
|
Redis EventBrokerRedisConfig `mapstructure:"redis"`
|
||||||
|
NATS EventBrokerNATSConfig `mapstructure:"nats"`
|
||||||
|
Database EventBrokerDatabaseConfig `mapstructure:"database"`
|
||||||
|
RetryPolicy EventBrokerRetryPolicyConfig `mapstructure:"retry_policy"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventBrokerRedisConfig contains Redis-specific configuration
|
||||||
|
type EventBrokerRedisConfig struct {
|
||||||
|
StreamName string `mapstructure:"stream_name"`
|
||||||
|
ConsumerGroup string `mapstructure:"consumer_group"`
|
||||||
|
MaxLen int64 `mapstructure:"max_len"`
|
||||||
|
Host string `mapstructure:"host"`
|
||||||
|
Port int `mapstructure:"port"`
|
||||||
|
Password string `mapstructure:"password"`
|
||||||
|
DB int `mapstructure:"db"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventBrokerNATSConfig contains NATS-specific configuration
|
||||||
|
type EventBrokerNATSConfig struct {
|
||||||
|
URL string `mapstructure:"url"`
|
||||||
|
StreamName string `mapstructure:"stream_name"`
|
||||||
|
Subjects []string `mapstructure:"subjects"`
|
||||||
|
Storage string `mapstructure:"storage"` // file, memory
|
||||||
|
MaxAge time.Duration `mapstructure:"max_age"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventBrokerDatabaseConfig contains database provider configuration
|
||||||
|
type EventBrokerDatabaseConfig struct {
|
||||||
|
TableName string `mapstructure:"table_name"`
|
||||||
|
Channel string `mapstructure:"channel"` // PostgreSQL NOTIFY channel name
|
||||||
|
PollInterval time.Duration `mapstructure:"poll_interval"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventBrokerRetryPolicyConfig contains retry policy configuration
|
||||||
|
type EventBrokerRetryPolicyConfig struct {
|
||||||
|
MaxRetries int `mapstructure:"max_retries"`
|
||||||
|
InitialDelay time.Duration `mapstructure:"initial_delay"`
|
||||||
|
MaxDelay time.Duration `mapstructure:"max_delay"`
|
||||||
|
BackoffFactor float64 `mapstructure:"backoff_factor"`
|
||||||
|
}
|
||||||
|
|||||||
@ -165,4 +165,39 @@ func setDefaults(v *viper.Viper) {
|
|||||||
|
|
||||||
// Database defaults
|
// Database defaults
|
||||||
v.SetDefault("database.url", "")
|
v.SetDefault("database.url", "")
|
||||||
|
|
||||||
|
// Event Broker defaults
|
||||||
|
v.SetDefault("event_broker.enabled", false)
|
||||||
|
v.SetDefault("event_broker.provider", "memory")
|
||||||
|
v.SetDefault("event_broker.mode", "async")
|
||||||
|
v.SetDefault("event_broker.worker_count", 10)
|
||||||
|
v.SetDefault("event_broker.buffer_size", 1000)
|
||||||
|
v.SetDefault("event_broker.instance_id", "")
|
||||||
|
|
||||||
|
// Event Broker - Redis defaults
|
||||||
|
v.SetDefault("event_broker.redis.stream_name", "resolvespec:events")
|
||||||
|
v.SetDefault("event_broker.redis.consumer_group", "resolvespec-workers")
|
||||||
|
v.SetDefault("event_broker.redis.max_len", 10000)
|
||||||
|
v.SetDefault("event_broker.redis.host", "localhost")
|
||||||
|
v.SetDefault("event_broker.redis.port", 6379)
|
||||||
|
v.SetDefault("event_broker.redis.password", "")
|
||||||
|
v.SetDefault("event_broker.redis.db", 0)
|
||||||
|
|
||||||
|
// Event Broker - NATS defaults
|
||||||
|
v.SetDefault("event_broker.nats.url", "nats://localhost:4222")
|
||||||
|
v.SetDefault("event_broker.nats.stream_name", "RESOLVESPEC_EVENTS")
|
||||||
|
v.SetDefault("event_broker.nats.subjects", []string{"events.>"})
|
||||||
|
v.SetDefault("event_broker.nats.storage", "file")
|
||||||
|
v.SetDefault("event_broker.nats.max_age", "24h")
|
||||||
|
|
||||||
|
// Event Broker - Database defaults
|
||||||
|
v.SetDefault("event_broker.database.table_name", "events")
|
||||||
|
v.SetDefault("event_broker.database.channel", "resolvespec_events")
|
||||||
|
v.SetDefault("event_broker.database.poll_interval", "1s")
|
||||||
|
|
||||||
|
// Event Broker - Retry Policy defaults
|
||||||
|
v.SetDefault("event_broker.retry_policy.max_retries", 3)
|
||||||
|
v.SetDefault("event_broker.retry_policy.initial_delay", "1s")
|
||||||
|
v.SetDefault("event_broker.retry_policy.max_delay", "30s")
|
||||||
|
v.SetDefault("event_broker.retry_policy.backoff_factor", 2.0)
|
||||||
}
|
}
|
||||||
|
|||||||
327
pkg/eventbroker/README.md
Normal file
327
pkg/eventbroker/README.md
Normal file
@ -0,0 +1,327 @@
|
|||||||
|
# Event Broker System
|
||||||
|
|
||||||
|
A comprehensive event handler/broker system for ResolveSpec that provides real-time event publishing, subscription, and cross-instance communication.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- **Multiple Sources**: Events from database, websockets, frontend, system, and internal sources
|
||||||
|
- **Event Status Tracking**: Pending, processing, completed, failed states with timestamps
|
||||||
|
- **Rich Metadata**: User IDs, session IDs, instance IDs, JSON payloads, and custom metadata
|
||||||
|
- **Sync & Async Modes**: Choose between synchronous or asynchronous event processing
|
||||||
|
- **Pattern Matching**: Subscribe to events using glob-style patterns
|
||||||
|
- **Multiple Providers**: In-memory, Redis Streams, NATS JetStream, PostgreSQL with NOTIFY
|
||||||
|
- **Hook Integration**: Automatic CRUD event capture via restheadspec hooks
|
||||||
|
- **Retry Logic**: Configurable retry policy with exponential backoff
|
||||||
|
- **Metrics**: Prometheus-compatible metrics for monitoring
|
||||||
|
- **Graceful Shutdown**: Proper cleanup and event flushing on shutdown
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Configuration
|
||||||
|
|
||||||
|
Add to your `config.yaml`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
enabled: true
|
||||||
|
provider: memory # memory, redis, nats, database
|
||||||
|
mode: async # sync, async
|
||||||
|
worker_count: 10
|
||||||
|
buffer_size: 1000
|
||||||
|
instance_id: "${HOSTNAME}"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Initialize
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/eventbroker"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// Load configuration
|
||||||
|
cfgMgr := config.NewManager()
|
||||||
|
cfg, _ := cfgMgr.GetConfig()
|
||||||
|
|
||||||
|
// Initialize event broker
|
||||||
|
if err := eventbroker.Initialize(cfg.EventBroker); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Subscribe to Events
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Subscribe to specific events
|
||||||
|
eventbroker.Subscribe("public.users.create", eventbroker.EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *eventbroker.Event) error {
|
||||||
|
log.Printf("New user created: %s", event.Payload)
|
||||||
|
// Send welcome email, update cache, etc.
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
// Subscribe with patterns
|
||||||
|
eventbroker.Subscribe("*.*.delete", eventbroker.EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *eventbroker.Event) error {
|
||||||
|
log.Printf("Deleted: %s.%s", event.Schema, event.Entity)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Publish Events
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Create and publish an event
|
||||||
|
event := eventbroker.NewEvent(eventbroker.EventSourceDatabase, "public.users.update")
|
||||||
|
event.InstanceID = eventbroker.GetDefaultBroker().InstanceID()
|
||||||
|
event.UserID = 123
|
||||||
|
event.SessionID = "session-456"
|
||||||
|
event.Schema = "public"
|
||||||
|
event.Entity = "users"
|
||||||
|
event.Operation = "update"
|
||||||
|
|
||||||
|
event.SetPayload(map[string]interface{}{
|
||||||
|
"id": 123,
|
||||||
|
"name": "John Doe",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Async (non-blocking)
|
||||||
|
eventbroker.PublishAsync(ctx, event)
|
||||||
|
|
||||||
|
// Sync (blocking)
|
||||||
|
eventbroker.PublishSync(ctx, event)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Automatic CRUD Event Capture
|
||||||
|
|
||||||
|
Automatically capture database CRUD operations:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/eventbroker"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupHooks(handler *restheadspec.Handler) {
|
||||||
|
broker := eventbroker.GetDefaultBroker()
|
||||||
|
|
||||||
|
// Configure which operations to capture
|
||||||
|
config := eventbroker.DefaultCRUDHookConfig()
|
||||||
|
config.EnableRead = false // Disable read events for performance
|
||||||
|
|
||||||
|
// Register hooks
|
||||||
|
eventbroker.RegisterCRUDHooks(broker, handler.Hooks(), config)
|
||||||
|
|
||||||
|
// Now all create/update/delete operations automatically publish events!
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Event Structure
|
||||||
|
|
||||||
|
Every event contains:
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Event struct {
|
||||||
|
ID string // UUID
|
||||||
|
Source EventSource // database, websocket, system, frontend, internal
|
||||||
|
Type string // Pattern: schema.entity.operation
|
||||||
|
Status EventStatus // pending, processing, completed, failed
|
||||||
|
Payload json.RawMessage // JSON payload
|
||||||
|
UserID int // User who triggered the event
|
||||||
|
SessionID string // Session identifier
|
||||||
|
InstanceID string // Server instance identifier
|
||||||
|
Schema string // Database schema
|
||||||
|
Entity string // Database entity/table
|
||||||
|
Operation string // create, update, delete, read
|
||||||
|
CreatedAt time.Time // When event was created
|
||||||
|
ProcessedAt *time.Time // When processing started
|
||||||
|
CompletedAt *time.Time // When processing completed
|
||||||
|
Error string // Error message if failed
|
||||||
|
Metadata map[string]interface{} // Additional context
|
||||||
|
RetryCount int // Number of retry attempts
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pattern Matching
|
||||||
|
|
||||||
|
Subscribe to events using glob-style patterns:
|
||||||
|
|
||||||
|
| Pattern | Matches | Example |
|
||||||
|
|---------|---------|---------|
|
||||||
|
| `*` | All events | Any event |
|
||||||
|
| `public.users.*` | All user operations | `public.users.create`, `public.users.update` |
|
||||||
|
| `*.*.create` | All create operations | `public.users.create`, `auth.sessions.create` |
|
||||||
|
| `public.*.*` | All events in public schema | `public.users.create`, `public.posts.delete` |
|
||||||
|
| `public.users.create` | Exact match | Only `public.users.create` |
|
||||||
|
|
||||||
|
## Providers
|
||||||
|
|
||||||
|
### Memory Provider (Default)
|
||||||
|
|
||||||
|
Best for: Development, single-instance deployments
|
||||||
|
|
||||||
|
- **Pros**: Fast, no dependencies, simple
|
||||||
|
- **Cons**: Events lost on restart, single-instance only
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
provider: memory
|
||||||
|
```
|
||||||
|
|
||||||
|
### Redis Provider (Future)
|
||||||
|
|
||||||
|
Best for: Production, multi-instance deployments
|
||||||
|
|
||||||
|
- **Pros**: Persistent, cross-instance pub/sub, reliable
|
||||||
|
- **Cons**: Requires Redis
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
provider: redis
|
||||||
|
redis:
|
||||||
|
stream_name: "resolvespec:events"
|
||||||
|
consumer_group: "resolvespec-workers"
|
||||||
|
host: "localhost"
|
||||||
|
port: 6379
|
||||||
|
```
|
||||||
|
|
||||||
|
### NATS Provider (Future)
|
||||||
|
|
||||||
|
Best for: High-performance, low-latency requirements
|
||||||
|
|
||||||
|
- **Pros**: Very fast, built-in clustering, durable
|
||||||
|
- **Cons**: Requires NATS server
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
provider: nats
|
||||||
|
nats:
|
||||||
|
url: "nats://localhost:4222"
|
||||||
|
stream_name: "RESOLVESPEC_EVENTS"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Provider (Future)
|
||||||
|
|
||||||
|
Best for: Audit trails, event replay, SQL queries
|
||||||
|
|
||||||
|
- **Pros**: No additional infrastructure, full SQL query support, PostgreSQL NOTIFY for real-time
|
||||||
|
- **Cons**: Slower than Redis/NATS
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
provider: database
|
||||||
|
database:
|
||||||
|
table_name: "events"
|
||||||
|
channel: "resolvespec_events"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Processing Modes
|
||||||
|
|
||||||
|
### Async Mode (Recommended)
|
||||||
|
|
||||||
|
Events are queued and processed by worker pool:
|
||||||
|
|
||||||
|
- Non-blocking event publishing
|
||||||
|
- Configurable worker count
|
||||||
|
- Better throughput
|
||||||
|
- Events may be processed out of order
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
mode: async
|
||||||
|
worker_count: 10
|
||||||
|
buffer_size: 1000
|
||||||
|
```
|
||||||
|
|
||||||
|
### Sync Mode
|
||||||
|
|
||||||
|
Events are processed immediately:
|
||||||
|
|
||||||
|
- Blocking event publishing
|
||||||
|
- Guaranteed ordering
|
||||||
|
- Immediate error feedback
|
||||||
|
- Lower throughput
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
mode: sync
|
||||||
|
```
|
||||||
|
|
||||||
|
## Retry Policy
|
||||||
|
|
||||||
|
Configure automatic retries for failed handlers:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
event_broker:
|
||||||
|
retry_policy:
|
||||||
|
max_retries: 3
|
||||||
|
initial_delay: 1s
|
||||||
|
max_delay: 30s
|
||||||
|
backoff_factor: 2.0 # Exponential backoff
|
||||||
|
```
|
||||||
|
|
||||||
|
## Metrics
|
||||||
|
|
||||||
|
The event broker exposes Prometheus metrics:
|
||||||
|
|
||||||
|
- `eventbroker_events_published_total{source, type}` - Total events published
|
||||||
|
- `eventbroker_events_processed_total{source, type, status}` - Total events processed
|
||||||
|
- `eventbroker_event_processing_duration_seconds{source, type}` - Event processing duration
|
||||||
|
- `eventbroker_queue_size` - Current queue size (async mode)
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Use Async Mode**: For better performance, use async mode in production
|
||||||
|
2. **Disable Read Events**: Read events can be high volume; disable if not needed
|
||||||
|
3. **Pattern Matching**: Use specific patterns to avoid processing unnecessary events
|
||||||
|
4. **Error Handling**: Always handle errors in event handlers; they won't fail the original operation
|
||||||
|
5. **Idempotency**: Make handlers idempotent as events may be retried
|
||||||
|
6. **Payload Size**: Keep payloads reasonable; avoid large objects
|
||||||
|
7. **Monitoring**: Monitor metrics to detect issues early
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
See `example_usage.go` for comprehensive examples including:
|
||||||
|
- Basic event publishing and subscription
|
||||||
|
- Hook integration
|
||||||
|
- Error handling
|
||||||
|
- Configuration
|
||||||
|
- Pattern matching
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────┐
|
||||||
|
│ Application │
|
||||||
|
└────────┬────────┘
|
||||||
|
│
|
||||||
|
├─ Publish Events
|
||||||
|
│
|
||||||
|
┌────────▼────────┐ ┌──────────────┐
|
||||||
|
│ Event Broker │◄────►│ Subscribers │
|
||||||
|
└────────┬────────┘ └──────────────┘
|
||||||
|
│
|
||||||
|
├─ Store Events
|
||||||
|
│
|
||||||
|
┌────────▼────────┐
|
||||||
|
│ Provider │
|
||||||
|
│ (Memory/Redis │
|
||||||
|
│ /NATS/DB) │
|
||||||
|
└─────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## Future Enhancements
|
||||||
|
|
||||||
|
- [ ] Database Provider with PostgreSQL NOTIFY
|
||||||
|
- [ ] Redis Streams Provider
|
||||||
|
- [ ] NATS JetStream Provider
|
||||||
|
- [ ] Event replay functionality
|
||||||
|
- [ ] Dead letter queue
|
||||||
|
- [ ] Event filtering at provider level
|
||||||
|
- [ ] Batch publishing
|
||||||
|
- [ ] Event compression
|
||||||
|
- [ ] Schema versioning
|
||||||
453
pkg/eventbroker/broker.go
Normal file
453
pkg/eventbroker/broker.go
Normal file
@ -0,0 +1,453 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Broker is the main interface for event publishing and subscription
|
||||||
|
type Broker interface {
|
||||||
|
// Publish publishes an event (mode-dependent: sync or async)
|
||||||
|
Publish(ctx context.Context, event *Event) error
|
||||||
|
|
||||||
|
// PublishSync publishes an event synchronously (blocks until all handlers complete)
|
||||||
|
PublishSync(ctx context.Context, event *Event) error
|
||||||
|
|
||||||
|
// PublishAsync publishes an event asynchronously (returns immediately)
|
||||||
|
PublishAsync(ctx context.Context, event *Event) error
|
||||||
|
|
||||||
|
// Subscribe registers a handler for events matching the pattern
|
||||||
|
Subscribe(pattern string, handler EventHandler) (SubscriptionID, error)
|
||||||
|
|
||||||
|
// Unsubscribe removes a subscription
|
||||||
|
Unsubscribe(id SubscriptionID) error
|
||||||
|
|
||||||
|
// Start starts the broker (begins processing events)
|
||||||
|
Start(ctx context.Context) error
|
||||||
|
|
||||||
|
// Stop stops the broker gracefully (flushes pending events)
|
||||||
|
Stop(ctx context.Context) error
|
||||||
|
|
||||||
|
// Stats returns broker statistics
|
||||||
|
Stats(ctx context.Context) (*BrokerStats, error)
|
||||||
|
|
||||||
|
// InstanceID returns the instance ID of this broker
|
||||||
|
InstanceID() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProcessingMode determines how events are processed
|
||||||
|
type ProcessingMode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProcessingModeSync ProcessingMode = "sync"
|
||||||
|
ProcessingModeAsync ProcessingMode = "async"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BrokerStats contains broker statistics
|
||||||
|
type BrokerStats struct {
|
||||||
|
InstanceID string `json:"instance_id"`
|
||||||
|
Mode ProcessingMode `json:"mode"`
|
||||||
|
IsRunning bool `json:"is_running"`
|
||||||
|
TotalPublished int64 `json:"total_published"`
|
||||||
|
TotalProcessed int64 `json:"total_processed"`
|
||||||
|
TotalFailed int64 `json:"total_failed"`
|
||||||
|
ActiveSubscribers int `json:"active_subscribers"`
|
||||||
|
QueueSize int `json:"queue_size,omitempty"` // For async mode
|
||||||
|
ActiveWorkers int `json:"active_workers,omitempty"` // For async mode
|
||||||
|
ProviderStats *ProviderStats `json:"provider_stats,omitempty"`
|
||||||
|
AdditionalStats map[string]interface{} `json:"additional_stats,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventBroker implements the Broker interface
|
||||||
|
type EventBroker struct {
|
||||||
|
provider Provider
|
||||||
|
subscriptions *subscriptionManager
|
||||||
|
mode ProcessingMode
|
||||||
|
instanceID string
|
||||||
|
retryPolicy *RetryPolicy
|
||||||
|
|
||||||
|
// Async mode fields (initialized in Phase 4)
|
||||||
|
workerPool *workerPool
|
||||||
|
|
||||||
|
// Runtime state
|
||||||
|
isRunning atomic.Bool
|
||||||
|
stopOnce sync.Once
|
||||||
|
stopCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Statistics
|
||||||
|
statsPublished atomic.Int64
|
||||||
|
statsProcessed atomic.Int64
|
||||||
|
statsFailed atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetryPolicy defines how failed events should be retried
|
||||||
|
type RetryPolicy struct {
|
||||||
|
MaxRetries int
|
||||||
|
InitialDelay time.Duration
|
||||||
|
MaxDelay time.Duration
|
||||||
|
BackoffFactor float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRetryPolicy returns a sensible default retry policy
|
||||||
|
func DefaultRetryPolicy() *RetryPolicy {
|
||||||
|
return &RetryPolicy{
|
||||||
|
MaxRetries: 3,
|
||||||
|
InitialDelay: 1 * time.Second,
|
||||||
|
MaxDelay: 30 * time.Second,
|
||||||
|
BackoffFactor: 2.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options for creating a new broker
|
||||||
|
type Options struct {
|
||||||
|
Provider Provider
|
||||||
|
Mode ProcessingMode
|
||||||
|
WorkerCount int // For async mode
|
||||||
|
BufferSize int // For async mode
|
||||||
|
RetryPolicy *RetryPolicy
|
||||||
|
InstanceID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBroker creates a new event broker with the given options
|
||||||
|
func NewBroker(opts Options) (*EventBroker, error) {
|
||||||
|
if opts.Provider == nil {
|
||||||
|
return nil, fmt.Errorf("provider is required")
|
||||||
|
}
|
||||||
|
if opts.InstanceID == "" {
|
||||||
|
return nil, fmt.Errorf("instance ID is required")
|
||||||
|
}
|
||||||
|
if opts.Mode == "" {
|
||||||
|
opts.Mode = ProcessingModeAsync // Default to async
|
||||||
|
}
|
||||||
|
if opts.RetryPolicy == nil {
|
||||||
|
opts.RetryPolicy = DefaultRetryPolicy()
|
||||||
|
}
|
||||||
|
|
||||||
|
broker := &EventBroker{
|
||||||
|
provider: opts.Provider,
|
||||||
|
subscriptions: newSubscriptionManager(),
|
||||||
|
mode: opts.Mode,
|
||||||
|
instanceID: opts.InstanceID,
|
||||||
|
retryPolicy: opts.RetryPolicy,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Worker pool will be initialized in Phase 4 for async mode
|
||||||
|
if opts.Mode == ProcessingModeAsync {
|
||||||
|
if opts.WorkerCount == 0 {
|
||||||
|
opts.WorkerCount = 10 // Default
|
||||||
|
}
|
||||||
|
if opts.BufferSize == 0 {
|
||||||
|
opts.BufferSize = 1000 // Default
|
||||||
|
}
|
||||||
|
broker.workerPool = newWorkerPool(opts.WorkerCount, opts.BufferSize, broker.processEvent)
|
||||||
|
}
|
||||||
|
|
||||||
|
return broker, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Functional option pattern helpers
|
||||||
|
func WithProvider(p Provider) func(*Options) {
|
||||||
|
return func(o *Options) { o.Provider = p }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithMode(m ProcessingMode) func(*Options) {
|
||||||
|
return func(o *Options) { o.Mode = m }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithWorkerCount(count int) func(*Options) {
|
||||||
|
return func(o *Options) { o.WorkerCount = count }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithBufferSize(size int) func(*Options) {
|
||||||
|
return func(o *Options) { o.BufferSize = size }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryPolicy(policy *RetryPolicy) func(*Options) {
|
||||||
|
return func(o *Options) { o.RetryPolicy = policy }
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithInstanceID(id string) func(*Options) {
|
||||||
|
return func(o *Options) { o.InstanceID = id }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the broker
|
||||||
|
func (b *EventBroker) Start(ctx context.Context) error {
|
||||||
|
if b.isRunning.Load() {
|
||||||
|
return fmt.Errorf("broker already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.isRunning.Store(true)
|
||||||
|
|
||||||
|
// Start worker pool for async mode
|
||||||
|
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||||
|
b.workerPool.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Event broker started (mode: %s, instance: %s)", b.mode, b.instanceID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the broker gracefully
|
||||||
|
func (b *EventBroker) Stop(ctx context.Context) error {
|
||||||
|
var stopErr error
|
||||||
|
|
||||||
|
b.stopOnce.Do(func() {
|
||||||
|
logger.Info("Stopping event broker...")
|
||||||
|
|
||||||
|
// Mark as not running
|
||||||
|
b.isRunning.Store(false)
|
||||||
|
|
||||||
|
// Close the stop channel
|
||||||
|
close(b.stopCh)
|
||||||
|
|
||||||
|
// Stop worker pool for async mode
|
||||||
|
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||||
|
if err := b.workerPool.Stop(ctx); err != nil {
|
||||||
|
logger.Error("Error stopping worker pool: %v", err)
|
||||||
|
stopErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines
|
||||||
|
b.wg.Wait()
|
||||||
|
|
||||||
|
// Close provider
|
||||||
|
if err := b.provider.Close(); err != nil {
|
||||||
|
logger.Error("Error closing provider: %v", err)
|
||||||
|
if stopErr == nil {
|
||||||
|
stopErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Event broker stopped")
|
||||||
|
})
|
||||||
|
|
||||||
|
return stopErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish publishes an event based on the broker's mode
|
||||||
|
func (b *EventBroker) Publish(ctx context.Context, event *Event) error {
|
||||||
|
if b.mode == ProcessingModeSync {
|
||||||
|
return b.PublishSync(ctx, event)
|
||||||
|
}
|
||||||
|
return b.PublishAsync(ctx, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishSync publishes an event synchronously
|
||||||
|
func (b *EventBroker) PublishSync(ctx context.Context, event *Event) error {
|
||||||
|
if !b.isRunning.Load() {
|
||||||
|
return fmt.Errorf("broker is not running")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate event
|
||||||
|
if err := event.Validate(); err != nil {
|
||||||
|
return fmt.Errorf("invalid event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store event in provider
|
||||||
|
if err := b.provider.Publish(ctx, event); err != nil {
|
||||||
|
return fmt.Errorf("failed to publish event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.statsPublished.Add(1)
|
||||||
|
|
||||||
|
// Record metrics
|
||||||
|
recordEventPublished(event)
|
||||||
|
|
||||||
|
// Process event synchronously
|
||||||
|
if err := b.processEvent(ctx, event); err != nil {
|
||||||
|
logger.Error("Failed to process event %s: %v", event.ID, err)
|
||||||
|
b.statsFailed.Add(1)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
b.statsProcessed.Add(1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishAsync publishes an event asynchronously
|
||||||
|
func (b *EventBroker) PublishAsync(ctx context.Context, event *Event) error {
|
||||||
|
if !b.isRunning.Load() {
|
||||||
|
return fmt.Errorf("broker is not running")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate event
|
||||||
|
if err := event.Validate(); err != nil {
|
||||||
|
return fmt.Errorf("invalid event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store event in provider
|
||||||
|
if err := b.provider.Publish(ctx, event); err != nil {
|
||||||
|
return fmt.Errorf("failed to publish event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.statsPublished.Add(1)
|
||||||
|
|
||||||
|
// Record metrics
|
||||||
|
recordEventPublished(event)
|
||||||
|
|
||||||
|
// Queue for async processing
|
||||||
|
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||||
|
// Update queue size metrics
|
||||||
|
updateQueueSize(int64(b.workerPool.QueueSize()))
|
||||||
|
return b.workerPool.Submit(ctx, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to sync if async not configured
|
||||||
|
return b.processEvent(ctx, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe adds a subscription for events matching the pattern
|
||||||
|
func (b *EventBroker) Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
|
||||||
|
return b.subscriptions.Subscribe(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe removes a subscription
|
||||||
|
func (b *EventBroker) Unsubscribe(id SubscriptionID) error {
|
||||||
|
return b.subscriptions.Unsubscribe(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// processEvent processes an event by calling all matching handlers
|
||||||
|
func (b *EventBroker) processEvent(ctx context.Context, event *Event) error {
|
||||||
|
startTime := time.Now()
|
||||||
|
|
||||||
|
// Get all handlers matching this event type
|
||||||
|
handlers := b.subscriptions.GetMatching(event.Type)
|
||||||
|
|
||||||
|
if len(handlers) == 0 {
|
||||||
|
logger.Debug("No handlers for event type: %s", event.Type)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Processing event %s with %d handler(s)", event.ID, len(handlers))
|
||||||
|
|
||||||
|
// Mark event as processing
|
||||||
|
event.MarkProcessing()
|
||||||
|
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusProcessing, ""); err != nil {
|
||||||
|
logger.Warn("Failed to update event status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute all handlers
|
||||||
|
var lastErr error
|
||||||
|
for i, handler := range handlers {
|
||||||
|
if err := b.executeHandlerWithRetry(ctx, handler, event); err != nil {
|
||||||
|
logger.Error("Handler %d failed for event %s: %v", i+1, event.ID, err)
|
||||||
|
lastErr = err
|
||||||
|
// Continue processing other handlers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update final status
|
||||||
|
if lastErr != nil {
|
||||||
|
event.MarkFailed(lastErr)
|
||||||
|
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusFailed, lastErr.Error()); err != nil {
|
||||||
|
logger.Warn("Failed to update event status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record metrics
|
||||||
|
recordEventProcessed(event, time.Since(startTime))
|
||||||
|
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
event.MarkCompleted()
|
||||||
|
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusCompleted, ""); err != nil {
|
||||||
|
logger.Warn("Failed to update event status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record metrics
|
||||||
|
recordEventProcessed(event, time.Since(startTime))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeHandlerWithRetry executes a handler with retry logic
|
||||||
|
func (b *EventBroker) executeHandlerWithRetry(ctx context.Context, handler EventHandler, event *Event) error {
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for attempt := 0; attempt <= b.retryPolicy.MaxRetries; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
// Calculate backoff delay
|
||||||
|
delay := b.calculateBackoff(attempt)
|
||||||
|
logger.Debug("Retrying event %s (attempt %d/%d) after %v",
|
||||||
|
event.ID, attempt, b.retryPolicy.MaxRetries, delay)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(delay):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
event.IncrementRetry()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute handler
|
||||||
|
if err := handler.Handle(ctx, event); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
logger.Warn("Handler failed for event %s (attempt %d): %v", event.ID, attempt+1, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("handler failed after %d attempts: %w", b.retryPolicy.MaxRetries+1, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateBackoff calculates the backoff delay for a retry attempt
|
||||||
|
func (b *EventBroker) calculateBackoff(attempt int) time.Duration {
|
||||||
|
delay := float64(b.retryPolicy.InitialDelay) * pow(b.retryPolicy.BackoffFactor, float64(attempt-1))
|
||||||
|
if delay > float64(b.retryPolicy.MaxDelay) {
|
||||||
|
delay = float64(b.retryPolicy.MaxDelay)
|
||||||
|
}
|
||||||
|
return time.Duration(delay)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pow is a simple integer power function
|
||||||
|
func pow(base float64, exp float64) float64 {
|
||||||
|
result := 1.0
|
||||||
|
for i := 0.0; i < exp; i++ {
|
||||||
|
result *= base
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns broker statistics
|
||||||
|
func (b *EventBroker) Stats(ctx context.Context) (*BrokerStats, error) {
|
||||||
|
providerStats, err := b.provider.Stats(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to get provider stats: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := &BrokerStats{
|
||||||
|
InstanceID: b.instanceID,
|
||||||
|
Mode: b.mode,
|
||||||
|
IsRunning: b.isRunning.Load(),
|
||||||
|
TotalPublished: b.statsPublished.Load(),
|
||||||
|
TotalProcessed: b.statsProcessed.Load(),
|
||||||
|
TotalFailed: b.statsFailed.Load(),
|
||||||
|
ActiveSubscribers: b.subscriptions.Count(),
|
||||||
|
ProviderStats: providerStats,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add async-specific stats
|
||||||
|
if b.mode == ProcessingModeAsync && b.workerPool != nil {
|
||||||
|
stats.QueueSize = b.workerPool.QueueSize()
|
||||||
|
stats.ActiveWorkers = b.workerPool.ActiveWorkers()
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InstanceID returns the instance ID
|
||||||
|
func (b *EventBroker) InstanceID() string {
|
||||||
|
return b.instanceID
|
||||||
|
}
|
||||||
524
pkg/eventbroker/broker_test.go
Normal file
524
pkg/eventbroker/broker_test.go
Normal file
@ -0,0 +1,524 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewBroker(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
MaxEvents: 1000,
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
opts Options
|
||||||
|
wantError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid options",
|
||||||
|
opts: Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
},
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing provider",
|
||||||
|
opts: Options{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing instance ID",
|
||||||
|
opts: Options{
|
||||||
|
Provider: provider,
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "async mode with defaults",
|
||||||
|
opts: Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeAsync,
|
||||||
|
},
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
broker, err := NewBroker(tt.opts)
|
||||||
|
if (err != nil) != tt.wantError {
|
||||||
|
t.Errorf("NewBroker() error = %v, wantError %v", err, tt.wantError)
|
||||||
|
}
|
||||||
|
if err == nil && broker == nil {
|
||||||
|
t.Error("Expected non-nil broker")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerStartStop(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, err := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create broker: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Start
|
||||||
|
if err := broker.Start(context.Background()); err != nil {
|
||||||
|
t.Fatalf("Failed to start broker: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test double start (should fail)
|
||||||
|
if err := broker.Start(context.Background()); err == nil {
|
||||||
|
t.Error("Expected error on double start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Stop
|
||||||
|
if err := broker.Stop(context.Background()); err != nil {
|
||||||
|
t.Fatalf("Failed to stop broker: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test double stop (should not fail)
|
||||||
|
if err := broker.Stop(context.Background()); err != nil {
|
||||||
|
t.Error("Double stop should not fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerPublishSync(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
// Subscribe to events
|
||||||
|
called := false
|
||||||
|
var receivedEvent *Event
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called = true
|
||||||
|
receivedEvent = event
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish event
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
err := broker.PublishSync(context.Background(), event)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PublishSync failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify handler was called
|
||||||
|
if !called {
|
||||||
|
t.Error("Expected handler to be called")
|
||||||
|
}
|
||||||
|
if receivedEvent == nil || receivedEvent.ID != event.ID {
|
||||||
|
t.Error("Expected to receive the published event")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify event status
|
||||||
|
if event.Status != EventStatusCompleted {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusCompleted, event.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerPublishAsync(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeAsync,
|
||||||
|
WorkerCount: 2,
|
||||||
|
BufferSize: 10,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
// Subscribe to events
|
||||||
|
var callCount atomic.Int32
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
callCount.Add(1)
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish multiple events
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
if err := broker.PublishAsync(context.Background(), event); err != nil {
|
||||||
|
t.Fatalf("PublishAsync failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for events to be processed
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
if callCount.Load() != 5 {
|
||||||
|
t.Errorf("Expected 5 handler calls, got %d", callCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerPublishBeforeStart(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
err := broker.Publish(context.Background(), event)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when publishing before start")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerHandlerError(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
RetryPolicy: &RetryPolicy{
|
||||||
|
MaxRetries: 2,
|
||||||
|
InitialDelay: 10 * time.Millisecond,
|
||||||
|
MaxDelay: 100 * time.Millisecond,
|
||||||
|
BackoffFactor: 2.0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
// Subscribe with failing handler
|
||||||
|
var callCount atomic.Int32
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
callCount.Add(1)
|
||||||
|
return errors.New("handler error")
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish event
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
err := broker.PublishSync(context.Background(), event)
|
||||||
|
|
||||||
|
// Should fail after retries
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error from handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have been called MaxRetries+1 times (initial + retries)
|
||||||
|
if callCount.Load() != 3 {
|
||||||
|
t.Errorf("Expected 3 calls (1 initial + 2 retries), got %d", callCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Event should be marked as failed
|
||||||
|
if event.Status != EventStatusFailed {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusFailed, event.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerMultipleHandlers(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
// Subscribe multiple handlers
|
||||||
|
var called1, called2, called3 bool
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called1 = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
broker.Subscribe("test.event", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called2 = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
broker.Subscribe("*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called3 = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish event
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
broker.PublishSync(context.Background(), event)
|
||||||
|
|
||||||
|
// All handlers should be called
|
||||||
|
if !called1 || !called2 || !called3 {
|
||||||
|
t.Error("Expected all handlers to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerUnsubscribe(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
// Subscribe
|
||||||
|
called := false
|
||||||
|
id, _ := broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Unsubscribe
|
||||||
|
if err := broker.Unsubscribe(id); err != nil {
|
||||||
|
t.Fatalf("Unsubscribe failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish event
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
broker.PublishSync(context.Background(), event)
|
||||||
|
|
||||||
|
// Handler should not be called
|
||||||
|
if called {
|
||||||
|
t.Error("Expected handler not to be called after unsubscribe")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerStats(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeSync,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
// Subscribe
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish events
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
broker.PublishSync(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get stats
|
||||||
|
stats, err := broker.Stats(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stats failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.InstanceID != "test-instance" {
|
||||||
|
t.Errorf("Expected instance ID 'test-instance', got %s", stats.InstanceID)
|
||||||
|
}
|
||||||
|
if stats.TotalPublished != 3 {
|
||||||
|
t.Errorf("Expected 3 published events, got %d", stats.TotalPublished)
|
||||||
|
}
|
||||||
|
if stats.TotalProcessed != 3 {
|
||||||
|
t.Errorf("Expected 3 processed events, got %d", stats.TotalProcessed)
|
||||||
|
}
|
||||||
|
if stats.ActiveSubscribers != 1 {
|
||||||
|
t.Errorf("Expected 1 active subscriber, got %d", stats.ActiveSubscribers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerInstanceID(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "my-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
if broker.InstanceID() != "my-instance" {
|
||||||
|
t.Errorf("Expected instance ID 'my-instance', got %s", broker.InstanceID())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerConcurrentPublish(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeAsync,
|
||||||
|
WorkerCount: 5,
|
||||||
|
BufferSize: 100,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
var callCount atomic.Int32
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
callCount.Add(1)
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish concurrently
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
broker.PublishAsync(context.Background(), event)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
time.Sleep(200 * time.Millisecond) // Wait for async processing
|
||||||
|
|
||||||
|
if callCount.Load() != 50 {
|
||||||
|
t.Errorf("Expected 50 handler calls, got %d", callCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerGracefulShutdown(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: ProcessingModeAsync,
|
||||||
|
WorkerCount: 2,
|
||||||
|
BufferSize: 10,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
|
||||||
|
var processedCount atomic.Int32
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
time.Sleep(50 * time.Millisecond) // Simulate work
|
||||||
|
processedCount.Add(1)
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Publish events
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
broker.PublishAsync(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop broker (should wait for events to be processed)
|
||||||
|
if err := broker.Stop(context.Background()); err != nil {
|
||||||
|
t.Fatalf("Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// All events should be processed
|
||||||
|
if processedCount.Load() != 5 {
|
||||||
|
t.Errorf("Expected 5 processed events, got %d", processedCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerDefaultRetryPolicy(t *testing.T) {
|
||||||
|
policy := DefaultRetryPolicy()
|
||||||
|
|
||||||
|
if policy.MaxRetries != 3 {
|
||||||
|
t.Errorf("Expected MaxRetries 3, got %d", policy.MaxRetries)
|
||||||
|
}
|
||||||
|
if policy.InitialDelay != 1*time.Second {
|
||||||
|
t.Errorf("Expected InitialDelay 1s, got %v", policy.InitialDelay)
|
||||||
|
}
|
||||||
|
if policy.BackoffFactor != 2.0 {
|
||||||
|
t.Errorf("Expected BackoffFactor 2.0, got %f", policy.BackoffFactor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBrokerProcessingModes(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
mode ProcessingMode
|
||||||
|
}{
|
||||||
|
{"sync mode", ProcessingModeSync},
|
||||||
|
{"async mode", ProcessingModeAsync},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
broker, _ := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
Mode: tt.mode,
|
||||||
|
})
|
||||||
|
broker.Start(context.Background())
|
||||||
|
defer broker.Stop(context.Background())
|
||||||
|
|
||||||
|
called := false
|
||||||
|
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
event := NewEvent(EventSourceSystem, "test.event")
|
||||||
|
event.InstanceID = "test-instance"
|
||||||
|
broker.Publish(context.Background(), event)
|
||||||
|
|
||||||
|
if tt.mode == ProcessingModeAsync {
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Error("Expected handler to be called")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
175
pkg/eventbroker/event.go
Normal file
175
pkg/eventbroker/event.go
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EventSource represents where an event originated from
|
||||||
|
type EventSource string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EventSourceDatabase EventSource = "database"
|
||||||
|
EventSourceWebSocket EventSource = "websocket"
|
||||||
|
EventSourceFrontend EventSource = "frontend"
|
||||||
|
EventSourceSystem EventSource = "system"
|
||||||
|
EventSourceInternal EventSource = "internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EventStatus represents the current state of an event
|
||||||
|
type EventStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
EventStatusPending EventStatus = "pending"
|
||||||
|
EventStatusProcessing EventStatus = "processing"
|
||||||
|
EventStatusCompleted EventStatus = "completed"
|
||||||
|
EventStatusFailed EventStatus = "failed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Event represents a single event in the system with complete metadata
|
||||||
|
type Event struct {
|
||||||
|
// Identification
|
||||||
|
ID string `json:"id" db:"id"`
|
||||||
|
|
||||||
|
// Source & Classification
|
||||||
|
Source EventSource `json:"source" db:"source"`
|
||||||
|
Type string `json:"type" db:"type"` // Pattern: schema.entity.operation
|
||||||
|
|
||||||
|
// Status Tracking
|
||||||
|
Status EventStatus `json:"status" db:"status"`
|
||||||
|
RetryCount int `json:"retry_count" db:"retry_count"`
|
||||||
|
Error string `json:"error,omitempty" db:"error"`
|
||||||
|
|
||||||
|
// Payload
|
||||||
|
Payload json.RawMessage `json:"payload" db:"payload"`
|
||||||
|
|
||||||
|
// Context Information
|
||||||
|
UserID int `json:"user_id" db:"user_id"`
|
||||||
|
SessionID string `json:"session_id" db:"session_id"`
|
||||||
|
InstanceID string `json:"instance_id" db:"instance_id"`
|
||||||
|
|
||||||
|
// Database Context
|
||||||
|
Schema string `json:"schema" db:"schema"`
|
||||||
|
Entity string `json:"entity" db:"entity"`
|
||||||
|
Operation string `json:"operation" db:"operation"` // create, update, delete, read
|
||||||
|
|
||||||
|
// Timestamps
|
||||||
|
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||||
|
ProcessedAt *time.Time `json:"processed_at,omitempty" db:"processed_at"`
|
||||||
|
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
|
||||||
|
|
||||||
|
// Extensibility
|
||||||
|
Metadata map[string]interface{} `json:"metadata" db:"metadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEvent creates a new event with defaults
|
||||||
|
func NewEvent(source EventSource, eventType string) *Event {
|
||||||
|
return &Event{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
Source: source,
|
||||||
|
Type: eventType,
|
||||||
|
Status: EventStatusPending,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Metadata: make(map[string]interface{}),
|
||||||
|
RetryCount: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventType generates a type string from schema, entity, and operation
|
||||||
|
// Pattern: schema.entity.operation (e.g., "public.users.create")
|
||||||
|
func EventType(schema, entity, operation string) string {
|
||||||
|
return fmt.Sprintf("%s.%s.%s", schema, entity, operation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkProcessing marks the event as being processed
|
||||||
|
func (e *Event) MarkProcessing() {
|
||||||
|
e.Status = EventStatusProcessing
|
||||||
|
now := time.Now()
|
||||||
|
e.ProcessedAt = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkCompleted marks the event as successfully completed
|
||||||
|
func (e *Event) MarkCompleted() {
|
||||||
|
e.Status = EventStatusCompleted
|
||||||
|
now := time.Now()
|
||||||
|
e.CompletedAt = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkFailed marks the event as failed with an error message
|
||||||
|
func (e *Event) MarkFailed(err error) {
|
||||||
|
e.Status = EventStatusFailed
|
||||||
|
e.Error = err.Error()
|
||||||
|
now := time.Now()
|
||||||
|
e.CompletedAt = &now
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementRetry increments the retry counter
|
||||||
|
func (e *Event) IncrementRetry() {
|
||||||
|
e.RetryCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPayload sets the event payload from any value by marshaling to JSON
|
||||||
|
func (e *Event) SetPayload(v interface{}) error {
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal payload: %w", err)
|
||||||
|
}
|
||||||
|
e.Payload = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPayload unmarshals the payload into the provided value
|
||||||
|
func (e *Event) GetPayload(v interface{}) error {
|
||||||
|
if len(e.Payload) == 0 {
|
||||||
|
return fmt.Errorf("payload is empty")
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(e.Payload, v); err != nil {
|
||||||
|
return fmt.Errorf("failed to unmarshal payload: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone creates a deep copy of the event
|
||||||
|
func (e *Event) Clone() *Event {
|
||||||
|
clone := *e
|
||||||
|
|
||||||
|
// Deep copy metadata
|
||||||
|
if e.Metadata != nil {
|
||||||
|
clone.Metadata = make(map[string]interface{})
|
||||||
|
for k, v := range e.Metadata {
|
||||||
|
clone.Metadata[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deep copy timestamps
|
||||||
|
if e.ProcessedAt != nil {
|
||||||
|
t := *e.ProcessedAt
|
||||||
|
clone.ProcessedAt = &t
|
||||||
|
}
|
||||||
|
if e.CompletedAt != nil {
|
||||||
|
t := *e.CompletedAt
|
||||||
|
clone.CompletedAt = &t
|
||||||
|
}
|
||||||
|
|
||||||
|
return &clone
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate performs basic validation on the event
|
||||||
|
func (e *Event) Validate() error {
|
||||||
|
if e.ID == "" {
|
||||||
|
return fmt.Errorf("event ID is required")
|
||||||
|
}
|
||||||
|
if e.Source == "" {
|
||||||
|
return fmt.Errorf("event source is required")
|
||||||
|
}
|
||||||
|
if e.Type == "" {
|
||||||
|
return fmt.Errorf("event type is required")
|
||||||
|
}
|
||||||
|
if e.InstanceID == "" {
|
||||||
|
return fmt.Errorf("instance ID is required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
314
pkg/eventbroker/event_test.go
Normal file
314
pkg/eventbroker/event_test.go
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewEvent(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
|
||||||
|
if event.ID == "" {
|
||||||
|
t.Error("Expected event ID to be generated")
|
||||||
|
}
|
||||||
|
if event.Source != EventSourceDatabase {
|
||||||
|
t.Errorf("Expected source %s, got %s", EventSourceDatabase, event.Source)
|
||||||
|
}
|
||||||
|
if event.Type != "public.users.create" {
|
||||||
|
t.Errorf("Expected type 'public.users.create', got %s", event.Type)
|
||||||
|
}
|
||||||
|
if event.Status != EventStatusPending {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusPending, event.Status)
|
||||||
|
}
|
||||||
|
if event.CreatedAt.IsZero() {
|
||||||
|
t.Error("Expected CreatedAt to be set")
|
||||||
|
}
|
||||||
|
if event.Metadata == nil {
|
||||||
|
t.Error("Expected Metadata to be initialized")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
schema string
|
||||||
|
entity string
|
||||||
|
operation string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"public", "users", "create", "public.users.create"},
|
||||||
|
{"admin", "roles", "update", "admin.roles.update"},
|
||||||
|
{"", "system", "start", ".system.start"}, // Empty schema results in leading dot
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := EventType(tt.schema, tt.entity, tt.operation)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("EventType(%q, %q, %q) = %q, expected %q",
|
||||||
|
tt.schema, tt.entity, tt.operation, result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventValidate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
event *Event
|
||||||
|
wantError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid event",
|
||||||
|
event: func() *Event {
|
||||||
|
e := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
e.InstanceID = "test-instance"
|
||||||
|
return e
|
||||||
|
}(),
|
||||||
|
wantError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing ID",
|
||||||
|
event: &Event{
|
||||||
|
Source: EventSourceDatabase,
|
||||||
|
Type: "public.users.create",
|
||||||
|
Status: EventStatusPending,
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing source",
|
||||||
|
event: &Event{
|
||||||
|
ID: "test-id",
|
||||||
|
Type: "public.users.create",
|
||||||
|
Status: EventStatusPending,
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing type",
|
||||||
|
event: &Event{
|
||||||
|
ID: "test-id",
|
||||||
|
Source: EventSourceDatabase,
|
||||||
|
Status: EventStatusPending,
|
||||||
|
},
|
||||||
|
wantError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.event.Validate()
|
||||||
|
if (err != nil) != tt.wantError {
|
||||||
|
t.Errorf("Event.Validate() error = %v, wantError %v", err, tt.wantError)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventSetPayload(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"id": 1,
|
||||||
|
"name": "John Doe",
|
||||||
|
"email": "john@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
err := event.SetPayload(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("SetPayload failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.Payload == nil {
|
||||||
|
t.Fatal("Expected payload to be set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify payload can be unmarshaled
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.Unmarshal(event.Payload, &result); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal payload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["name"] != "John Doe" {
|
||||||
|
t.Errorf("Expected name 'John Doe', got %v", result["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventGetPayload(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"id": float64(1), // JSON unmarshals numbers as float64
|
||||||
|
"name": "John Doe",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := event.SetPayload(payload); err != nil {
|
||||||
|
t.Fatalf("SetPayload failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := event.GetPayload(&result); err != nil {
|
||||||
|
t.Fatalf("GetPayload failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["name"] != "John Doe" {
|
||||||
|
t.Errorf("Expected name 'John Doe', got %v", result["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventMarkProcessing(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
event.MarkProcessing()
|
||||||
|
|
||||||
|
if event.Status != EventStatusProcessing {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusProcessing, event.Status)
|
||||||
|
}
|
||||||
|
if event.ProcessedAt == nil {
|
||||||
|
t.Error("Expected ProcessedAt to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventMarkCompleted(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
event.MarkCompleted()
|
||||||
|
|
||||||
|
if event.Status != EventStatusCompleted {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusCompleted, event.Status)
|
||||||
|
}
|
||||||
|
if event.CompletedAt == nil {
|
||||||
|
t.Error("Expected CompletedAt to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventMarkFailed(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
testErr := errors.New("test error")
|
||||||
|
event.MarkFailed(testErr)
|
||||||
|
|
||||||
|
if event.Status != EventStatusFailed {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusFailed, event.Status)
|
||||||
|
}
|
||||||
|
if event.Error != "test error" {
|
||||||
|
t.Errorf("Expected error %q, got %q", "test error", event.Error)
|
||||||
|
}
|
||||||
|
if event.CompletedAt == nil {
|
||||||
|
t.Error("Expected CompletedAt to be set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventIncrementRetry(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
|
||||||
|
initialCount := event.RetryCount
|
||||||
|
event.IncrementRetry()
|
||||||
|
|
||||||
|
if event.RetryCount != initialCount+1 {
|
||||||
|
t.Errorf("Expected retry count %d, got %d", initialCount+1, event.RetryCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventJSONMarshaling(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
event.UserID = 123
|
||||||
|
event.SessionID = "session-123"
|
||||||
|
event.InstanceID = "instance-1"
|
||||||
|
event.Schema = "public"
|
||||||
|
event.Entity = "users"
|
||||||
|
event.Operation = "create"
|
||||||
|
event.SetPayload(map[string]interface{}{"name": "Test"})
|
||||||
|
|
||||||
|
// Marshal to JSON
|
||||||
|
data, err := json.Marshal(event)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to marshal event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal back
|
||||||
|
var decoded Event
|
||||||
|
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||||
|
t.Fatalf("Failed to unmarshal event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify fields
|
||||||
|
if decoded.ID != event.ID {
|
||||||
|
t.Errorf("Expected ID %s, got %s", event.ID, decoded.ID)
|
||||||
|
}
|
||||||
|
if decoded.Source != event.Source {
|
||||||
|
t.Errorf("Expected source %s, got %s", event.Source, decoded.Source)
|
||||||
|
}
|
||||||
|
if decoded.UserID != event.UserID {
|
||||||
|
t.Errorf("Expected UserID %d, got %d", event.UserID, decoded.UserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventStatusString(t *testing.T) {
|
||||||
|
statuses := []EventStatus{
|
||||||
|
EventStatusPending,
|
||||||
|
EventStatusProcessing,
|
||||||
|
EventStatusCompleted,
|
||||||
|
EventStatusFailed,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, status := range statuses {
|
||||||
|
if string(status) == "" {
|
||||||
|
t.Errorf("EventStatus %v has empty string representation", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventSourceString(t *testing.T) {
|
||||||
|
sources := []EventSource{
|
||||||
|
EventSourceDatabase,
|
||||||
|
EventSourceWebSocket,
|
||||||
|
EventSourceFrontend,
|
||||||
|
EventSourceSystem,
|
||||||
|
EventSourceInternal,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, source := range sources {
|
||||||
|
if string(source) == "" {
|
||||||
|
t.Errorf("EventSource %v has empty string representation", source)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventMetadata(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
|
||||||
|
// Test setting metadata
|
||||||
|
event.Metadata["key1"] = "value1"
|
||||||
|
event.Metadata["key2"] = 123
|
||||||
|
|
||||||
|
if event.Metadata["key1"] != "value1" {
|
||||||
|
t.Errorf("Expected metadata key1 to be 'value1', got %v", event.Metadata["key1"])
|
||||||
|
}
|
||||||
|
if event.Metadata["key2"] != 123 {
|
||||||
|
t.Errorf("Expected metadata key2 to be 123, got %v", event.Metadata["key2"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventTimestamps(t *testing.T) {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
createdAt := event.CreatedAt
|
||||||
|
|
||||||
|
// Wait a tiny bit to ensure timestamps differ
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
|
||||||
|
event.MarkProcessing()
|
||||||
|
if event.ProcessedAt == nil {
|
||||||
|
t.Fatal("ProcessedAt should be set")
|
||||||
|
}
|
||||||
|
if !event.ProcessedAt.After(createdAt) {
|
||||||
|
t.Error("ProcessedAt should be after CreatedAt")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
|
||||||
|
event.MarkCompleted()
|
||||||
|
if event.CompletedAt == nil {
|
||||||
|
t.Fatal("CompletedAt should be set")
|
||||||
|
}
|
||||||
|
if !event.CompletedAt.After(*event.ProcessedAt) {
|
||||||
|
t.Error("CompletedAt should be after ProcessedAt")
|
||||||
|
}
|
||||||
|
}
|
||||||
160
pkg/eventbroker/eventbroker.go
Normal file
160
pkg/eventbroker/eventbroker.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultBroker Broker
|
||||||
|
brokerMu sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initialize initializes the global event broker from configuration
|
||||||
|
func Initialize(cfg config.EventBrokerConfig) error {
|
||||||
|
if !cfg.Enabled {
|
||||||
|
logger.Info("Event broker is disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create provider
|
||||||
|
provider, err := NewProviderFromConfig(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create provider: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse mode
|
||||||
|
mode := ProcessingModeAsync
|
||||||
|
if cfg.Mode == "sync" {
|
||||||
|
mode = ProcessingModeSync
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert retry policy
|
||||||
|
retryPolicy := &RetryPolicy{
|
||||||
|
MaxRetries: cfg.RetryPolicy.MaxRetries,
|
||||||
|
InitialDelay: cfg.RetryPolicy.InitialDelay,
|
||||||
|
MaxDelay: cfg.RetryPolicy.MaxDelay,
|
||||||
|
BackoffFactor: cfg.RetryPolicy.BackoffFactor,
|
||||||
|
}
|
||||||
|
if retryPolicy.MaxRetries == 0 {
|
||||||
|
retryPolicy = DefaultRetryPolicy()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create broker options
|
||||||
|
opts := Options{
|
||||||
|
Provider: provider,
|
||||||
|
Mode: mode,
|
||||||
|
WorkerCount: cfg.WorkerCount,
|
||||||
|
BufferSize: cfg.BufferSize,
|
||||||
|
RetryPolicy: retryPolicy,
|
||||||
|
InstanceID: getInstanceID(cfg.InstanceID),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create broker
|
||||||
|
broker, err := NewBroker(opts)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create broker: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start broker
|
||||||
|
if err := broker.Start(context.Background()); err != nil {
|
||||||
|
return fmt.Errorf("failed to start broker: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set as default
|
||||||
|
SetDefaultBroker(broker)
|
||||||
|
|
||||||
|
// Register shutdown callback
|
||||||
|
RegisterShutdown(broker)
|
||||||
|
|
||||||
|
logger.Info("Event broker initialized successfully (provider: %s, mode: %s, instance: %s)",
|
||||||
|
cfg.Provider, cfg.Mode, opts.InstanceID)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDefaultBroker sets the default global broker
|
||||||
|
func SetDefaultBroker(broker Broker) {
|
||||||
|
brokerMu.Lock()
|
||||||
|
defer brokerMu.Unlock()
|
||||||
|
defaultBroker = broker
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDefaultBroker returns the default global broker
|
||||||
|
func GetDefaultBroker() Broker {
|
||||||
|
brokerMu.RLock()
|
||||||
|
defer brokerMu.RUnlock()
|
||||||
|
return defaultBroker
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInitialized returns true if the default broker is initialized
|
||||||
|
func IsInitialized() bool {
|
||||||
|
return GetDefaultBroker() != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish publishes an event using the default broker
|
||||||
|
func Publish(ctx context.Context, event *Event) error {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return fmt.Errorf("event broker not initialized")
|
||||||
|
}
|
||||||
|
return broker.Publish(ctx, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishSync publishes an event synchronously using the default broker
|
||||||
|
func PublishSync(ctx context.Context, event *Event) error {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return fmt.Errorf("event broker not initialized")
|
||||||
|
}
|
||||||
|
return broker.PublishSync(ctx, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublishAsync publishes an event asynchronously using the default broker
|
||||||
|
func PublishAsync(ctx context.Context, event *Event) error {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return fmt.Errorf("event broker not initialized")
|
||||||
|
}
|
||||||
|
return broker.PublishAsync(ctx, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe subscribes to events using the default broker
|
||||||
|
func Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return "", fmt.Errorf("event broker not initialized")
|
||||||
|
}
|
||||||
|
return broker.Subscribe(pattern, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe unsubscribes from events using the default broker
|
||||||
|
func Unsubscribe(id SubscriptionID) error {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return fmt.Errorf("event broker not initialized")
|
||||||
|
}
|
||||||
|
return broker.Unsubscribe(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns statistics from the default broker
|
||||||
|
func Stats(ctx context.Context) (*BrokerStats, error) {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return nil, fmt.Errorf("event broker not initialized")
|
||||||
|
}
|
||||||
|
return broker.Stats(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterShutdown registers the broker's shutdown with the server shutdown callbacks
|
||||||
|
func RegisterShutdown(broker Broker) {
|
||||||
|
server.RegisterShutdownCallback(func(ctx context.Context) error {
|
||||||
|
logger.Info("Shutting down event broker...")
|
||||||
|
return broker.Stop(ctx)
|
||||||
|
})
|
||||||
|
}
|
||||||
266
pkg/eventbroker/example_usage.go
Normal file
266
pkg/eventbroker/example_usage.go
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
// nolint
|
||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Example demonstrates basic usage of the event broker
|
||||||
|
func Example() {
|
||||||
|
// 1. Create a memory provider
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "example-instance",
|
||||||
|
MaxEvents: 1000,
|
||||||
|
CleanupInterval: 5 * time.Minute,
|
||||||
|
MaxAge: 1 * time.Hour,
|
||||||
|
})
|
||||||
|
|
||||||
|
// 2. Create a broker
|
||||||
|
broker, err := NewBroker(Options{
|
||||||
|
Provider: provider,
|
||||||
|
Mode: ProcessingModeAsync,
|
||||||
|
WorkerCount: 5,
|
||||||
|
BufferSize: 100,
|
||||||
|
RetryPolicy: DefaultRetryPolicy(),
|
||||||
|
InstanceID: "example-instance",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create broker: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Start the broker
|
||||||
|
if err := broker.Start(context.Background()); err != nil {
|
||||||
|
logger.Error("Failed to start broker: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := broker.Stop(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to stop broker: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 4. Subscribe to events
|
||||||
|
broker.Subscribe("public.users.*", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
logger.Info("User event: %s (operation: %s)", event.Type, event.Operation)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
broker.Subscribe("*.*.create", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
logger.Info("Create event: %s.%s", event.Schema, event.Entity)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
// 5. Publish events
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Database event
|
||||||
|
dbEvent := NewEvent(EventSourceDatabase, EventType("public", "users", "create"))
|
||||||
|
dbEvent.InstanceID = "example-instance"
|
||||||
|
dbEvent.UserID = 123
|
||||||
|
dbEvent.SessionID = "session-456"
|
||||||
|
dbEvent.Schema = "public"
|
||||||
|
dbEvent.Entity = "users"
|
||||||
|
dbEvent.Operation = "create"
|
||||||
|
dbEvent.SetPayload(map[string]interface{}{
|
||||||
|
"id": 123,
|
||||||
|
"name": "John Doe",
|
||||||
|
"email": "john@example.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := broker.PublishAsync(ctx, dbEvent); err != nil {
|
||||||
|
logger.Error("Failed to publish event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebSocket event
|
||||||
|
wsEvent := NewEvent(EventSourceWebSocket, "chat.message")
|
||||||
|
wsEvent.InstanceID = "example-instance"
|
||||||
|
wsEvent.UserID = 123
|
||||||
|
wsEvent.SessionID = "session-456"
|
||||||
|
wsEvent.SetPayload(map[string]interface{}{
|
||||||
|
"room": "general",
|
||||||
|
"message": "Hello, World!",
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := broker.PublishAsync(ctx, wsEvent); err != nil {
|
||||||
|
logger.Error("Failed to publish event: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Get statistics
|
||||||
|
time.Sleep(1 * time.Second) // Wait for processing
|
||||||
|
stats, _ := broker.Stats(ctx)
|
||||||
|
logger.Info("Broker stats: %d published, %d processed", stats.TotalPublished, stats.TotalProcessed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleWithHooks demonstrates integration with the hook system
|
||||||
|
func ExampleWithHooks() {
|
||||||
|
// This would typically be called in your main.go or initialization code
|
||||||
|
// after setting up your restheadspec.Handler
|
||||||
|
|
||||||
|
// Pseudo-code (actual implementation would use real handler):
|
||||||
|
/*
|
||||||
|
broker := eventbroker.GetDefaultBroker()
|
||||||
|
hookRegistry := handler.Hooks()
|
||||||
|
|
||||||
|
// Register CRUD hooks
|
||||||
|
config := eventbroker.DefaultCRUDHookConfig()
|
||||||
|
config.EnableRead = false // Disable read events for performance
|
||||||
|
|
||||||
|
if err := eventbroker.RegisterCRUDHooks(broker, hookRegistry, config); err != nil {
|
||||||
|
logger.Error("Failed to register CRUD hooks: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now all CRUD operations will automatically publish events
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleSubscriptionPatterns demonstrates different subscription patterns
|
||||||
|
func ExampleSubscriptionPatterns() {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern 1: Subscribe to all events from a specific entity
|
||||||
|
broker.Subscribe("public.users.*", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
fmt.Printf("User event: %s\n", event.Operation)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
// Pattern 2: Subscribe to a specific operation across all entities
|
||||||
|
broker.Subscribe("*.*.create", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
fmt.Printf("Create event: %s.%s\n", event.Schema, event.Entity)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
// Pattern 3: Subscribe to all events in a schema
|
||||||
|
broker.Subscribe("public.*.*", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
fmt.Printf("Public schema event: %s.%s\n", event.Entity, event.Operation)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
// Pattern 4: Subscribe to everything (use with caution)
|
||||||
|
broker.Subscribe("*", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
fmt.Printf("Any event: %s\n", event.Type)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleErrorHandling demonstrates error handling in event handlers
|
||||||
|
func ExampleErrorHandling() {
|
||||||
|
broker := GetDefaultBroker()
|
||||||
|
if broker == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler that may fail
|
||||||
|
broker.Subscribe("public.users.create", EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *Event) error {
|
||||||
|
// Simulate processing
|
||||||
|
var user struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := event.GetPayload(&user); err != nil {
|
||||||
|
return fmt.Errorf("invalid payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate
|
||||||
|
if user.Email == "" {
|
||||||
|
return fmt.Errorf("email is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process (e.g., send email)
|
||||||
|
logger.Info("Sending welcome email to %s", user.Email)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleConfiguration demonstrates initializing from configuration
|
||||||
|
func ExampleConfiguration() {
|
||||||
|
// This would typically be in your main.go
|
||||||
|
|
||||||
|
// Pseudo-code:
|
||||||
|
/*
|
||||||
|
// Load configuration
|
||||||
|
cfgMgr := config.NewManager()
|
||||||
|
if err := cfgMgr.Load(); err != nil {
|
||||||
|
logger.Fatal("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := cfgMgr.GetConfig()
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal("Failed to get config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize event broker
|
||||||
|
if err := eventbroker.Initialize(cfg.EventBroker); err != nil {
|
||||||
|
logger.Fatal("Failed to initialize event broker: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the default broker
|
||||||
|
eventbroker.Subscribe("*.*.create", eventbroker.EventHandlerFunc(
|
||||||
|
func(ctx context.Context, event *eventbroker.Event) error {
|
||||||
|
logger.Info("Created: %s.%s", event.Schema, event.Entity)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
))
|
||||||
|
*/
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleYAMLConfiguration shows example YAML configuration
|
||||||
|
const ExampleYAMLConfiguration = `
|
||||||
|
event_broker:
|
||||||
|
enabled: true
|
||||||
|
provider: memory # memory, redis, nats, database
|
||||||
|
mode: async # sync, async
|
||||||
|
worker_count: 10
|
||||||
|
buffer_size: 1000
|
||||||
|
instance_id: "${HOSTNAME}"
|
||||||
|
|
||||||
|
# Memory provider is default, no additional config needed
|
||||||
|
|
||||||
|
# Redis provider (when provider: redis)
|
||||||
|
redis:
|
||||||
|
stream_name: "resolvespec:events"
|
||||||
|
consumer_group: "resolvespec-workers"
|
||||||
|
host: "localhost"
|
||||||
|
port: 6379
|
||||||
|
|
||||||
|
# NATS provider (when provider: nats)
|
||||||
|
nats:
|
||||||
|
url: "nats://localhost:4222"
|
||||||
|
stream_name: "RESOLVESPEC_EVENTS"
|
||||||
|
|
||||||
|
# Database provider (when provider: database)
|
||||||
|
database:
|
||||||
|
table_name: "events"
|
||||||
|
channel: "resolvespec_events"
|
||||||
|
|
||||||
|
# Retry policy
|
||||||
|
retry_policy:
|
||||||
|
max_retries: 3
|
||||||
|
initial_delay: 1s
|
||||||
|
max_delay: 30s
|
||||||
|
backoff_factor: 2.0
|
||||||
|
`
|
||||||
56
pkg/eventbroker/factory.go
Normal file
56
pkg/eventbroker/factory.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewProviderFromConfig creates a provider based on configuration
|
||||||
|
func NewProviderFromConfig(cfg config.EventBrokerConfig) (Provider, error) {
|
||||||
|
switch cfg.Provider {
|
||||||
|
case "memory":
|
||||||
|
cleanupInterval := 5 * time.Minute
|
||||||
|
if cfg.Database.PollInterval > 0 {
|
||||||
|
cleanupInterval = cfg.Database.PollInterval
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: getInstanceID(cfg.InstanceID),
|
||||||
|
MaxEvents: 10000,
|
||||||
|
CleanupInterval: cleanupInterval,
|
||||||
|
}), nil
|
||||||
|
|
||||||
|
case "redis":
|
||||||
|
// Redis provider will be implemented in Phase 8
|
||||||
|
return nil, fmt.Errorf("redis provider not yet implemented")
|
||||||
|
|
||||||
|
case "nats":
|
||||||
|
// NATS provider will be implemented in Phase 9
|
||||||
|
return nil, fmt.Errorf("nats provider not yet implemented")
|
||||||
|
|
||||||
|
case "database":
|
||||||
|
// Database provider will be implemented in Phase 7
|
||||||
|
return nil, fmt.Errorf("database provider not yet implemented")
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown provider: %s", cfg.Provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getInstanceID returns the instance ID, defaulting to hostname if not specified
|
||||||
|
func getInstanceID(configID string) string {
|
||||||
|
if configID != "" {
|
||||||
|
return configID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to get hostname
|
||||||
|
if hostname, err := os.Hostname(); err == nil {
|
||||||
|
return hostname
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to a default
|
||||||
|
return "resolvespec-instance"
|
||||||
|
}
|
||||||
17
pkg/eventbroker/handler.go
Normal file
17
pkg/eventbroker/handler.go
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
// EventHandler processes an event
|
||||||
|
type EventHandler interface {
|
||||||
|
Handle(ctx context.Context, event *Event) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventHandlerFunc is a function adapter for EventHandler
|
||||||
|
// This allows using regular functions as event handlers
|
||||||
|
type EventHandlerFunc func(ctx context.Context, event *Event) error
|
||||||
|
|
||||||
|
// Handle implements EventHandler
|
||||||
|
func (f EventHandlerFunc) Handle(ctx context.Context, event *Event) error {
|
||||||
|
return f(ctx, event)
|
||||||
|
}
|
||||||
137
pkg/eventbroker/hooks.go
Normal file
137
pkg/eventbroker/hooks.go
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CRUDHookConfig configures which CRUD operations should trigger events
|
||||||
|
type CRUDHookConfig struct {
|
||||||
|
EnableCreate bool
|
||||||
|
EnableRead bool
|
||||||
|
EnableUpdate bool
|
||||||
|
EnableDelete bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultCRUDHookConfig returns default configuration (all enabled)
|
||||||
|
func DefaultCRUDHookConfig() *CRUDHookConfig {
|
||||||
|
return &CRUDHookConfig{
|
||||||
|
EnableCreate: true,
|
||||||
|
EnableRead: false, // Typically disabled for performance
|
||||||
|
EnableUpdate: true,
|
||||||
|
EnableDelete: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterCRUDHooks registers event hooks for CRUD operations
|
||||||
|
// This integrates with the restheadspec.HookRegistry to automatically
|
||||||
|
// capture database events
|
||||||
|
func RegisterCRUDHooks(broker Broker, hookRegistry *restheadspec.HookRegistry, config *CRUDHookConfig) error {
|
||||||
|
if broker == nil {
|
||||||
|
return fmt.Errorf("broker cannot be nil")
|
||||||
|
}
|
||||||
|
if hookRegistry == nil {
|
||||||
|
return fmt.Errorf("hookRegistry cannot be nil")
|
||||||
|
}
|
||||||
|
if config == nil {
|
||||||
|
config = DefaultCRUDHookConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create hook handler factory
|
||||||
|
createHookHandler := func(operation string) restheadspec.HookFunc {
|
||||||
|
return func(hookCtx *restheadspec.HookContext) error {
|
||||||
|
// Get user context from Go context
|
||||||
|
userCtx, ok := security.GetUserContext(hookCtx.Context)
|
||||||
|
if !ok || userCtx == nil {
|
||||||
|
logger.Debug("No user context found in hook")
|
||||||
|
userCtx = &security.UserContext{} // Empty user context
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create event
|
||||||
|
event := NewEvent(EventSourceDatabase, EventType(hookCtx.Schema, hookCtx.Entity, operation))
|
||||||
|
event.InstanceID = broker.InstanceID()
|
||||||
|
event.UserID = userCtx.UserID
|
||||||
|
event.SessionID = userCtx.SessionID
|
||||||
|
event.Schema = hookCtx.Schema
|
||||||
|
event.Entity = hookCtx.Entity
|
||||||
|
event.Operation = operation
|
||||||
|
|
||||||
|
// Set payload based on operation
|
||||||
|
var payload interface{}
|
||||||
|
switch operation {
|
||||||
|
case "create":
|
||||||
|
payload = hookCtx.Result
|
||||||
|
case "read":
|
||||||
|
payload = hookCtx.Result
|
||||||
|
case "update":
|
||||||
|
payload = map[string]interface{}{
|
||||||
|
"id": hookCtx.ID,
|
||||||
|
"data": hookCtx.Data,
|
||||||
|
}
|
||||||
|
case "delete":
|
||||||
|
payload = map[string]interface{}{
|
||||||
|
"id": hookCtx.ID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if payload != nil {
|
||||||
|
if err := event.SetPayload(payload); err != nil {
|
||||||
|
logger.Error("Failed to set event payload: %v", err)
|
||||||
|
payload = map[string]interface{}{"error": "failed to serialize payload"}
|
||||||
|
event.Payload, _ = json.Marshal(payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add metadata
|
||||||
|
if userCtx.UserName != "" {
|
||||||
|
event.Metadata["user_name"] = userCtx.UserName
|
||||||
|
}
|
||||||
|
if userCtx.Email != "" {
|
||||||
|
event.Metadata["user_email"] = userCtx.Email
|
||||||
|
}
|
||||||
|
if len(userCtx.Roles) > 0 {
|
||||||
|
event.Metadata["user_roles"] = userCtx.Roles
|
||||||
|
}
|
||||||
|
event.Metadata["table_name"] = hookCtx.TableName
|
||||||
|
|
||||||
|
// Publish asynchronously to not block CRUD operation
|
||||||
|
if err := broker.PublishAsync(hookCtx.Context, event); err != nil {
|
||||||
|
logger.Error("Failed to publish %s event for %s.%s: %v",
|
||||||
|
operation, hookCtx.Schema, hookCtx.Entity, err)
|
||||||
|
// Don't fail the CRUD operation if event publishing fails
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Published %s event for %s.%s (ID: %s)",
|
||||||
|
operation, hookCtx.Schema, hookCtx.Entity, event.ID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register hooks based on configuration
|
||||||
|
if config.EnableCreate {
|
||||||
|
hookRegistry.Register(restheadspec.AfterCreate, createHookHandler("create"))
|
||||||
|
logger.Info("Registered event hook for CREATE operations")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.EnableRead {
|
||||||
|
hookRegistry.Register(restheadspec.AfterRead, createHookHandler("read"))
|
||||||
|
logger.Info("Registered event hook for READ operations")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.EnableUpdate {
|
||||||
|
hookRegistry.Register(restheadspec.AfterUpdate, createHookHandler("update"))
|
||||||
|
logger.Info("Registered event hook for UPDATE operations")
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.EnableDelete {
|
||||||
|
hookRegistry.Register(restheadspec.AfterDelete, createHookHandler("delete"))
|
||||||
|
logger.Info("Registered event hook for DELETE operations")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
28
pkg/eventbroker/metrics.go
Normal file
28
pkg/eventbroker/metrics.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
// recordEventPublished records an event publication metric
|
||||||
|
func recordEventPublished(event *Event) {
|
||||||
|
if mp := metrics.GetProvider(); mp != nil {
|
||||||
|
mp.RecordEventPublished(string(event.Source), event.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordEventProcessed records an event processing metric
|
||||||
|
func recordEventProcessed(event *Event, duration time.Duration) {
|
||||||
|
if mp := metrics.GetProvider(); mp != nil {
|
||||||
|
mp.RecordEventProcessed(string(event.Source), event.Type, string(event.Status), duration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateQueueSize updates the event queue size metric
|
||||||
|
func updateQueueSize(size int64) {
|
||||||
|
if mp := metrics.GetProvider(); mp != nil {
|
||||||
|
mp.UpdateEventQueueSize(size)
|
||||||
|
}
|
||||||
|
}
|
||||||
70
pkg/eventbroker/provider.go
Normal file
70
pkg/eventbroker/provider.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider defines the storage backend interface for events
|
||||||
|
// Implementations: MemoryProvider, RedisProvider, NATSProvider, DatabaseProvider
|
||||||
|
type Provider interface {
|
||||||
|
// Store stores an event
|
||||||
|
Store(ctx context.Context, event *Event) error
|
||||||
|
|
||||||
|
// Get retrieves an event by ID
|
||||||
|
Get(ctx context.Context, id string) (*Event, error)
|
||||||
|
|
||||||
|
// List lists events with optional filters
|
||||||
|
List(ctx context.Context, filter *EventFilter) ([]*Event, error)
|
||||||
|
|
||||||
|
// UpdateStatus updates the status of an event
|
||||||
|
UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error
|
||||||
|
|
||||||
|
// Delete deletes an event by ID
|
||||||
|
Delete(ctx context.Context, id string) error
|
||||||
|
|
||||||
|
// Stream returns a channel of events for real-time consumption
|
||||||
|
// Used for cross-instance pub/sub
|
||||||
|
// The channel is closed when the context is canceled or an error occurs
|
||||||
|
Stream(ctx context.Context, pattern string) (<-chan *Event, error)
|
||||||
|
|
||||||
|
// Publish publishes an event to all subscribers (for distributed providers)
|
||||||
|
// For in-memory provider, this is the same as Store
|
||||||
|
// For Redis/NATS/Database, this triggers cross-instance delivery
|
||||||
|
Publish(ctx context.Context, event *Event) error
|
||||||
|
|
||||||
|
// Close closes the provider and releases resources
|
||||||
|
Close() error
|
||||||
|
|
||||||
|
// Stats returns provider statistics
|
||||||
|
Stats(ctx context.Context) (*ProviderStats, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EventFilter defines filter criteria for listing events
|
||||||
|
type EventFilter struct {
|
||||||
|
Source *EventSource
|
||||||
|
Status *EventStatus
|
||||||
|
UserID *int
|
||||||
|
Schema string
|
||||||
|
Entity string
|
||||||
|
Operation string
|
||||||
|
InstanceID string
|
||||||
|
StartTime *time.Time
|
||||||
|
EndTime *time.Time
|
||||||
|
Limit int
|
||||||
|
Offset int
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderStats contains statistics about the provider
|
||||||
|
type ProviderStats struct {
|
||||||
|
ProviderType string `json:"provider_type"`
|
||||||
|
TotalEvents int64 `json:"total_events"`
|
||||||
|
PendingEvents int64 `json:"pending_events"`
|
||||||
|
ProcessingEvents int64 `json:"processing_events"`
|
||||||
|
CompletedEvents int64 `json:"completed_events"`
|
||||||
|
FailedEvents int64 `json:"failed_events"`
|
||||||
|
EventsPublished int64 `json:"events_published"`
|
||||||
|
EventsConsumed int64 `json:"events_consumed"`
|
||||||
|
ActiveSubscribers int `json:"active_subscribers"`
|
||||||
|
ProviderSpecific map[string]interface{} `json:"provider_specific,omitempty"`
|
||||||
|
}
|
||||||
446
pkg/eventbroker/provider_memory.go
Normal file
446
pkg/eventbroker/provider_memory.go
Normal file
@ -0,0 +1,446 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemoryProvider implements Provider interface using in-memory storage
|
||||||
|
// Features:
|
||||||
|
// - Thread-safe event storage with RW mutex
|
||||||
|
// - LRU eviction when max events reached
|
||||||
|
// - In-process pub/sub (not cross-instance)
|
||||||
|
// - Automatic cleanup of old completed events
|
||||||
|
type MemoryProvider struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
events map[string]*Event
|
||||||
|
eventOrder []string // For LRU tracking
|
||||||
|
subscribers map[string][]chan *Event
|
||||||
|
instanceID string
|
||||||
|
maxEvents int
|
||||||
|
cleanupInterval time.Duration
|
||||||
|
maxAge time.Duration
|
||||||
|
|
||||||
|
// Statistics
|
||||||
|
stats MemoryProviderStats
|
||||||
|
|
||||||
|
// Lifecycle
|
||||||
|
stopCleanup chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
isRunning atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryProviderStats contains statistics for the memory provider
|
||||||
|
type MemoryProviderStats struct {
|
||||||
|
TotalEvents atomic.Int64
|
||||||
|
PendingEvents atomic.Int64
|
||||||
|
ProcessingEvents atomic.Int64
|
||||||
|
CompletedEvents atomic.Int64
|
||||||
|
FailedEvents atomic.Int64
|
||||||
|
EventsPublished atomic.Int64
|
||||||
|
EventsConsumed atomic.Int64
|
||||||
|
ActiveSubscribers atomic.Int32
|
||||||
|
Evictions atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryProviderOptions configures the memory provider
|
||||||
|
type MemoryProviderOptions struct {
|
||||||
|
InstanceID string
|
||||||
|
MaxEvents int
|
||||||
|
CleanupInterval time.Duration
|
||||||
|
MaxAge time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemoryProvider creates a new in-memory event provider
|
||||||
|
func NewMemoryProvider(opts MemoryProviderOptions) *MemoryProvider {
|
||||||
|
if opts.MaxEvents == 0 {
|
||||||
|
opts.MaxEvents = 10000 // Default
|
||||||
|
}
|
||||||
|
if opts.CleanupInterval == 0 {
|
||||||
|
opts.CleanupInterval = 5 * time.Minute // Default
|
||||||
|
}
|
||||||
|
if opts.MaxAge == 0 {
|
||||||
|
opts.MaxAge = 24 * time.Hour // Default: keep events for 24 hours
|
||||||
|
}
|
||||||
|
|
||||||
|
mp := &MemoryProvider{
|
||||||
|
events: make(map[string]*Event),
|
||||||
|
eventOrder: make([]string, 0),
|
||||||
|
subscribers: make(map[string][]chan *Event),
|
||||||
|
instanceID: opts.InstanceID,
|
||||||
|
maxEvents: opts.MaxEvents,
|
||||||
|
cleanupInterval: opts.CleanupInterval,
|
||||||
|
maxAge: opts.MaxAge,
|
||||||
|
stopCleanup: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.isRunning.Store(true)
|
||||||
|
|
||||||
|
// Start cleanup goroutine
|
||||||
|
mp.wg.Add(1)
|
||||||
|
go mp.cleanupLoop()
|
||||||
|
|
||||||
|
logger.Info("Memory provider initialized (max_events: %d, cleanup: %v, max_age: %v)",
|
||||||
|
opts.MaxEvents, opts.CleanupInterval, opts.MaxAge)
|
||||||
|
|
||||||
|
return mp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store stores an event
|
||||||
|
func (mp *MemoryProvider) Store(ctx context.Context, event *Event) error {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
// Check if we need to evict oldest events
|
||||||
|
if len(mp.events) >= mp.maxEvents {
|
||||||
|
mp.evictOldestLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store event
|
||||||
|
mp.events[event.ID] = event.Clone()
|
||||||
|
mp.eventOrder = append(mp.eventOrder, event.ID)
|
||||||
|
|
||||||
|
// Update statistics
|
||||||
|
mp.stats.TotalEvents.Add(1)
|
||||||
|
mp.updateStatusCountsLocked(event.Status, 1)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves an event by ID
|
||||||
|
func (mp *MemoryProvider) Get(ctx context.Context, id string) (*Event, error) {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
event, exists := mp.events[id]
|
||||||
|
if !exists {
|
||||||
|
return nil, fmt.Errorf("event not found: %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return event.Clone(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List lists events with optional filters
|
||||||
|
func (mp *MemoryProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
var results []*Event
|
||||||
|
|
||||||
|
for _, event := range mp.events {
|
||||||
|
if mp.matchesFilter(event, filter) {
|
||||||
|
results = append(results, event.Clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply limit and offset
|
||||||
|
if filter != nil {
|
||||||
|
if filter.Offset > 0 && filter.Offset < len(results) {
|
||||||
|
results = results[filter.Offset:]
|
||||||
|
}
|
||||||
|
if filter.Limit > 0 && filter.Limit < len(results) {
|
||||||
|
results = results[:filter.Limit]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateStatus updates the status of an event
|
||||||
|
func (mp *MemoryProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
event, exists := mp.events[id]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("event not found: %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update status counts
|
||||||
|
mp.updateStatusCountsLocked(event.Status, -1)
|
||||||
|
mp.updateStatusCountsLocked(status, 1)
|
||||||
|
|
||||||
|
// Update event
|
||||||
|
event.Status = status
|
||||||
|
if errorMsg != "" {
|
||||||
|
event.Error = errorMsg
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete deletes an event by ID
|
||||||
|
func (mp *MemoryProvider) Delete(ctx context.Context, id string) error {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
event, exists := mp.events[id]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("event not found: %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update counts
|
||||||
|
mp.stats.TotalEvents.Add(-1)
|
||||||
|
mp.updateStatusCountsLocked(event.Status, -1)
|
||||||
|
|
||||||
|
// Delete event
|
||||||
|
delete(mp.events, id)
|
||||||
|
|
||||||
|
// Remove from order tracking
|
||||||
|
for i, eid := range mp.eventOrder {
|
||||||
|
if eid == id {
|
||||||
|
mp.eventOrder = append(mp.eventOrder[:i], mp.eventOrder[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream returns a channel of events for real-time consumption
|
||||||
|
// Note: This is in-process only, not cross-instance
|
||||||
|
func (mp *MemoryProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
// Create buffered channel for events
|
||||||
|
ch := make(chan *Event, 100)
|
||||||
|
|
||||||
|
// Store subscriber
|
||||||
|
mp.subscribers[pattern] = append(mp.subscribers[pattern], ch)
|
||||||
|
mp.stats.ActiveSubscribers.Add(1)
|
||||||
|
|
||||||
|
// Goroutine to clean up on context cancellation
|
||||||
|
mp.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer mp.wg.Done()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
// Remove subscriber
|
||||||
|
subs := mp.subscribers[pattern]
|
||||||
|
for i, subCh := range subs {
|
||||||
|
if subCh == ch {
|
||||||
|
mp.subscribers[pattern] = append(subs[:i], subs[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.stats.ActiveSubscribers.Add(-1)
|
||||||
|
close(ch)
|
||||||
|
}()
|
||||||
|
|
||||||
|
logger.Debug("Stream created for pattern: %s", pattern)
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish publishes an event to all subscribers
|
||||||
|
func (mp *MemoryProvider) Publish(ctx context.Context, event *Event) error {
|
||||||
|
// Store the event first
|
||||||
|
if err := mp.Store(ctx, event); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.stats.EventsPublished.Add(1)
|
||||||
|
|
||||||
|
// Notify subscribers
|
||||||
|
mp.mu.RLock()
|
||||||
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
|
for pattern, channels := range mp.subscribers {
|
||||||
|
if matchPattern(pattern, event.Type) {
|
||||||
|
for _, ch := range channels {
|
||||||
|
select {
|
||||||
|
case ch <- event.Clone():
|
||||||
|
mp.stats.EventsConsumed.Add(1)
|
||||||
|
default:
|
||||||
|
// Channel full, skip
|
||||||
|
logger.Warn("Subscriber channel full for pattern: %s", pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the provider and releases resources
|
||||||
|
func (mp *MemoryProvider) Close() error {
|
||||||
|
if !mp.isRunning.Load() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
mp.isRunning.Store(false)
|
||||||
|
|
||||||
|
// Stop cleanup loop
|
||||||
|
close(mp.stopCleanup)
|
||||||
|
|
||||||
|
// Wait for goroutines
|
||||||
|
mp.wg.Wait()
|
||||||
|
|
||||||
|
// Close all subscriber channels
|
||||||
|
mp.mu.Lock()
|
||||||
|
for _, channels := range mp.subscribers {
|
||||||
|
for _, ch := range channels {
|
||||||
|
close(ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mp.subscribers = make(map[string][]chan *Event)
|
||||||
|
mp.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Memory provider closed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns provider statistics
|
||||||
|
func (mp *MemoryProvider) Stats(ctx context.Context) (*ProviderStats, error) {
|
||||||
|
return &ProviderStats{
|
||||||
|
ProviderType: "memory",
|
||||||
|
TotalEvents: mp.stats.TotalEvents.Load(),
|
||||||
|
PendingEvents: mp.stats.PendingEvents.Load(),
|
||||||
|
ProcessingEvents: mp.stats.ProcessingEvents.Load(),
|
||||||
|
CompletedEvents: mp.stats.CompletedEvents.Load(),
|
||||||
|
FailedEvents: mp.stats.FailedEvents.Load(),
|
||||||
|
EventsPublished: mp.stats.EventsPublished.Load(),
|
||||||
|
EventsConsumed: mp.stats.EventsConsumed.Load(),
|
||||||
|
ActiveSubscribers: int(mp.stats.ActiveSubscribers.Load()),
|
||||||
|
ProviderSpecific: map[string]interface{}{
|
||||||
|
"max_events": mp.maxEvents,
|
||||||
|
"cleanup_interval": mp.cleanupInterval.String(),
|
||||||
|
"max_age": mp.maxAge.String(),
|
||||||
|
"evictions": mp.stats.Evictions.Load(),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupLoop periodically cleans up old completed events
|
||||||
|
func (mp *MemoryProvider) cleanupLoop() {
|
||||||
|
defer mp.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(mp.cleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
mp.cleanup()
|
||||||
|
case <-mp.stopCleanup:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup removes old completed/failed events
|
||||||
|
func (mp *MemoryProvider) cleanup() {
|
||||||
|
mp.mu.Lock()
|
||||||
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
|
cutoff := time.Now().Add(-mp.maxAge)
|
||||||
|
removed := 0
|
||||||
|
|
||||||
|
for id, event := range mp.events {
|
||||||
|
// Only clean up completed or failed events that are old
|
||||||
|
if (event.Status == EventStatusCompleted || event.Status == EventStatusFailed) &&
|
||||||
|
event.CreatedAt.Before(cutoff) {
|
||||||
|
|
||||||
|
delete(mp.events, id)
|
||||||
|
mp.stats.TotalEvents.Add(-1)
|
||||||
|
mp.updateStatusCountsLocked(event.Status, -1)
|
||||||
|
|
||||||
|
// Remove from order tracking
|
||||||
|
for i, eid := range mp.eventOrder {
|
||||||
|
if eid == id {
|
||||||
|
mp.eventOrder = append(mp.eventOrder[:i], mp.eventOrder[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
removed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if removed > 0 {
|
||||||
|
logger.Debug("Cleanup removed %d old events", removed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// evictOldestLocked evicts the oldest event (LRU)
|
||||||
|
// Caller must hold write lock
|
||||||
|
func (mp *MemoryProvider) evictOldestLocked() {
|
||||||
|
if len(mp.eventOrder) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get oldest event ID
|
||||||
|
oldestID := mp.eventOrder[0]
|
||||||
|
mp.eventOrder = mp.eventOrder[1:]
|
||||||
|
|
||||||
|
// Remove event
|
||||||
|
if event, exists := mp.events[oldestID]; exists {
|
||||||
|
delete(mp.events, oldestID)
|
||||||
|
mp.stats.TotalEvents.Add(-1)
|
||||||
|
mp.updateStatusCountsLocked(event.Status, -1)
|
||||||
|
mp.stats.Evictions.Add(1)
|
||||||
|
|
||||||
|
logger.Debug("Evicted oldest event: %s", oldestID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesFilter checks if an event matches the filter criteria
|
||||||
|
func (mp *MemoryProvider) matchesFilter(event *Event, filter *EventFilter) bool {
|
||||||
|
if filter == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if filter.Source != nil && event.Source != *filter.Source {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.Status != nil && event.Status != *filter.Status {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.UserID != nil && event.UserID != *filter.UserID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.Schema != "" && event.Schema != filter.Schema {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.Entity != "" && event.Entity != filter.Entity {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.Operation != "" && event.Operation != filter.Operation {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.InstanceID != "" && event.InstanceID != filter.InstanceID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.StartTime != nil && event.CreatedAt.Before(*filter.StartTime) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if filter.EndTime != nil && event.CreatedAt.After(*filter.EndTime) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateStatusCountsLocked updates status statistics
|
||||||
|
// Caller must hold write lock
|
||||||
|
func (mp *MemoryProvider) updateStatusCountsLocked(status EventStatus, delta int64) {
|
||||||
|
switch status {
|
||||||
|
case EventStatusPending:
|
||||||
|
mp.stats.PendingEvents.Add(delta)
|
||||||
|
case EventStatusProcessing:
|
||||||
|
mp.stats.ProcessingEvents.Add(delta)
|
||||||
|
case EventStatusCompleted:
|
||||||
|
mp.stats.CompletedEvents.Add(delta)
|
||||||
|
case EventStatusFailed:
|
||||||
|
mp.stats.FailedEvents.Add(delta)
|
||||||
|
}
|
||||||
|
}
|
||||||
419
pkg/eventbroker/provider_memory_test.go
Normal file
419
pkg/eventbroker/provider_memory_test.go
Normal file
@ -0,0 +1,419 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewMemoryProvider(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
MaxEvents: 100,
|
||||||
|
CleanupInterval: 1 * time.Minute,
|
||||||
|
})
|
||||||
|
|
||||||
|
if provider == nil {
|
||||||
|
t.Fatal("Expected non-nil provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := provider.Stats(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stats failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.ProviderType != "memory" {
|
||||||
|
t.Errorf("Expected provider type 'memory', got %s", stats.ProviderType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderPublishAndGet(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
event.UserID = 123
|
||||||
|
|
||||||
|
// Publish event
|
||||||
|
if err := provider.Publish(context.Background(), event); err != nil {
|
||||||
|
t.Fatalf("Publish failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get event
|
||||||
|
retrieved, err := provider.Get(context.Background(), event.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Get failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if retrieved.ID != event.ID {
|
||||||
|
t.Errorf("Expected event ID %s, got %s", event.ID, retrieved.ID)
|
||||||
|
}
|
||||||
|
if retrieved.UserID != 123 {
|
||||||
|
t.Errorf("Expected user ID 123, got %d", retrieved.UserID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderGetNonExistent(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := provider.Get(context.Background(), "non-existent-id")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when getting non-existent event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderUpdateStatus(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
|
||||||
|
// Update status to processing
|
||||||
|
err := provider.UpdateStatus(context.Background(), event.ID, EventStatusProcessing, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UpdateStatus failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
retrieved, _ := provider.Get(context.Background(), event.ID)
|
||||||
|
if retrieved.Status != EventStatusProcessing {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusProcessing, retrieved.Status)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update status to failed with error
|
||||||
|
err = provider.UpdateStatus(context.Background(), event.ID, EventStatusFailed, "test error")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("UpdateStatus failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
retrieved, _ = provider.Get(context.Background(), event.ID)
|
||||||
|
if retrieved.Status != EventStatusFailed {
|
||||||
|
t.Errorf("Expected status %s, got %s", EventStatusFailed, retrieved.Status)
|
||||||
|
}
|
||||||
|
if retrieved.Error != "test error" {
|
||||||
|
t.Errorf("Expected error 'test error', got %s", retrieved.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderList(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish multiple events
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List all events
|
||||||
|
events, err := provider.List(context.Background(), &EventFilter{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) != 5 {
|
||||||
|
t.Errorf("Expected 5 events, got %d", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderListWithFilter(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish events with different types
|
||||||
|
event1 := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
provider.Publish(context.Background(), event1)
|
||||||
|
|
||||||
|
event2 := NewEvent(EventSourceDatabase, "public.roles.create")
|
||||||
|
provider.Publish(context.Background(), event2)
|
||||||
|
|
||||||
|
event3 := NewEvent(EventSourceWebSocket, "chat.message")
|
||||||
|
provider.Publish(context.Background(), event3)
|
||||||
|
|
||||||
|
// Filter by source
|
||||||
|
source := EventSourceDatabase
|
||||||
|
events, err := provider.List(context.Background(), &EventFilter{
|
||||||
|
Source: &source,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) != 2 {
|
||||||
|
t.Errorf("Expected 2 events with database source, got %d", len(events))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter by status
|
||||||
|
status := EventStatusPending
|
||||||
|
events, err = provider.List(context.Background(), &EventFilter{
|
||||||
|
Status: &status,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) != 3 {
|
||||||
|
t.Errorf("Expected 3 events with pending status, got %d", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderListWithLimit(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish multiple events
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
event := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List with limit
|
||||||
|
events, err := provider.List(context.Background(), &EventFilter{
|
||||||
|
Limit: 5,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) != 5 {
|
||||||
|
t.Errorf("Expected 5 events (limited), got %d", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderDelete(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
|
||||||
|
// Delete event
|
||||||
|
err := provider.Delete(context.Background(), event.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Delete failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify deleted
|
||||||
|
_, err = provider.Get(context.Background(), event.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when getting deleted event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderLRUEviction(t *testing.T) {
|
||||||
|
// Create provider with small max events
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
MaxEvents: 3,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish 5 events
|
||||||
|
events := make([]*Event, 5)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
events[i] = NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), events[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
// First 2 events should be evicted
|
||||||
|
_, err := provider.Get(context.Background(), events[0].ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected first event to be evicted")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = provider.Get(context.Background(), events[1].ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected second event to be evicted")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Last 3 events should still exist
|
||||||
|
for i := 2; i < 5; i++ {
|
||||||
|
_, err := provider.Get(context.Background(), events[i].ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected event %d to still exist", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderCleanup(t *testing.T) {
|
||||||
|
// Create provider with short cleanup interval
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
CleanupInterval: 100 * time.Millisecond,
|
||||||
|
MaxAge: 200 * time.Millisecond,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish and complete an event
|
||||||
|
event := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
provider.UpdateStatus(context.Background(), event.ID, EventStatusCompleted, "")
|
||||||
|
|
||||||
|
// Wait for cleanup to run
|
||||||
|
time.Sleep(400 * time.Millisecond)
|
||||||
|
|
||||||
|
// Event should be cleaned up
|
||||||
|
_, err := provider.Get(context.Background(), event.ID)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected event to be cleaned up")
|
||||||
|
}
|
||||||
|
|
||||||
|
provider.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderStats(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
MaxEvents: 100,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish events
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
event := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, err := provider.Stats(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stats failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stats.ProviderType != "memory" {
|
||||||
|
t.Errorf("Expected provider type 'memory', got %s", stats.ProviderType)
|
||||||
|
}
|
||||||
|
if stats.TotalEvents != 5 {
|
||||||
|
t.Errorf("Expected 5 total events, got %d", stats.TotalEvents)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderClose(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
CleanupInterval: 100 * time.Millisecond,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish event
|
||||||
|
event := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
|
||||||
|
// Close provider
|
||||||
|
err := provider.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Close failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup goroutine should be stopped
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderConcurrency(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Concurrent publish
|
||||||
|
done := make(chan bool, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func() {
|
||||||
|
defer func() { done <- true }()
|
||||||
|
event := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all events were stored
|
||||||
|
events, _ := provider.List(context.Background(), &EventFilter{})
|
||||||
|
if len(events) != 10 {
|
||||||
|
t.Errorf("Expected 10 events, got %d", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderStream(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Stream is implemented for memory provider (in-process pub/sub)
|
||||||
|
ch, err := provider.Stream(context.Background(), "test.*")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stream failed: %v", err)
|
||||||
|
}
|
||||||
|
if ch == nil {
|
||||||
|
t.Error("Expected non-nil channel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderTimeRangeFilter(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish events at different times
|
||||||
|
event1 := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event1)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
event2 := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event2)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
event3 := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
provider.Publish(context.Background(), event3)
|
||||||
|
|
||||||
|
// Filter by time range
|
||||||
|
startTime := event2.CreatedAt.Add(-1 * time.Millisecond)
|
||||||
|
events, err := provider.List(context.Background(), &EventFilter{
|
||||||
|
StartTime: &startTime,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should get events 2 and 3
|
||||||
|
if len(events) != 2 {
|
||||||
|
t.Errorf("Expected 2 events after start time, got %d", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryProviderInstanceIDFilter(t *testing.T) {
|
||||||
|
provider := NewMemoryProvider(MemoryProviderOptions{
|
||||||
|
InstanceID: "test-instance",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Publish events with different instance IDs
|
||||||
|
event1 := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
event1.InstanceID = "instance-1"
|
||||||
|
provider.Publish(context.Background(), event1)
|
||||||
|
|
||||||
|
event2 := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
event2.InstanceID = "instance-2"
|
||||||
|
provider.Publish(context.Background(), event2)
|
||||||
|
|
||||||
|
// Filter by instance ID
|
||||||
|
events, err := provider.List(context.Background(), &EventFilter{
|
||||||
|
InstanceID: "instance-1",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("List failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Errorf("Expected 1 event with instance-1, got %d", len(events))
|
||||||
|
}
|
||||||
|
if events[0].InstanceID != "instance-1" {
|
||||||
|
t.Errorf("Expected instance ID 'instance-1', got %s", events[0].InstanceID)
|
||||||
|
}
|
||||||
|
}
|
||||||
140
pkg/eventbroker/subscription.go
Normal file
140
pkg/eventbroker/subscription.go
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SubscriptionID uniquely identifies a subscription
|
||||||
|
type SubscriptionID string
|
||||||
|
|
||||||
|
// subscription represents a single subscription with its handler and pattern
|
||||||
|
type subscription struct {
|
||||||
|
id SubscriptionID
|
||||||
|
pattern string
|
||||||
|
handler EventHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
// subscriptionManager manages event subscriptions and pattern matching
|
||||||
|
type subscriptionManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
subscriptions map[SubscriptionID]*subscription
|
||||||
|
nextID atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// newSubscriptionManager creates a new subscription manager
|
||||||
|
func newSubscriptionManager() *subscriptionManager {
|
||||||
|
return &subscriptionManager{
|
||||||
|
subscriptions: make(map[SubscriptionID]*subscription),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe adds a new subscription
|
||||||
|
func (sm *subscriptionManager) Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
|
||||||
|
if pattern == "" {
|
||||||
|
return "", fmt.Errorf("pattern cannot be empty")
|
||||||
|
}
|
||||||
|
if handler == nil {
|
||||||
|
return "", fmt.Errorf("handler cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
id := SubscriptionID(fmt.Sprintf("sub-%d", sm.nextID.Add(1)))
|
||||||
|
|
||||||
|
sm.mu.Lock()
|
||||||
|
sm.subscriptions[id] = &subscription{
|
||||||
|
id: id,
|
||||||
|
pattern: pattern,
|
||||||
|
handler: handler,
|
||||||
|
}
|
||||||
|
sm.mu.Unlock()
|
||||||
|
|
||||||
|
logger.Info("Subscribed to pattern '%s' with ID: %s", pattern, id)
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe removes a subscription
|
||||||
|
func (sm *subscriptionManager) Unsubscribe(id SubscriptionID) error {
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
|
if _, exists := sm.subscriptions[id]; !exists {
|
||||||
|
return fmt.Errorf("subscription not found: %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(sm.subscriptions, id)
|
||||||
|
logger.Info("Unsubscribed: %s", id)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetMatching returns all handlers that match the event type
|
||||||
|
func (sm *subscriptionManager) GetMatching(eventType string) []EventHandler {
|
||||||
|
sm.mu.RLock()
|
||||||
|
defer sm.mu.RUnlock()
|
||||||
|
|
||||||
|
var handlers []EventHandler
|
||||||
|
for _, sub := range sm.subscriptions {
|
||||||
|
if matchPattern(sub.pattern, eventType) {
|
||||||
|
handlers = append(handlers, sub.handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return handlers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of active subscriptions
|
||||||
|
func (sm *subscriptionManager) Count() int {
|
||||||
|
sm.mu.RLock()
|
||||||
|
defer sm.mu.RUnlock()
|
||||||
|
return len(sm.subscriptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all subscriptions
|
||||||
|
func (sm *subscriptionManager) Clear() {
|
||||||
|
sm.mu.Lock()
|
||||||
|
defer sm.mu.Unlock()
|
||||||
|
sm.subscriptions = make(map[SubscriptionID]*subscription)
|
||||||
|
logger.Info("Cleared all subscriptions")
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchPattern implements glob-style pattern matching for event types
|
||||||
|
// Patterns:
|
||||||
|
// - "*" matches any single segment
|
||||||
|
// - "a.b.c" matches exactly "a.b.c"
|
||||||
|
// - "a.*.c" matches "a.anything.c"
|
||||||
|
// - "a.b.*" matches any operation on a.b
|
||||||
|
// - "*" matches everything
|
||||||
|
//
|
||||||
|
// Event type format: schema.entity.operation (e.g., "public.users.create")
|
||||||
|
func matchPattern(pattern, eventType string) bool {
|
||||||
|
// Wildcard matches everything
|
||||||
|
if pattern == "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exact match
|
||||||
|
if pattern == eventType {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split pattern and event type by dots
|
||||||
|
patternParts := strings.Split(pattern, ".")
|
||||||
|
eventParts := strings.Split(eventType, ".")
|
||||||
|
|
||||||
|
// Different number of parts can only match if pattern has wildcards
|
||||||
|
if len(patternParts) != len(eventParts) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match each part
|
||||||
|
for i := range patternParts {
|
||||||
|
if patternParts[i] != "*" && patternParts[i] != eventParts[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
270
pkg/eventbroker/subscription_test.go
Normal file
270
pkg/eventbroker/subscription_test.go
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMatchPattern(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
pattern string
|
||||||
|
eventType string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
// Exact matches
|
||||||
|
{"public.users.create", "public.users.create", true},
|
||||||
|
{"public.users.create", "public.users.update", false},
|
||||||
|
|
||||||
|
// Wildcard matches
|
||||||
|
{"*", "public.users.create", true},
|
||||||
|
{"*", "anything", true},
|
||||||
|
{"public.*", "public.users", true},
|
||||||
|
{"public.*", "public.users.create", false}, // Different number of parts
|
||||||
|
{"public.*", "admin.users", false},
|
||||||
|
{"*.users.create", "public.users.create", true},
|
||||||
|
{"*.users.create", "admin.users.create", true},
|
||||||
|
{"*.users.create", "public.roles.create", false},
|
||||||
|
{"public.*.create", "public.users.create", true},
|
||||||
|
{"public.*.create", "public.roles.create", true},
|
||||||
|
{"public.*.create", "public.users.update", false},
|
||||||
|
|
||||||
|
// Multiple wildcards
|
||||||
|
{"*.*", "public.users", true},
|
||||||
|
{"*.*", "public.users.create", false}, // Different number of parts
|
||||||
|
{"*.*.create", "public.users.create", true},
|
||||||
|
{"*.*.create", "admin.roles.create", true},
|
||||||
|
{"*.*.create", "public.users.update", false},
|
||||||
|
|
||||||
|
// Edge cases
|
||||||
|
{"", "", true},
|
||||||
|
{"", "something", false},
|
||||||
|
{"something", "", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.pattern+"_vs_"+tt.eventType, func(t *testing.T) {
|
||||||
|
result := matchPattern(tt.pattern, tt.eventType)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("matchPattern(%q, %q) = %v, expected %v",
|
||||||
|
tt.pattern, tt.eventType, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionManager(t *testing.T) {
|
||||||
|
manager := newSubscriptionManager()
|
||||||
|
|
||||||
|
// Create test handler
|
||||||
|
called := false
|
||||||
|
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test Subscribe
|
||||||
|
id, err := manager.Subscribe("public.users.*", handler)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Subscribe failed: %v", err)
|
||||||
|
}
|
||||||
|
if id == "" {
|
||||||
|
t.Fatal("Expected non-empty subscription ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GetMatching
|
||||||
|
handlers := manager.GetMatching("public.users.create")
|
||||||
|
if len(handlers) != 1 {
|
||||||
|
t.Fatalf("Expected 1 handler, got %d", len(handlers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test handler execution
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
if err := handlers[0].Handle(context.Background(), event); err != nil {
|
||||||
|
t.Fatalf("Handler execution failed: %v", err)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Error("Expected handler to be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Count
|
||||||
|
if manager.Count() != 1 {
|
||||||
|
t.Errorf("Expected count 1, got %d", manager.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Unsubscribe
|
||||||
|
if err := manager.Unsubscribe(id); err != nil {
|
||||||
|
t.Fatalf("Unsubscribe failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify unsubscribed
|
||||||
|
handlers = manager.GetMatching("public.users.create")
|
||||||
|
if len(handlers) != 0 {
|
||||||
|
t.Errorf("Expected 0 handlers after unsubscribe, got %d", len(handlers))
|
||||||
|
}
|
||||||
|
if manager.Count() != 0 {
|
||||||
|
t.Errorf("Expected count 0 after unsubscribe, got %d", manager.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionManagerMultipleHandlers(t *testing.T) {
|
||||||
|
manager := newSubscriptionManager()
|
||||||
|
|
||||||
|
called1 := false
|
||||||
|
handler1 := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called1 = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
called2 := false
|
||||||
|
handler2 := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called2 = true
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Subscribe multiple handlers
|
||||||
|
id1, _ := manager.Subscribe("public.users.*", handler1)
|
||||||
|
id2, _ := manager.Subscribe("*.users.*", handler2)
|
||||||
|
|
||||||
|
// Both should match
|
||||||
|
handlers := manager.GetMatching("public.users.create")
|
||||||
|
if len(handlers) != 2 {
|
||||||
|
t.Fatalf("Expected 2 handlers, got %d", len(handlers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute all handlers
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
for _, h := range handlers {
|
||||||
|
h.Handle(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !called1 || !called2 {
|
||||||
|
t.Error("Expected both handlers to be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe one
|
||||||
|
manager.Unsubscribe(id1)
|
||||||
|
handlers = manager.GetMatching("public.users.create")
|
||||||
|
if len(handlers) != 1 {
|
||||||
|
t.Errorf("Expected 1 handler after unsubscribe, got %d", len(handlers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unsubscribe remaining
|
||||||
|
manager.Unsubscribe(id2)
|
||||||
|
if manager.Count() != 0 {
|
||||||
|
t.Errorf("Expected count 0 after all unsubscribe, got %d", manager.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionManagerConcurrency(t *testing.T) {
|
||||||
|
manager := newSubscriptionManager()
|
||||||
|
|
||||||
|
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Subscribe and unsubscribe concurrently
|
||||||
|
done := make(chan bool, 10)
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
go func() {
|
||||||
|
defer func() { done <- true }()
|
||||||
|
id, _ := manager.Subscribe("test.*", handler)
|
||||||
|
manager.GetMatching("test.event")
|
||||||
|
manager.Unsubscribe(id)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all goroutines
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have no subscriptions left
|
||||||
|
if manager.Count() != 0 {
|
||||||
|
t.Errorf("Expected count 0 after concurrent operations, got %d", manager.Count())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionManagerUnsubscribeNonExistent(t *testing.T) {
|
||||||
|
manager := newSubscriptionManager()
|
||||||
|
|
||||||
|
// Try to unsubscribe a non-existent ID
|
||||||
|
err := manager.Unsubscribe("non-existent-id")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when unsubscribing non-existent ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionIDGeneration(t *testing.T) {
|
||||||
|
manager := newSubscriptionManager()
|
||||||
|
|
||||||
|
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Subscribe multiple times and ensure unique IDs
|
||||||
|
ids := make(map[SubscriptionID]bool)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
id, _ := manager.Subscribe("test.*", handler)
|
||||||
|
if ids[id] {
|
||||||
|
t.Fatalf("Duplicate subscription ID: %s", id)
|
||||||
|
}
|
||||||
|
ids[id] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEventHandlerFunc(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
var receivedEvent *Event
|
||||||
|
|
||||||
|
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
called = true
|
||||||
|
receivedEvent = event
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
event := NewEvent(EventSourceDatabase, "test.event")
|
||||||
|
err := handler.Handle(context.Background(), event)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if !called {
|
||||||
|
t.Error("Expected handler to be called")
|
||||||
|
}
|
||||||
|
if receivedEvent != event {
|
||||||
|
t.Error("Expected to receive the same event")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSubscriptionManagerPatternPriority(t *testing.T) {
|
||||||
|
manager := newSubscriptionManager()
|
||||||
|
|
||||||
|
// More specific patterns should still match
|
||||||
|
specificCalled := false
|
||||||
|
genericCalled := false
|
||||||
|
|
||||||
|
manager.Subscribe("public.users.create", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
specificCalled = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
manager.Subscribe("*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
|
||||||
|
genericCalled = true
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
|
||||||
|
handlers := manager.GetMatching("public.users.create")
|
||||||
|
if len(handlers) != 2 {
|
||||||
|
t.Fatalf("Expected 2 matching handlers, got %d", len(handlers))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute all handlers
|
||||||
|
event := NewEvent(EventSourceDatabase, "public.users.create")
|
||||||
|
for _, h := range handlers {
|
||||||
|
h.Handle(context.Background(), event)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !specificCalled || !genericCalled {
|
||||||
|
t.Error("Expected both specific and generic handlers to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
141
pkg/eventbroker/worker_pool.go
Normal file
141
pkg/eventbroker/worker_pool.go
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
package eventbroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// workerPool manages a pool of workers for async event processing
|
||||||
|
type workerPool struct {
|
||||||
|
workerCount int
|
||||||
|
bufferSize int
|
||||||
|
eventQueue chan *Event
|
||||||
|
processor func(context.Context, *Event) error
|
||||||
|
|
||||||
|
activeWorkers atomic.Int32
|
||||||
|
isRunning atomic.Bool
|
||||||
|
stopCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// newWorkerPool creates a new worker pool
|
||||||
|
func newWorkerPool(workerCount, bufferSize int, processor func(context.Context, *Event) error) *workerPool {
|
||||||
|
return &workerPool{
|
||||||
|
workerCount: workerCount,
|
||||||
|
bufferSize: bufferSize,
|
||||||
|
eventQueue: make(chan *Event, bufferSize),
|
||||||
|
processor: processor,
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the worker pool
|
||||||
|
func (wp *workerPool) Start() {
|
||||||
|
if wp.isRunning.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wp.isRunning.Store(true)
|
||||||
|
|
||||||
|
// Start workers
|
||||||
|
for i := 0; i < wp.workerCount; i++ {
|
||||||
|
wp.wg.Add(1)
|
||||||
|
go wp.worker(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Worker pool started with %d workers", wp.workerCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the worker pool gracefully
|
||||||
|
func (wp *workerPool) Stop(ctx context.Context) error {
|
||||||
|
if !wp.isRunning.Load() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
wp.isRunning.Store(false)
|
||||||
|
|
||||||
|
// Close event queue to signal workers
|
||||||
|
close(wp.eventQueue)
|
||||||
|
|
||||||
|
// Wait for workers to finish with context timeout
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wp.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
logger.Info("Worker pool stopped gracefully")
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
logger.Warn("Worker pool stop timed out, some events may be lost")
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Submit submits an event to the queue
|
||||||
|
func (wp *workerPool) Submit(ctx context.Context, event *Event) error {
|
||||||
|
if !wp.isRunning.Load() {
|
||||||
|
return ErrWorkerPoolStopped
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case wp.eventQueue <- event:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
return ErrQueueFull
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker is a worker goroutine that processes events from the queue
|
||||||
|
func (wp *workerPool) worker(id int) {
|
||||||
|
defer wp.wg.Done()
|
||||||
|
|
||||||
|
logger.Debug("Worker %d started", id)
|
||||||
|
|
||||||
|
for event := range wp.eventQueue {
|
||||||
|
wp.activeWorkers.Add(1)
|
||||||
|
|
||||||
|
// Process event with background context (detached from original request)
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := wp.processor(ctx, event); err != nil {
|
||||||
|
logger.Error("Worker %d failed to process event %s: %v", id, event.ID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wp.activeWorkers.Add(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Worker %d stopped", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueueSize returns the current queue size
|
||||||
|
func (wp *workerPool) QueueSize() int {
|
||||||
|
return len(wp.eventQueue)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ActiveWorkers returns the number of currently active workers
|
||||||
|
func (wp *workerPool) ActiveWorkers() int {
|
||||||
|
return int(wp.activeWorkers.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error definitions
|
||||||
|
var (
|
||||||
|
ErrWorkerPoolStopped = &BrokerError{Code: "worker_pool_stopped", Message: "worker pool is stopped"}
|
||||||
|
ErrQueueFull = &BrokerError{Code: "queue_full", Message: "event queue is full"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// BrokerError represents an error from the event broker
|
||||||
|
type BrokerError struct {
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *BrokerError) Error() string {
|
||||||
|
return e.Message
|
||||||
|
}
|
||||||
@ -20,8 +20,23 @@ import (
|
|||||||
|
|
||||||
// Handler handles function-based SQL API requests
|
// Handler handles function-based SQL API requests
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
db common.Database
|
db common.Database
|
||||||
hooks *HookRegistry
|
hooks *HookRegistry
|
||||||
|
variablesCallback func(r *http.Request) map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type SqlQueryOptions struct {
|
||||||
|
NoCount bool
|
||||||
|
BlankParams bool
|
||||||
|
AllowFilter bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSqlQueryOptions() SqlQueryOptions {
|
||||||
|
return SqlQueryOptions{
|
||||||
|
NoCount: false,
|
||||||
|
BlankParams: true,
|
||||||
|
AllowFilter: true,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHandler creates a new function API handler
|
// NewHandler creates a new function API handler
|
||||||
@ -38,6 +53,14 @@ func (h *Handler) GetDatabase() common.Database {
|
|||||||
return h.db
|
return h.db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *Handler) SetVariablesCallback(callback func(r *http.Request) map[string]interface{}) {
|
||||||
|
h.variablesCallback = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) GetVariablesCallback() func(r *http.Request) map[string]interface{} {
|
||||||
|
return h.variablesCallback
|
||||||
|
}
|
||||||
|
|
||||||
// Hooks returns the hook registry for this handler
|
// Hooks returns the hook registry for this handler
|
||||||
// Use this to register custom hooks for operations
|
// Use this to register custom hooks for operations
|
||||||
func (h *Handler) Hooks() *HookRegistry {
|
func (h *Handler) Hooks() *HookRegistry {
|
||||||
@ -48,7 +71,7 @@ func (h *Handler) Hooks() *HookRegistry {
|
|||||||
type HTTPFuncType func(http.ResponseWriter, *http.Request)
|
type HTTPFuncType func(http.ResponseWriter, *http.Request)
|
||||||
|
|
||||||
// SqlQueryList creates an HTTP handler that executes a SQL query and returns a list with pagination
|
// SqlQueryList creates an HTTP handler that executes a SQL query and returns a list with pagination
|
||||||
func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFilter bool) HTTPFuncType {
|
func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFuncType {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
@ -70,6 +93,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
inputvars := make([]string, 0)
|
inputvars := make([]string, 0)
|
||||||
metainfo := make(map[string]interface{})
|
metainfo := make(map[string]interface{})
|
||||||
variables := make(map[string]interface{})
|
variables := make(map[string]interface{})
|
||||||
|
|
||||||
complexAPI := false
|
complexAPI := false
|
||||||
|
|
||||||
// Get user context from security package
|
// Get user context from security package
|
||||||
@ -93,9 +117,9 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
MetaInfo: metainfo,
|
MetaInfo: metainfo,
|
||||||
PropQry: propQry,
|
PropQry: propQry,
|
||||||
UserContext: userCtx,
|
UserContext: userCtx,
|
||||||
NoCount: pNoCount,
|
NoCount: options.NoCount,
|
||||||
BlankParams: pBlankparms,
|
BlankParams: options.BlankParams,
|
||||||
AllowFilter: pAllowFilter,
|
AllowFilter: options.AllowFilter,
|
||||||
ComplexAPI: complexAPI,
|
ComplexAPI: complexAPI,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,13 +155,13 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
complexAPI = reqParams.ComplexAPI
|
complexAPI = reqParams.ComplexAPI
|
||||||
|
|
||||||
// Merge query string parameters
|
// Merge query string parameters
|
||||||
sqlquery = h.mergeQueryParams(r, sqlquery, variables, pAllowFilter, propQry)
|
sqlquery = h.mergeQueryParams(r, sqlquery, variables, options.AllowFilter, propQry)
|
||||||
|
|
||||||
// Merge header parameters
|
// Merge header parameters
|
||||||
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
|
sqlquery = h.mergeHeaderParams(r, sqlquery, variables, propQry, &complexAPI)
|
||||||
|
|
||||||
// Apply filters from parsed parameters (if not already applied by pAllowFilter)
|
// Apply filters from parsed parameters (if not already applied by pAllowFilter)
|
||||||
if !pAllowFilter {
|
if !options.AllowFilter {
|
||||||
sqlquery = h.ApplyFilters(sqlquery, reqParams)
|
sqlquery = h.ApplyFilters(sqlquery, reqParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,7 +173,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
|
|
||||||
// Override pNoCount if skipcount is specified
|
// Override pNoCount if skipcount is specified
|
||||||
if reqParams.SkipCount {
|
if reqParams.SkipCount {
|
||||||
pNoCount = true
|
options.NoCount = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build metainfo
|
// Build metainfo
|
||||||
@ -164,7 +188,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
|
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
|
||||||
|
|
||||||
// Remove unused input variables
|
// Remove unused input variables
|
||||||
if pBlankparms {
|
if options.BlankParams {
|
||||||
for _, kw := range inputvars {
|
for _, kw := range inputvars {
|
||||||
replacement := getReplacementForBlankParam(sqlquery, kw)
|
replacement := getReplacementForBlankParam(sqlquery, kw)
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
||||||
@ -205,7 +229,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
sqlquery = fmt.Sprintf("%s \nORDER BY %s", sqlquery, ValidSQL(sortcols, "select"))
|
sqlquery = fmt.Sprintf("%s \nORDER BY %s", sqlquery, ValidSQL(sortcols, "select"))
|
||||||
}
|
}
|
||||||
|
|
||||||
if !pNoCount {
|
if !options.NoCount {
|
||||||
if limit > 0 && offset > 0 {
|
if limit > 0 && offset > 0 {
|
||||||
sqlquery = fmt.Sprintf("%s \nLIMIT %d OFFSET %d", sqlquery, limit, offset)
|
sqlquery = fmt.Sprintf("%s \nLIMIT %d OFFSET %d", sqlquery, limit, offset)
|
||||||
} else if limit > 0 {
|
} else if limit > 0 {
|
||||||
@ -244,7 +268,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
// Normalize PostgreSQL types for proper JSON marshaling
|
// Normalize PostgreSQL types for proper JSON marshaling
|
||||||
dbobjlist = normalizePostgresTypesList(rows)
|
dbobjlist = normalizePostgresTypesList(rows)
|
||||||
|
|
||||||
if pNoCount {
|
if options.NoCount {
|
||||||
total = int64(len(dbobjlist))
|
total = int64(len(dbobjlist))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -386,7 +410,7 @@ func (h *Handler) SqlQueryList(sqlquery string, pNoCount, pBlankparms, pAllowFil
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SqlQuery creates an HTTP handler that executes a SQL query and returns a single record
|
// SqlQuery creates an HTTP handler that executes a SQL query and returns a single record
|
||||||
func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncType {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
@ -406,6 +430,7 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
|||||||
inputvars := make([]string, 0)
|
inputvars := make([]string, 0)
|
||||||
metainfo := make(map[string]interface{})
|
metainfo := make(map[string]interface{})
|
||||||
variables := make(map[string]interface{})
|
variables := make(map[string]interface{})
|
||||||
|
|
||||||
dbobj := make(map[string]interface{})
|
dbobj := make(map[string]interface{})
|
||||||
complexAPI := false
|
complexAPI := false
|
||||||
|
|
||||||
@ -430,7 +455,7 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
|||||||
MetaInfo: metainfo,
|
MetaInfo: metainfo,
|
||||||
PropQry: propQry,
|
PropQry: propQry,
|
||||||
UserContext: userCtx,
|
UserContext: userCtx,
|
||||||
BlankParams: pBlankparms,
|
BlankParams: options.BlankParams,
|
||||||
ComplexAPI: complexAPI,
|
ComplexAPI: complexAPI,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -507,7 +532,7 @@ func (h *Handler) SqlQuery(sqlquery string, pBlankparms bool) HTTPFuncType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remove unused input variables
|
// Remove unused input variables
|
||||||
if pBlankparms {
|
if options.BlankParams {
|
||||||
for _, kw := range inputvars {
|
for _, kw := range inputvars {
|
||||||
replacement := getReplacementForBlankParam(sqlquery, kw)
|
replacement := getReplacementForBlankParam(sqlquery, kw)
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
||||||
@ -631,8 +656,18 @@ func (h *Handler) extractInputVariables(sqlquery string, inputvars *[]string) st
|
|||||||
|
|
||||||
// mergePathParams merges URL path parameters into the SQL query
|
// mergePathParams merges URL path parameters into the SQL query
|
||||||
func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables map[string]interface{}) string {
|
func (h *Handler) mergePathParams(r *http.Request, sqlquery string, variables map[string]interface{}) string {
|
||||||
// Note: Path parameters would typically come from a router like gorilla/mux
|
|
||||||
// For now, this is a placeholder for path parameter extraction
|
if h.GetVariablesCallback() != nil {
|
||||||
|
pathVars := h.GetVariablesCallback()(r)
|
||||||
|
for k, v := range pathVars {
|
||||||
|
kword := fmt.Sprintf("[%s]", k)
|
||||||
|
if strings.Contains(sqlquery, kword) {
|
||||||
|
sqlquery = strings.ReplaceAll(sqlquery, kword, fmt.Sprintf("%v", v))
|
||||||
|
}
|
||||||
|
variables[k] = v
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
return sqlquery
|
return sqlquery
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -70,6 +70,10 @@ func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Data
|
|||||||
return fn(m)
|
return fn(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *MockDatabase) GetUnderlyingDB() interface{} {
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
// MockResult implements common.Result interface for testing
|
// MockResult implements common.Result interface for testing
|
||||||
type MockResult struct {
|
type MockResult struct {
|
||||||
rows int64
|
rows int64
|
||||||
@ -532,7 +536,7 @@ func TestSqlQuery(t *testing.T) {
|
|||||||
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handlerFunc := handler.SqlQuery(tt.sqlQuery, tt.blankParams)
|
handlerFunc := handler.SqlQuery(tt.sqlQuery, SqlQueryOptions{BlankParams: tt.blankParams})
|
||||||
handlerFunc(w, req)
|
handlerFunc(w, req)
|
||||||
|
|
||||||
if w.Code != tt.expectedStatus {
|
if w.Code != tt.expectedStatus {
|
||||||
@ -655,7 +659,7 @@ func TestSqlQueryList(t *testing.T) {
|
|||||||
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handlerFunc := handler.SqlQueryList(tt.sqlQuery, tt.noCount, tt.blankParams, tt.allowFilter)
|
handlerFunc := handler.SqlQueryList(tt.sqlQuery, SqlQueryOptions{NoCount: tt.noCount, BlankParams: tt.blankParams, AllowFilter: tt.allowFilter})
|
||||||
handlerFunc(w, req)
|
handlerFunc(w, req)
|
||||||
|
|
||||||
if w.Code != tt.expectedStatus {
|
if w.Code != tt.expectedStatus {
|
||||||
|
|||||||
@ -576,7 +576,7 @@ func TestHookIntegrationWithHandler(t *testing.T) {
|
|||||||
req := createTestRequest("GET", "/test", nil, nil, nil)
|
req := createTestRequest("GET", "/test", nil, nil, nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", false)
|
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", SqlQueryOptions{})
|
||||||
handlerFunc(w, req)
|
handlerFunc(w, req)
|
||||||
|
|
||||||
if !hookCalled {
|
if !hookCalled {
|
||||||
|
|||||||
@ -30,6 +30,15 @@ type Provider interface {
|
|||||||
// UpdateCacheSize updates the cache size metric
|
// UpdateCacheSize updates the cache size metric
|
||||||
UpdateCacheSize(provider string, size int64)
|
UpdateCacheSize(provider string, size int64)
|
||||||
|
|
||||||
|
// RecordEventPublished records an event publication
|
||||||
|
RecordEventPublished(source, eventType string)
|
||||||
|
|
||||||
|
// RecordEventProcessed records an event processing with its status
|
||||||
|
RecordEventProcessed(source, eventType, status string, duration time.Duration)
|
||||||
|
|
||||||
|
// UpdateEventQueueSize updates the event queue size metric
|
||||||
|
UpdateEventQueueSize(size int64)
|
||||||
|
|
||||||
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
|
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
|
||||||
Handler() http.Handler
|
Handler() http.Handler
|
||||||
}
|
}
|
||||||
@ -59,9 +68,13 @@ func (n *NoOpProvider) IncRequestsInFlight()
|
|||||||
func (n *NoOpProvider) DecRequestsInFlight() {}
|
func (n *NoOpProvider) DecRequestsInFlight() {}
|
||||||
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
||||||
}
|
}
|
||||||
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
||||||
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
||||||
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
|
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
|
||||||
|
func (n *NoOpProvider) RecordEventPublished(source, eventType string) {}
|
||||||
|
func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
|
||||||
|
}
|
||||||
|
func (n *NoOpProvider) UpdateEventQueueSize(size int64) {}
|
||||||
func (n *NoOpProvider) Handler() http.Handler {
|
func (n *NoOpProvider) Handler() http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
|||||||
@ -6,15 +6,37 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ModelRules defines the permissions and security settings for a model
|
||||||
|
type ModelRules struct {
|
||||||
|
CanRead bool // Whether the model can be read (GET operations)
|
||||||
|
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||||
|
CanCreate bool // Whether the model can be created (POST operations)
|
||||||
|
CanDelete bool // Whether the model can be deleted (DELETE operations)
|
||||||
|
SecurityDisabled bool // Whether security checks are disabled for this model
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultModelRules returns the default rules for a model (all operations allowed, security enabled)
|
||||||
|
func DefaultModelRules() ModelRules {
|
||||||
|
return ModelRules{
|
||||||
|
CanRead: true,
|
||||||
|
CanUpdate: true,
|
||||||
|
CanCreate: true,
|
||||||
|
CanDelete: true,
|
||||||
|
SecurityDisabled: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultModelRegistry implements ModelRegistry interface
|
// DefaultModelRegistry implements ModelRegistry interface
|
||||||
type DefaultModelRegistry struct {
|
type DefaultModelRegistry struct {
|
||||||
models map[string]interface{}
|
models map[string]interface{}
|
||||||
|
rules map[string]ModelRules
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global default registry instance
|
// Global default registry instance
|
||||||
var defaultRegistry = &DefaultModelRegistry{
|
var defaultRegistry = &DefaultModelRegistry{
|
||||||
models: make(map[string]interface{}),
|
models: make(map[string]interface{}),
|
||||||
|
rules: make(map[string]ModelRules),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Global list of registries (searched in order)
|
// Global list of registries (searched in order)
|
||||||
@ -25,6 +47,7 @@ var registriesMutex sync.RWMutex
|
|||||||
func NewModelRegistry() *DefaultModelRegistry {
|
func NewModelRegistry() *DefaultModelRegistry {
|
||||||
return &DefaultModelRegistry{
|
return &DefaultModelRegistry{
|
||||||
models: make(map[string]interface{}),
|
models: make(map[string]interface{}),
|
||||||
|
rules: make(map[string]ModelRules),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -98,6 +121,10 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.models[name] = model
|
r.models[name] = model
|
||||||
|
// Initialize with default rules if not already set
|
||||||
|
if _, exists := r.rules[name]; !exists {
|
||||||
|
r.rules[name] = DefaultModelRules()
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,6 +162,54 @@ func (r *DefaultModelRegistry) GetModelByEntity(schema, entity string) (interfac
|
|||||||
return r.GetModel(entity)
|
return r.GetModel(entity)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetModelRules sets the rules for a specific model
|
||||||
|
func (r *DefaultModelRegistry) SetModelRules(name string, rules ModelRules) error {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
|
||||||
|
// Check if model exists
|
||||||
|
if _, exists := r.models[name]; !exists {
|
||||||
|
return fmt.Errorf("model %s not found", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules[name] = rules
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelRules retrieves the rules for a specific model
|
||||||
|
// Returns default rules if model exists but rules are not set
|
||||||
|
func (r *DefaultModelRegistry) GetModelRules(name string) (ModelRules, error) {
|
||||||
|
r.mutex.RLock()
|
||||||
|
defer r.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Check if model exists
|
||||||
|
if _, exists := r.models[name]; !exists {
|
||||||
|
return ModelRules{}, fmt.Errorf("model %s not found", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return rules if set, otherwise return default rules
|
||||||
|
if rules, exists := r.rules[name]; exists {
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return DefaultModelRules(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterModelWithRules registers a model with specific rules
|
||||||
|
func (r *DefaultModelRegistry) RegisterModelWithRules(name string, model interface{}, rules ModelRules) error {
|
||||||
|
// First register the model
|
||||||
|
if err := r.RegisterModel(name, model); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then set the rules (we need to lock again for rules)
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
r.rules[name] = rules
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Global convenience functions using the default registry
|
// Global convenience functions using the default registry
|
||||||
|
|
||||||
// RegisterModel registers a model with the default global registry
|
// RegisterModel registers a model with the default global registry
|
||||||
@ -190,3 +265,34 @@ func GetModels() []interface{} {
|
|||||||
|
|
||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetModelRules sets the rules for a specific model in the default registry
|
||||||
|
func SetModelRules(name string, rules ModelRules) error {
|
||||||
|
return defaultRegistry.SetModelRules(name, rules)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelRules retrieves the rules for a specific model from the default registry
|
||||||
|
func GetModelRules(name string) (ModelRules, error) {
|
||||||
|
return defaultRegistry.GetModelRules(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelRulesByName retrieves the rules for a model by searching through all registries in order
|
||||||
|
// Returns the first match found
|
||||||
|
func GetModelRulesByName(name string) (ModelRules, error) {
|
||||||
|
registriesMutex.RLock()
|
||||||
|
defer registriesMutex.RUnlock()
|
||||||
|
|
||||||
|
for _, registry := range registries {
|
||||||
|
if _, err := registry.GetModel(name); err == nil {
|
||||||
|
// Model found in this registry, get its rules
|
||||||
|
return registry.GetModelRules(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ModelRules{}, fmt.Errorf("model %s not found in any registry", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterModelWithRules registers a model with specific rules in the default registry
|
||||||
|
func RegisterModelWithRules(model interface{}, name string, rules ModelRules) error {
|
||||||
|
return defaultRegistry.RegisterModelWithRules(name, model, rules)
|
||||||
|
}
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package reflection
|
package reflection
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -897,6 +898,319 @@ func GetRelationModel(model interface{}, fieldName string) interface{} {
|
|||||||
return currentModel
|
return currentModel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MapToStruct populates a struct from a map while preserving custom types
|
||||||
|
// It uses reflection to set struct fields based on map keys, matching by:
|
||||||
|
// 1. Bun tag column name
|
||||||
|
// 2. Gorm tag column name
|
||||||
|
// 3. JSON tag name
|
||||||
|
// 4. Field name (case-insensitive)
|
||||||
|
// This preserves custom types that implement driver.Valuer like SqlJSONB
|
||||||
|
func MapToStruct(dataMap map[string]interface{}, target interface{}) error {
|
||||||
|
if dataMap == nil || target == nil {
|
||||||
|
return fmt.Errorf("dataMap and target cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
targetValue := reflect.ValueOf(target)
|
||||||
|
if targetValue.Kind() != reflect.Ptr {
|
||||||
|
return fmt.Errorf("target must be a pointer to a struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
targetValue = targetValue.Elem()
|
||||||
|
if targetValue.Kind() != reflect.Struct {
|
||||||
|
return fmt.Errorf("target must be a pointer to a struct")
|
||||||
|
}
|
||||||
|
|
||||||
|
targetType := targetValue.Type()
|
||||||
|
|
||||||
|
// Create a map of column names to field indices for faster lookup
|
||||||
|
columnToField := make(map[string]int)
|
||||||
|
for i := 0; i < targetType.NumField(); i++ {
|
||||||
|
field := targetType.Field(i)
|
||||||
|
|
||||||
|
// Skip unexported fields
|
||||||
|
if !field.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build list of possible column names for this field
|
||||||
|
var columnNames []string
|
||||||
|
|
||||||
|
// 1. Bun tag
|
||||||
|
if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" {
|
||||||
|
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||||
|
columnNames = append(columnNames, colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Gorm tag
|
||||||
|
if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" {
|
||||||
|
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||||
|
columnNames = append(columnNames, colName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. JSON tag
|
||||||
|
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
columnNames = append(columnNames, parts[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Field name variations
|
||||||
|
columnNames = append(columnNames, field.Name)
|
||||||
|
columnNames = append(columnNames, strings.ToLower(field.Name))
|
||||||
|
columnNames = append(columnNames, ToSnakeCase(field.Name))
|
||||||
|
|
||||||
|
// Map all column name variations to this field index
|
||||||
|
for _, colName := range columnNames {
|
||||||
|
columnToField[strings.ToLower(colName)] = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate through the map and set struct fields
|
||||||
|
for key, value := range dataMap {
|
||||||
|
// Find the field index for this key
|
||||||
|
fieldIndex, found := columnToField[strings.ToLower(key)]
|
||||||
|
if !found {
|
||||||
|
// Skip keys that don't map to any field
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
field := targetValue.Field(fieldIndex)
|
||||||
|
if !field.CanSet() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the value, preserving custom types
|
||||||
|
if err := setFieldValue(field, value); err != nil {
|
||||||
|
return fmt.Errorf("failed to set field %s: %w", targetType.Field(fieldIndex).Name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setFieldValue sets a reflect.Value from an interface{} value, handling type conversions
|
||||||
|
func setFieldValue(field reflect.Value, value interface{}) error {
|
||||||
|
if value == nil {
|
||||||
|
// Set zero value for nil
|
||||||
|
field.Set(reflect.Zero(field.Type()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
valueReflect := reflect.ValueOf(value)
|
||||||
|
|
||||||
|
// If types match exactly, just set it
|
||||||
|
if valueReflect.Type().AssignableTo(field.Type()) {
|
||||||
|
field.Set(valueReflect)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle pointer fields
|
||||||
|
if field.Kind() == reflect.Ptr {
|
||||||
|
if valueReflect.Kind() != reflect.Ptr {
|
||||||
|
// Create a new pointer and set its value
|
||||||
|
newPtr := reflect.New(field.Type().Elem())
|
||||||
|
if err := setFieldValue(newPtr.Elem(), value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
field.Set(newPtr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle conversions for basic types
|
||||||
|
switch field.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
if str, ok := value.(string); ok {
|
||||||
|
field.SetString(str)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
if num, ok := convertToInt64(value); ok {
|
||||||
|
if field.OverflowInt(num) {
|
||||||
|
return fmt.Errorf("integer overflow")
|
||||||
|
}
|
||||||
|
field.SetInt(num)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
if num, ok := convertToUint64(value); ok {
|
||||||
|
if field.OverflowUint(num) {
|
||||||
|
return fmt.Errorf("unsigned integer overflow")
|
||||||
|
}
|
||||||
|
field.SetUint(num)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
if num, ok := convertToFloat64(value); ok {
|
||||||
|
if field.OverflowFloat(num) {
|
||||||
|
return fmt.Errorf("float overflow")
|
||||||
|
}
|
||||||
|
field.SetFloat(num)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case reflect.Bool:
|
||||||
|
if b, ok := value.(bool); ok {
|
||||||
|
field.SetBool(b)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
case reflect.Slice:
|
||||||
|
// Handle []byte specially (for types like SqlJSONB)
|
||||||
|
if field.Type().Elem().Kind() == reflect.Uint8 {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case []byte:
|
||||||
|
field.SetBytes(v)
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
field.SetBytes([]byte(v))
|
||||||
|
return nil
|
||||||
|
case map[string]interface{}, []interface{}:
|
||||||
|
// Marshal complex types to JSON for SqlJSONB fields
|
||||||
|
jsonBytes, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal value to JSON: %w", err)
|
||||||
|
}
|
||||||
|
field.SetBytes(jsonBytes)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle struct types (like SqlTimeStamp, SqlDate, SqlTime which wrap SqlNull[time.Time])
|
||||||
|
if field.Kind() == reflect.Struct {
|
||||||
|
// Try to find a "Val" field (for SqlNull types) and set it
|
||||||
|
valField := field.FieldByName("Val")
|
||||||
|
if valField.IsValid() && valField.CanSet() {
|
||||||
|
// Also set Valid field to true
|
||||||
|
validField := field.FieldByName("Valid")
|
||||||
|
if validField.IsValid() && validField.CanSet() && validField.Kind() == reflect.Bool {
|
||||||
|
// Set the Val field
|
||||||
|
if err := setFieldValue(valField, value); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Set Valid to true
|
||||||
|
validField.SetBool(true)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we can convert the type, do it
|
||||||
|
if valueReflect.Type().ConvertibleTo(field.Type()) {
|
||||||
|
field.Set(valueReflect.Convert(field.Type()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToInt64 attempts to convert various types to int64
|
||||||
|
func convertToInt64(value interface{}) (int64, bool) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
return int64(v), true
|
||||||
|
case int8:
|
||||||
|
return int64(v), true
|
||||||
|
case int16:
|
||||||
|
return int64(v), true
|
||||||
|
case int32:
|
||||||
|
return int64(v), true
|
||||||
|
case int64:
|
||||||
|
return v, true
|
||||||
|
case uint:
|
||||||
|
return int64(v), true
|
||||||
|
case uint8:
|
||||||
|
return int64(v), true
|
||||||
|
case uint16:
|
||||||
|
return int64(v), true
|
||||||
|
case uint32:
|
||||||
|
return int64(v), true
|
||||||
|
case uint64:
|
||||||
|
return int64(v), true
|
||||||
|
case float32:
|
||||||
|
return int64(v), true
|
||||||
|
case float64:
|
||||||
|
return int64(v), true
|
||||||
|
case string:
|
||||||
|
if num, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||||
|
return num, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToUint64 attempts to convert various types to uint64
|
||||||
|
func convertToUint64(value interface{}) (uint64, bool) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
return uint64(v), true
|
||||||
|
case int8:
|
||||||
|
return uint64(v), true
|
||||||
|
case int16:
|
||||||
|
return uint64(v), true
|
||||||
|
case int32:
|
||||||
|
return uint64(v), true
|
||||||
|
case int64:
|
||||||
|
return uint64(v), true
|
||||||
|
case uint:
|
||||||
|
return uint64(v), true
|
||||||
|
case uint8:
|
||||||
|
return uint64(v), true
|
||||||
|
case uint16:
|
||||||
|
return uint64(v), true
|
||||||
|
case uint32:
|
||||||
|
return uint64(v), true
|
||||||
|
case uint64:
|
||||||
|
return v, true
|
||||||
|
case float32:
|
||||||
|
return uint64(v), true
|
||||||
|
case float64:
|
||||||
|
return uint64(v), true
|
||||||
|
case string:
|
||||||
|
if num, err := strconv.ParseUint(v, 10, 64); err == nil {
|
||||||
|
return num, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToFloat64 attempts to convert various types to float64
|
||||||
|
func convertToFloat64(value interface{}) (float64, bool) {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int:
|
||||||
|
return float64(v), true
|
||||||
|
case int8:
|
||||||
|
return float64(v), true
|
||||||
|
case int16:
|
||||||
|
return float64(v), true
|
||||||
|
case int32:
|
||||||
|
return float64(v), true
|
||||||
|
case int64:
|
||||||
|
return float64(v), true
|
||||||
|
case uint:
|
||||||
|
return float64(v), true
|
||||||
|
case uint8:
|
||||||
|
return float64(v), true
|
||||||
|
case uint16:
|
||||||
|
return float64(v), true
|
||||||
|
case uint32:
|
||||||
|
return float64(v), true
|
||||||
|
case uint64:
|
||||||
|
return float64(v), true
|
||||||
|
case float32:
|
||||||
|
return float64(v), true
|
||||||
|
case float64:
|
||||||
|
return v, true
|
||||||
|
case string:
|
||||||
|
if num, err := strconv.ParseFloat(v, 64); err == nil {
|
||||||
|
return num, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
||||||
// This is a helper function used by GetRelationModel to handle one level at a time
|
// This is a helper function used by GetRelationModel to handle one level at a time
|
||||||
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
||||||
|
|||||||
266
pkg/reflection/model_utils_sqltypes_test.go
Normal file
266
pkg/reflection/model_utils_sqltypes_test.go
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
package reflection_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMapToStruct_SqlJSONB_PreservesDriverValuer(t *testing.T) {
|
||||||
|
// Test that SqlJSONB type preserves driver.Valuer interface
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
Meta common.SqlJSONB `bun:"meta" json:"meta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dataMap := map[string]interface{}{
|
||||||
|
"id": int64(123),
|
||||||
|
"meta": map[string]interface{}{
|
||||||
|
"key": "value",
|
||||||
|
"num": 42,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var result TestModel
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the field was set
|
||||||
|
if result.ID != 123 {
|
||||||
|
t.Errorf("ID = %v, want 123", result.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlJSONB was populated
|
||||||
|
if len(result.Meta) == 0 {
|
||||||
|
t.Error("Meta is empty, want non-empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Most importantly: verify driver.Valuer interface works
|
||||||
|
value, err := result.Meta.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Meta.Value() error = %v, want nil", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value should return a string representation of the JSON
|
||||||
|
if value == nil {
|
||||||
|
t.Error("Meta.Value() returned nil, want non-nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check it's a valid JSON string
|
||||||
|
if str, ok := value.(string); ok {
|
||||||
|
if len(str) == 0 {
|
||||||
|
t.Error("Meta.Value() returned empty string, want valid JSON")
|
||||||
|
}
|
||||||
|
t.Logf("SqlJSONB.Value() returned: %s", str)
|
||||||
|
} else {
|
||||||
|
t.Errorf("Meta.Value() returned type %T, want string", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapToStruct_SqlJSONB_FromBytes(t *testing.T) {
|
||||||
|
// Test that SqlJSONB can be set from []byte directly
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
Meta common.SqlJSONB `bun:"meta" json:"meta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes := []byte(`{"direct":"bytes"}`)
|
||||||
|
dataMap := map[string]interface{}{
|
||||||
|
"id": int64(456),
|
||||||
|
"meta": jsonBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
var result TestModel
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ID != 456 {
|
||||||
|
t.Errorf("ID = %v, want 456", result.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Meta) != string(jsonBytes) {
|
||||||
|
t.Errorf("Meta = %s, want %s", string(result.Meta), string(jsonBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify driver.Valuer works
|
||||||
|
value, err := result.Meta.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Meta.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if value == nil {
|
||||||
|
t.Error("Meta.Value() returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapToStruct_AllSqlTypes(t *testing.T) {
|
||||||
|
// Test model with all SQL custom types
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
CreatedAt common.SqlTimeStamp `bun:"created_at" json:"created_at"`
|
||||||
|
BirthDate common.SqlDate `bun:"birth_date" json:"birth_date"`
|
||||||
|
LoginTime common.SqlTime `bun:"login_time" json:"login_time"`
|
||||||
|
Meta common.SqlJSONB `bun:"meta" json:"meta"`
|
||||||
|
Tags common.SqlJSONB `bun:"tags" json:"tags"`
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
birthDate := time.Date(1990, 1, 15, 0, 0, 0, 0, time.UTC)
|
||||||
|
loginTime := time.Date(0, 1, 1, 14, 30, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
dataMap := map[string]interface{}{
|
||||||
|
"id": int64(100),
|
||||||
|
"name": "Test User",
|
||||||
|
"created_at": now,
|
||||||
|
"birth_date": birthDate,
|
||||||
|
"login_time": loginTime,
|
||||||
|
"meta": map[string]interface{}{
|
||||||
|
"role": "admin",
|
||||||
|
"active": true,
|
||||||
|
},
|
||||||
|
"tags": []interface{}{"golang", "testing", "sql"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var result TestModel
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify basic fields
|
||||||
|
if result.ID != 100 {
|
||||||
|
t.Errorf("ID = %v, want 100", result.ID)
|
||||||
|
}
|
||||||
|
if result.Name != "Test User" {
|
||||||
|
t.Errorf("Name = %v, want 'Test User'", result.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlTimeStamp
|
||||||
|
if !result.CreatedAt.Valid {
|
||||||
|
t.Error("CreatedAt.Valid = false, want true")
|
||||||
|
}
|
||||||
|
if !result.CreatedAt.Val.Equal(now) {
|
||||||
|
t.Errorf("CreatedAt.Val = %v, want %v", result.CreatedAt.Val, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify driver.Valuer for SqlTimeStamp
|
||||||
|
tsValue, err := result.CreatedAt.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("CreatedAt.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if tsValue == nil {
|
||||||
|
t.Error("CreatedAt.Value() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlDate
|
||||||
|
if !result.BirthDate.Valid {
|
||||||
|
t.Error("BirthDate.Valid = false, want true")
|
||||||
|
}
|
||||||
|
if !result.BirthDate.Val.Equal(birthDate) {
|
||||||
|
t.Errorf("BirthDate.Val = %v, want %v", result.BirthDate.Val, birthDate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify driver.Valuer for SqlDate
|
||||||
|
dateValue, err := result.BirthDate.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("BirthDate.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if dateValue == nil {
|
||||||
|
t.Error("BirthDate.Value() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlTime
|
||||||
|
if !result.LoginTime.Valid {
|
||||||
|
t.Error("LoginTime.Valid = false, want true")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify driver.Valuer for SqlTime
|
||||||
|
timeValue, err := result.LoginTime.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("LoginTime.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if timeValue == nil {
|
||||||
|
t.Error("LoginTime.Value() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlJSONB for Meta
|
||||||
|
if len(result.Meta) == 0 {
|
||||||
|
t.Error("Meta is empty")
|
||||||
|
}
|
||||||
|
metaValue, err := result.Meta.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Meta.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if metaValue == nil {
|
||||||
|
t.Error("Meta.Value() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify SqlJSONB for Tags
|
||||||
|
if len(result.Tags) == 0 {
|
||||||
|
t.Error("Tags is empty")
|
||||||
|
}
|
||||||
|
tagsValue, err := result.Tags.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Tags.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if tagsValue == nil {
|
||||||
|
t.Error("Tags.Value() returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("All SQL types successfully preserved driver.Valuer interface:")
|
||||||
|
t.Logf(" - SqlTimeStamp: %v", tsValue)
|
||||||
|
t.Logf(" - SqlDate: %v", dateValue)
|
||||||
|
t.Logf(" - SqlTime: %v", timeValue)
|
||||||
|
t.Logf(" - SqlJSONB (Meta): %v", metaValue)
|
||||||
|
t.Logf(" - SqlJSONB (Tags): %v", tagsValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapToStruct_SqlNull_NilValues(t *testing.T) {
|
||||||
|
// Test that SqlNull types handle nil values correctly
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
UpdatedAt common.SqlTimeStamp `bun:"updated_at" json:"updated_at"`
|
||||||
|
DeletedAt common.SqlTimeStamp `bun:"deleted_at" json:"deleted_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
dataMap := map[string]interface{}{
|
||||||
|
"id": int64(200),
|
||||||
|
"updated_at": now,
|
||||||
|
"deleted_at": nil, // Explicitly nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var result TestModel
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedAt should be valid
|
||||||
|
if !result.UpdatedAt.Valid {
|
||||||
|
t.Error("UpdatedAt.Valid = false, want true")
|
||||||
|
}
|
||||||
|
if !result.UpdatedAt.Val.Equal(now) {
|
||||||
|
t.Errorf("UpdatedAt.Val = %v, want %v", result.UpdatedAt.Val, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletedAt should be invalid (null)
|
||||||
|
if result.DeletedAt.Valid {
|
||||||
|
t.Error("DeletedAt.Valid = true, want false (null)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify driver.Valuer for null SqlTimeStamp
|
||||||
|
deletedValue, err := result.DeletedAt.Value()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("DeletedAt.Value() error = %v", err)
|
||||||
|
}
|
||||||
|
if deletedValue != nil {
|
||||||
|
t.Errorf("DeletedAt.Value() = %v, want nil", deletedValue)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1687,3 +1687,201 @@ func TestGetRelationModel_WithTags(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMapToStruct(t *testing.T) {
|
||||||
|
// Test model with various field types
|
||||||
|
type TestModel struct {
|
||||||
|
ID int64 `bun:"id,pk" json:"id"`
|
||||||
|
Name string `bun:"name" json:"name"`
|
||||||
|
Age int `bun:"age" json:"age"`
|
||||||
|
Active bool `bun:"active" json:"active"`
|
||||||
|
Score float64 `bun:"score" json:"score"`
|
||||||
|
Data []byte `bun:"data" json:"data"`
|
||||||
|
MetaJSON []byte `bun:"meta_json" json:"meta_json"`
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dataMap map[string]interface{}
|
||||||
|
expected TestModel
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic types conversion",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(123),
|
||||||
|
"name": "Test User",
|
||||||
|
"age": 30,
|
||||||
|
"active": true,
|
||||||
|
"score": 95.5,
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 123,
|
||||||
|
Name: "Test User",
|
||||||
|
Age: 30,
|
||||||
|
Active: true,
|
||||||
|
Score: 95.5,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Byte slice (SqlJSONB-like) from []byte",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(456),
|
||||||
|
"name": "JSON Test",
|
||||||
|
"data": []byte(`{"key":"value"}`),
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 456,
|
||||||
|
Name: "JSON Test",
|
||||||
|
Data: []byte(`{"key":"value"}`),
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Byte slice from string",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(789),
|
||||||
|
"data": "string data",
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 789,
|
||||||
|
Data: []byte("string data"),
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Byte slice from map (JSON marshal)",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(999),
|
||||||
|
"meta_json": map[string]interface{}{
|
||||||
|
"field1": "value1",
|
||||||
|
"field2": 42,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 999,
|
||||||
|
MetaJSON: []byte(`{"field1":"value1","field2":42}`),
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Byte slice from slice (JSON marshal)",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(111),
|
||||||
|
"meta_json": []interface{}{"item1", "item2", 3},
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 111,
|
||||||
|
MetaJSON: []byte(`["item1","item2",3]`),
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Field matching by bun tag",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(222),
|
||||||
|
"name": "Tagged Field",
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 222,
|
||||||
|
Name: "Tagged Field",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil values",
|
||||||
|
dataMap: map[string]interface{}{
|
||||||
|
"id": int64(333),
|
||||||
|
"data": nil,
|
||||||
|
},
|
||||||
|
expected: TestModel{
|
||||||
|
ID: 333,
|
||||||
|
Data: nil,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var result TestModel
|
||||||
|
err := MapToStruct(tt.dataMap, &result)
|
||||||
|
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("MapToStruct() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compare fields individually for better error messages
|
||||||
|
if result.ID != tt.expected.ID {
|
||||||
|
t.Errorf("ID = %v, want %v", result.ID, tt.expected.ID)
|
||||||
|
}
|
||||||
|
if result.Name != tt.expected.Name {
|
||||||
|
t.Errorf("Name = %v, want %v", result.Name, tt.expected.Name)
|
||||||
|
}
|
||||||
|
if result.Age != tt.expected.Age {
|
||||||
|
t.Errorf("Age = %v, want %v", result.Age, tt.expected.Age)
|
||||||
|
}
|
||||||
|
if result.Active != tt.expected.Active {
|
||||||
|
t.Errorf("Active = %v, want %v", result.Active, tt.expected.Active)
|
||||||
|
}
|
||||||
|
if result.Score != tt.expected.Score {
|
||||||
|
t.Errorf("Score = %v, want %v", result.Score, tt.expected.Score)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For byte slices, compare as strings for JSON data
|
||||||
|
if tt.expected.Data != nil {
|
||||||
|
if string(result.Data) != string(tt.expected.Data) {
|
||||||
|
t.Errorf("Data = %s, want %s", string(result.Data), string(tt.expected.Data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tt.expected.MetaJSON != nil {
|
||||||
|
if string(result.MetaJSON) != string(tt.expected.MetaJSON) {
|
||||||
|
t.Errorf("MetaJSON = %s, want %s", string(result.MetaJSON), string(tt.expected.MetaJSON))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapToStruct_Errors(t *testing.T) {
|
||||||
|
type TestModel struct {
|
||||||
|
ID int `bun:"id" json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
dataMap map[string]interface{}
|
||||||
|
target interface{}
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Nil dataMap",
|
||||||
|
dataMap: nil,
|
||||||
|
target: &TestModel{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Nil target",
|
||||||
|
dataMap: map[string]interface{}{"id": 1},
|
||||||
|
target: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-pointer target",
|
||||||
|
dataMap: map[string]interface{}{"id": 1},
|
||||||
|
target: TestModel{},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := MapToStruct(tt.dataMap, tt.target)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("MapToStruct() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -56,6 +56,10 @@ type HookContext struct {
|
|||||||
Abort bool // If set to true, the operation will be aborted
|
Abort bool // If set to true, the operation will be aborted
|
||||||
AbortMessage string // Message to return if aborted
|
AbortMessage string // Message to return if aborted
|
||||||
AbortCode int // HTTP status code if aborted
|
AbortCode int // HTTP status code if aborted
|
||||||
|
|
||||||
|
// Tx provides access to the database/transaction for executing additional SQL
|
||||||
|
// This allows hooks to run custom queries in addition to the main Query chain
|
||||||
|
Tx common.Database
|
||||||
}
|
}
|
||||||
|
|
||||||
// HookFunc is the signature for hook functions
|
// HookFunc is the signature for hook functions
|
||||||
|
|||||||
@ -127,7 +127,7 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
|
|
||||||
// Validate and filter columns in options (log warnings for invalid columns)
|
// Validate and filter columns in options (log warnings for invalid columns)
|
||||||
validator := common.NewColumnValidator(model)
|
validator := common.NewColumnValidator(model)
|
||||||
options = filterExtendedOptions(validator, options)
|
options = h.filterExtendedOptions(validator, options, model)
|
||||||
|
|
||||||
// Add request-scoped data to context (including options)
|
// Add request-scoped data to context (including options)
|
||||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
||||||
@ -300,6 +300,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
Options: options,
|
Options: options,
|
||||||
ID: id,
|
ID: id,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
|
Tx: h.db,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||||
@ -745,9 +746,42 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
|
|
||||||
// Apply ComputedQL fields if any
|
// Apply ComputedQL fields if any
|
||||||
if len(preload.ComputedQL) > 0 {
|
if len(preload.ComputedQL) > 0 {
|
||||||
|
// Get the base table name from the related model
|
||||||
|
baseTableName := getTableNameFromModel(relatedModel)
|
||||||
|
|
||||||
|
// Convert the preload relation path to the appropriate alias format
|
||||||
|
// This is ORM-specific. Currently we only support Bun's format.
|
||||||
|
// TODO: Add support for other ORMs if needed
|
||||||
|
preloadAlias := ""
|
||||||
|
if h.db.GetUnderlyingDB() != nil {
|
||||||
|
// Check if we're using Bun by checking the type name
|
||||||
|
underlyingType := fmt.Sprintf("%T", h.db.GetUnderlyingDB())
|
||||||
|
if strings.Contains(underlyingType, "bun.DB") {
|
||||||
|
// Use Bun's alias format: lowercase with double underscores
|
||||||
|
preloadAlias = relationPathToBunAlias(preload.Relation)
|
||||||
|
}
|
||||||
|
// For GORM: GORM doesn't use the same alias format, and this fix
|
||||||
|
// may not be needed since GORM handles preloads differently
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Applying computed columns to preload %s (alias: %s, base table: %s)",
|
||||||
|
preload.Relation, preloadAlias, baseTableName)
|
||||||
|
|
||||||
for colName, colExpr := range preload.ComputedQL {
|
for colName, colExpr := range preload.ComputedQL {
|
||||||
|
// Replace table references in the expression with the preload alias
|
||||||
|
// This fixes the ambiguous column reference issue when there are multiple
|
||||||
|
// levels of recursive/nested preloads
|
||||||
|
adjustedExpr := colExpr
|
||||||
|
if baseTableName != "" && preloadAlias != "" {
|
||||||
|
adjustedExpr = replaceTableReferencesInSQL(colExpr, baseTableName, preloadAlias)
|
||||||
|
if adjustedExpr != colExpr {
|
||||||
|
logger.Debug("Adjusted computed column expression for %s: '%s' -> '%s'",
|
||||||
|
colName, colExpr, adjustedExpr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
|
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
|
||||||
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName))
|
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", adjustedExpr, colName))
|
||||||
// Remove the computed column from selected columns to avoid duplication
|
// Remove the computed column from selected columns to avoid duplication
|
||||||
for colIndex := range preload.Columns {
|
for colIndex := range preload.Columns {
|
||||||
if preload.Columns[colIndex] == colName {
|
if preload.Columns[colIndex] == colName {
|
||||||
@ -840,6 +874,73 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
|||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// relationPathToBunAlias converts a relation path like "MAL.MAL.DEF" to the Bun alias format "mal__mal__def"
|
||||||
|
// Bun generates aliases for nested relations by lowercasing and replacing dots with double underscores
|
||||||
|
func relationPathToBunAlias(relationPath string) string {
|
||||||
|
if relationPath == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// Convert to lowercase and replace dots with double underscores
|
||||||
|
alias := strings.ToLower(relationPath)
|
||||||
|
alias = strings.ReplaceAll(alias, ".", "__")
|
||||||
|
return alias
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaceTableReferencesInSQL replaces references to a base table name in a SQL expression
|
||||||
|
// with the appropriate alias for the current preload level
|
||||||
|
// For example, if baseTableName is "mastertaskitem" and targetAlias is "mal__mal",
|
||||||
|
// it will replace "mastertaskitem.rid_mastertaskitem" with "mal__mal.rid_mastertaskitem"
|
||||||
|
func replaceTableReferencesInSQL(sqlExpr, baseTableName, targetAlias string) string {
|
||||||
|
if sqlExpr == "" || baseTableName == "" || targetAlias == "" {
|
||||||
|
return sqlExpr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace both quoted and unquoted table references
|
||||||
|
// Handle patterns like: tablename.column, "tablename".column, tablename."column", "tablename"."column"
|
||||||
|
|
||||||
|
// Pattern 1: tablename.column (unquoted)
|
||||||
|
result := strings.ReplaceAll(sqlExpr, baseTableName+".", targetAlias+".")
|
||||||
|
|
||||||
|
// Pattern 2: "tablename".column or "tablename"."column" (quoted table name)
|
||||||
|
result = strings.ReplaceAll(result, "\""+baseTableName+"\".", "\""+targetAlias+"\".")
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTableNameFromModel extracts the table name from a model
|
||||||
|
// It checks the bun tag first, then falls back to converting the struct name to snake_case
|
||||||
|
func getTableNameFromModel(model interface{}) string {
|
||||||
|
if model == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
|
// Unwrap pointers
|
||||||
|
for modelType != nil && modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Look for bun tag on embedded BaseModel
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
if field.Anonymous {
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
if strings.HasPrefix(bunTag, "table:") {
|
||||||
|
return strings.TrimPrefix(bunTag, "table:")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: convert struct name to lowercase (simple heuristic)
|
||||||
|
// This handles cases like "MasterTaskItem" -> "mastertaskitem"
|
||||||
|
return strings.ToLower(modelType.Name())
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
|
||||||
// Capture panics and return error response
|
// Capture panics and return error response
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -866,6 +967,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
Options: options,
|
Options: options,
|
||||||
Data: data,
|
Data: data,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
|
Tx: h.db,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||||
@ -955,6 +1057,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
Data: modelValue,
|
Data: modelValue,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
Query: query,
|
Query: query,
|
||||||
|
Tx: tx,
|
||||||
}
|
}
|
||||||
if err := h.hooks.Execute(BeforeScan, itemHookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeScan, itemHookCtx); err != nil {
|
||||||
return fmt.Errorf("BeforeScan hook failed for item %d: %w", i, err)
|
return fmt.Errorf("BeforeScan hook failed for item %d: %w", i, err)
|
||||||
@ -1047,6 +1150,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
Schema: schema,
|
Schema: schema,
|
||||||
Entity: entity,
|
Entity: entity,
|
||||||
TableName: tableName,
|
TableName: tableName,
|
||||||
|
Tx: h.db,
|
||||||
Model: model,
|
Model: model,
|
||||||
Options: options,
|
Options: options,
|
||||||
ID: id,
|
ID: id,
|
||||||
@ -1116,12 +1220,19 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
// Ensure ID is in the data map for the update
|
// Ensure ID is in the data map for the update
|
||||||
dataMap[pkName] = targetID
|
dataMap[pkName] = targetID
|
||||||
|
|
||||||
// Create update query
|
// Populate model instance from dataMap to preserve custom types (like SqlJSONB)
|
||||||
query := tx.NewUpdate().Table(tableName).SetMap(dataMap)
|
modelInstance := reflect.New(reflect.TypeOf(model).Elem()).Interface()
|
||||||
|
if err := reflection.MapToStruct(dataMap, modelInstance); err != nil {
|
||||||
|
return fmt.Errorf("failed to populate model from data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create update query using Model() to preserve custom types and driver.Valuer interfaces
|
||||||
|
query := tx.NewUpdate().Model(modelInstance).Table(tableName)
|
||||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
hookCtx.Query = query
|
hookCtx.Query = query
|
||||||
|
hookCtx.Tx = tx
|
||||||
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeScan, hookCtx); err != nil {
|
||||||
return fmt.Errorf("BeforeScan hook failed: %w", err)
|
return fmt.Errorf("BeforeScan hook failed: %w", err)
|
||||||
}
|
}
|
||||||
@ -1217,6 +1328,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
Model: model,
|
Model: model,
|
||||||
ID: itemID,
|
ID: itemID,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
|
Tx: tx,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
@ -1285,6 +1397,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
Model: model,
|
Model: model,
|
||||||
ID: itemIDStr,
|
ID: itemIDStr,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
|
Tx: tx,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
@ -1337,6 +1450,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
Model: model,
|
Model: model,
|
||||||
ID: itemIDStr,
|
ID: itemIDStr,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
|
Tx: tx,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
@ -1390,6 +1504,7 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
Model: model,
|
Model: model,
|
||||||
ID: id,
|
ID: id,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
|
Tx: h.db,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
@ -2241,7 +2356,7 @@ func (h *Handler) setRowNumbersOnRecords(records any, offset int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
|
// filterExtendedOptions filters all column references, removing invalid ones and logging warnings
|
||||||
func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions) ExtendedRequestOptions {
|
func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRequestOptions, model interface{}) ExtendedRequestOptions {
|
||||||
filtered := options
|
filtered := options
|
||||||
|
|
||||||
// Filter base RequestOptions
|
// Filter base RequestOptions
|
||||||
@ -2265,12 +2380,30 @@ func filterExtendedOptions(validator *common.ColumnValidator, options ExtendedRe
|
|||||||
// No filtering needed for ComputedQL keys
|
// No filtering needed for ComputedQL keys
|
||||||
filtered.ComputedQL = options.ComputedQL
|
filtered.ComputedQL = options.ComputedQL
|
||||||
|
|
||||||
// Filter Expand columns
|
// Filter Expand columns using the expand relation's model
|
||||||
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
|
filteredExpands := make([]ExpandOption, 0, len(options.Expand))
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
for _, expand := range options.Expand {
|
for _, expand := range options.Expand {
|
||||||
filteredExpand := expand
|
filteredExpand := expand
|
||||||
// Don't validate relation name, only columns
|
|
||||||
filteredExpand.Columns = validator.FilterValidColumns(expand.Columns)
|
// Get the relationship info for this expand relation
|
||||||
|
relInfo := h.getRelationshipInfo(modelType, expand.Relation)
|
||||||
|
if relInfo != nil && relInfo.relatedModel != nil {
|
||||||
|
// Create a validator for the related model
|
||||||
|
expandValidator := common.NewColumnValidator(relInfo.relatedModel)
|
||||||
|
// Filter columns using the related model's validator
|
||||||
|
filteredExpand.Columns = expandValidator.FilterValidColumns(expand.Columns)
|
||||||
|
} else {
|
||||||
|
// If we can't find the relationship, log a warning and skip column filtering
|
||||||
|
logger.Warn("Cannot validate columns for unknown relation: %s", expand.Relation)
|
||||||
|
// Keep the columns as-is if we can't validate them
|
||||||
|
filteredExpand.Columns = expand.Columns
|
||||||
|
}
|
||||||
|
|
||||||
filteredExpands = append(filteredExpands, filteredExpand)
|
filteredExpands = append(filteredExpands, filteredExpand)
|
||||||
}
|
}
|
||||||
filtered.Expand = filteredExpands
|
filtered.Expand = filteredExpands
|
||||||
|
|||||||
@ -55,6 +55,10 @@ type HookContext struct {
|
|||||||
|
|
||||||
// Response writer - allows hooks to modify response
|
// Response writer - allows hooks to modify response
|
||||||
Writer common.ResponseWriter
|
Writer common.ResponseWriter
|
||||||
|
|
||||||
|
// Tx provides access to the database/transaction for executing additional SQL
|
||||||
|
// This allows hooks to run custom queries in addition to the main Query chain
|
||||||
|
Tx common.Database
|
||||||
}
|
}
|
||||||
|
|
||||||
// HookFunc is the signature for hook functions
|
// HookFunc is the signature for hook functions
|
||||||
|
|||||||
@ -150,6 +150,50 @@ func ExampleRelatedDataHook(ctx *HookContext) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExampleTxHook demonstrates using the Tx field to execute additional SQL queries
|
||||||
|
// The Tx field provides access to the database/transaction for custom queries
|
||||||
|
func ExampleTxHook(ctx *HookContext) error {
|
||||||
|
// Example: Execute additional SQL operations alongside the main query
|
||||||
|
// This is useful for maintaining data consistency, updating related records, etc.
|
||||||
|
|
||||||
|
if ctx.Entity == "orders" && ctx.Data != nil {
|
||||||
|
// Example: Update inventory when an order is created
|
||||||
|
// Extract product ID and quantity from the order data
|
||||||
|
// dataMap, ok := ctx.Data.(map[string]interface{})
|
||||||
|
// if !ok {
|
||||||
|
// return fmt.Errorf("invalid data format")
|
||||||
|
// }
|
||||||
|
// productID := dataMap["product_id"]
|
||||||
|
// quantity := dataMap["quantity"]
|
||||||
|
|
||||||
|
// Use ctx.Tx to execute additional SQL queries
|
||||||
|
// The Tx field contains the same database/transaction as the main operation
|
||||||
|
// If inside a transaction, your queries will be part of the same transaction
|
||||||
|
// query := ctx.Tx.NewUpdate().
|
||||||
|
// Table("inventory").
|
||||||
|
// Set("quantity = quantity - ?", quantity).
|
||||||
|
// Where("product_id = ?", productID)
|
||||||
|
//
|
||||||
|
// if _, err := query.Exec(ctx.Context); err != nil {
|
||||||
|
// logger.Error("Failed to update inventory: %v", err)
|
||||||
|
// return fmt.Errorf("failed to update inventory: %w", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// You can also execute raw SQL using ctx.Tx
|
||||||
|
// var result []map[string]interface{}
|
||||||
|
// err := ctx.Tx.Query(ctx.Context, &result,
|
||||||
|
// "INSERT INTO order_history (order_id, status) VALUES (?, ?)",
|
||||||
|
// orderID, "pending")
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("failed to insert order history: %w", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
logger.Debug("Executed additional SQL for order entity")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetupExampleHooks demonstrates how to register hooks on a handler
|
// SetupExampleHooks demonstrates how to register hooks on a handler
|
||||||
func SetupExampleHooks(handler *Handler) {
|
func SetupExampleHooks(handler *Handler) {
|
||||||
hooks := handler.Hooks()
|
hooks := handler.Hooks()
|
||||||
|
|||||||
@ -29,10 +29,11 @@ type LoginRequest struct {
|
|||||||
|
|
||||||
// LoginResponse contains the result of a login attempt
|
// LoginResponse contains the result of a login attempt
|
||||||
type LoginResponse struct {
|
type LoginResponse struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
User *UserContext `json:"user"`
|
User *UserContext `json:"user"`
|
||||||
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
||||||
|
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogoutRequest contains information for logout
|
// LogoutRequest contains information for logout
|
||||||
|
|||||||
@ -111,7 +111,7 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
|||||||
var dataJSON sql.NullString
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)`
|
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)`
|
||||||
err = a.db.QueryRowContext(ctx, query, reqJSON).Scan(&success, &errorMsg, &dataJSON)
|
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("login query failed: %w", err)
|
return nil, fmt.Errorf("login query failed: %w", err)
|
||||||
}
|
}
|
||||||
@ -145,7 +145,7 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
|||||||
var dataJSON sql.NullString
|
var dataJSON sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)`
|
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)`
|
||||||
err = a.db.QueryRowContext(ctx, query, reqJSON).Scan(&success, &errorMsg, &dataJSON)
|
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("logout query failed: %w", err)
|
return fmt.Errorf("logout query failed: %w", err)
|
||||||
}
|
}
|
||||||
@ -297,7 +297,7 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
|||||||
var updatedUserJSON sql.NullString
|
var updatedUserJSON sql.NullString
|
||||||
|
|
||||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)`
|
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)`
|
||||||
_ = a.db.QueryRowContext(ctx, query, sessionToken, userJSON).Scan(&success, &errorMsg, &updatedUserJSON)
|
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RefreshToken implements Refreshable interface
|
// RefreshToken implements Refreshable interface
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user