Compare commits

...

93 Commits

Author SHA1 Message Date
Hein
932f12ab0a Update handler fixes for Utils bug
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Has been cancelled
Build , Vet Test, and Lint / Lint Code (push) Has been cancelled
Build , Vet Test, and Lint / Build (push) Has been cancelled
Tests / Unit Tests (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
2025-12-12 17:01:37 +02:00
Hein
b22792bad6 Optional check for bun 2025-12-12 14:49:52 +02:00
Hein
e8111c01aa Fixed for relation preloading 2025-12-12 11:45:04 +02:00
Hein
5862016031 Added ModelRules
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-12 10:13:11 +02:00
Hein
2f18dde29c Added Tx common.Database to hooks 2025-12-12 09:45:44 +02:00
Hein
31ad217818 Event Broken Concept 2025-12-12 09:23:54 +02:00
Hein
7ef1d6424a Better handling for variables callback
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-11 15:57:01 +02:00
Hein
c50eeac5bf Change the SqlQuery functions parameters on Function Spec 2025-12-11 15:42:00 +02:00
Hein
6d88f2668a Updated login interface with meta 2025-12-11 14:05:27 +02:00
Hein
8a9423df6d Fixed DatabaseAuthenticator JSON value. Added make tag 2025-12-11 13:59:41 +02:00
Hein
4cc943b9d3 Added row PgSQLAdapter
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
2025-12-10 15:28:09 +02:00
Hein
68dee78a34 Fixed filterExtendedOptions 2025-12-10 12:25:23 +02:00
Hein
efb9e5d9d5 Removed the buggy filter expand columns 2025-12-10 12:15:18 +02:00
Hein
490ae37c6d Fixed bugs in extractTableAndColumn 2025-12-10 11:48:03 +02:00
Hein
99307e31e6 More debugging on bun for scan issues 2025-12-10 11:16:25 +02:00
Hein
e3f7869c6d Bun scan debugging 2025-12-10 11:07:18 +02:00
Hein
c696d502c5 extractTableAndColumn 2025-12-10 10:10:55 +02:00
Hein
4ed1fba6ad Fixed extractTableAndColumn 2025-12-10 10:10:43 +02:00
Hein
1d0407a16d Fixed linting
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-10 10:00:01 +02:00
Hein
99001c749d Better sql where validation 2025-12-10 09:52:13 +02:00
Hein
1f7a57f8e3 Tracking provider 2025-12-10 09:31:55 +02:00
Hein
a95c28a0bf Multi Token warning and handling 2025-12-10 08:44:37 +02:00
Hein
e1abd5ebc1 Enhanced the SanitizeWhereClause function 2025-12-10 08:36:24 +02:00
Hein
ca4e53969b Better tests
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-09 15:32:16 +02:00
Hein
db2b7e878e Better handling of preloads 2025-12-09 15:12:17 +02:00
Hein
9572bfc7b8 Fix qualified column reference (like APIL.rid_hub) in a preload: 2025-12-09 14:46:33 +02:00
Hein
f0962ea1ec Added EnableQueryDebug log 2025-12-09 14:37:09 +02:00
Hein
8fcb065b42 Better Query Debugging 2025-12-09 14:31:26 +02:00
Hein
dc3b621380 Fixed test for session id changes 2025-12-09 14:07:00 +02:00
Hein
a4dd2a7086 exposed types FromString 2025-12-09 14:03:55 +02:00
Hein
3ec2e5f15a Proper handling of fromString in the types 2025-12-09 13:55:51 +02:00
Hein
c52afe2825 Updated sql types 2025-12-09 13:14:22 +02:00
Hein
76e98d02c3 Added modelregistry.GetDefaultRegistry 2025-12-09 12:12:10 +02:00
Hein
23e2db1496 Fixed linting 2025-12-09 12:02:44 +02:00
Hein
d188f49126 Added openapi spec 2025-12-09 12:01:21 +02:00
Hein
0f05202438 Database Authenticator with cache 2025-12-09 11:32:44 +02:00
Hein
b2115038f2 Fixed providers 2025-12-09 11:18:11 +02:00
Hein
229ee4fb28 Fixed DatabaseAuthenticator sq select 2025-12-09 11:05:48 +02:00
Hein
2cf760b979 Added a few auth shortcuts 2025-12-09 10:31:08 +02:00
Hein
0a9c107095 Fixed sqlquery bug in funcspec 2025-12-09 10:19:03 +02:00
Hein
4e2fe33b77 Fixed session_rid in funcspec 2025-12-09 10:04:39 +02:00
Hein
1baa0af0ac Config Package
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-09 09:19:56 +02:00
Hein
659b2925e4 Cursor pagnation for resolvespec 2025-12-09 08:51:15 +02:00
Hein
baca70cafc Split coverage reports
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-08 17:20:40 +02:00
Hein
ed57978620 go-version 1.24 2025-12-08 17:14:04 +02:00
Hein
97b39de88a Broken linting 2025-12-08 17:12:44 +02:00
Hein
bf955b7971 Updated version
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-08 17:08:23 +02:00
Hein
545856f8a0 Fixed linting issues 2025-12-08 17:07:13 +02:00
Hein
8d123e47bd Updated deps on workflow 2025-12-08 16:59:49 +02:00
Hein
c9eaf84125 A lot more tests 2025-12-08 16:56:48 +02:00
Hein
aeae9d7e0c Added blacklist middleware 2025-12-08 09:26:36 +02:00
Hein
2a84652dba Middleware enhancements 2025-12-08 08:47:13 +02:00
Hein
b741958895 Code sanity fixes, added middlewares
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-12-08 08:28:43 +02:00
Hein
2442589982 Better headers
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-12-03 14:42:38 +02:00
Hein
7c1bae60c9 Added meta handlers 2025-12-03 13:52:06 +02:00
Hein
06b2404c0c Remove blank array if no args 2025-12-03 12:25:51 +02:00
Hein
32007480c6 Handle cql columns as text by default 2025-12-03 12:18:33 +02:00
Hein
ab1ce869b6 Handling JSON responses in funcspec 2025-12-03 12:10:13 +02:00
Hein
ff72e04428 Added meta operation. 2025-12-03 11:59:58 +02:00
Hein
e35f8a4f14 Fix session id that is an integer. 2025-12-03 11:49:19 +02:00
Hein
5ff9a8a24e Fixed blank params on funcspec 2025-12-03 11:42:32 +02:00
Hein
81b87af6e4 Updated doc 2025-12-03 11:30:59 +02:00
Hein
f3ba314640 Refectored the mux routers. 2025-12-03 10:42:26 +02:00
Hein
93df33e274 UnderlyingRequest and UnderlyingResponseWriter
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-12-02 17:40:44 +02:00
Hein
abd045493a mux UnderlyingRequest 2025-12-02 17:34:18 +02:00
Hein
a61556d857 Added FallbackHandler 2025-12-02 17:16:34 +02:00
Hein
eaf1133575 Fixed security rules not loading 2025-12-02 16:55:12 +02:00
Hein
8172c0495d More generic security solution. 2025-12-02 16:35:08 +02:00
Hein
7a3c368121 Pass through to default handler 2025-12-02 16:09:36 +02:00
Hein
9c5c7689e9 More common handler interface 2025-12-02 15:45:24 +02:00
Hein
08050c960d Optional Authentication 2025-12-02 14:14:38 +02:00
Hein
78029fb34f Fixed formatting issues
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-12-01 14:56:30 +02:00
Hein
1643a5e920 Added cache, funcspec and implemented total cache 2025-12-01 14:40:54 +02:00
Hein
6bbe0ec8b0 Added function api prototype
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-11-24 17:00:15 +02:00
Hein
e32ec9e17e Updated the security package 2025-11-24 17:00:05 +02:00
Hein
26c175e65e Added make release to vscode tasks
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-24 10:15:23 +02:00
Hein
aa99e8e4bc Added WrapHTTPRequest 2025-11-24 10:13:48 +02:00
Hein
163593901f Huge preload chains causing errors, workaround to do seperate selects.
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-11-21 17:09:11 +02:00
Hein
1261960e97 Ability to handle multiple x-custom- headers
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-21 12:15:07 +02:00
Hein
76bbf33db2 Fixed SingleRecordAsObject true when handleRead with no id 2025-11-21 11:49:08 +02:00
Hein
02c9b96b0c Better SanitizeWhereClause 2025-11-21 11:42:01 +02:00
Hein
9a3564f05f SanitizeWhereClause with tablename on handlers. 2025-11-21 11:00:44 +02:00
Hein
a931b8cdd2 Better preloads 2025-11-21 10:41:58 +02:00
Hein
7e76977dcc Lots of refactoring, Fixes to preloads 2025-11-21 10:17:20 +02:00
Hein
7853a3f56a cql_columns parsing and recursive preloading. Also added legacy header support for limt(s,e) ,sort(x,y,-z) 2025-11-21 09:15:40 +02:00
Hein
c2e0c36c79 Restheadspec now takes parameters from query parameters and headers. Allows for backward compatibility with our old dojo clients 2025-11-21 08:56:58 +02:00
Hein
59bd709460 More reflection function to handle sql columns and get default sqlcolumn lists. 2025-11-21 08:35:46 +02:00
Hein
05962035b6 when you specify computed columns without explicitly listing base columns, you'll get all base model column
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-20 17:34:46 +02:00
Hein
1cd04b7083 Better where clause handling for preloads 2025-11-20 17:02:27 +02:00
Hein
0d4909054c Better handling of preload where conditions and a few panic changes 2025-11-20 16:50:26 +02:00
Hein
745564f2e7 More Panic Recovery for reflection on orm 2025-11-20 15:20:21 +02:00
Hein
311e50bfdd Better relation lookup
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-20 14:30:59 +02:00
Hein
c95bc9e633 Added x-files feature 2025-11-20 12:47:36 +02:00
170 changed files with 44069 additions and 3376 deletions

1
.claude/readme Normal file
View File

@ -0,0 +1 @@
We use claude for testing and document generation.

52
.env.example Normal file
View File

@ -0,0 +1,52 @@
# ResolveSpec Environment Variables Example
# Environment variables override config file settings
# All variables are prefixed with RESOLVESPEC_
# Nested config uses underscores (e.g., server.addr -> RESOLVESPEC_SERVER_ADDR)
# Server Configuration
RESOLVESPEC_SERVER_ADDR=:8080
RESOLVESPEC_SERVER_SHUTDOWN_TIMEOUT=30s
RESOLVESPEC_SERVER_DRAIN_TIMEOUT=25s
RESOLVESPEC_SERVER_READ_TIMEOUT=10s
RESOLVESPEC_SERVER_WRITE_TIMEOUT=10s
RESOLVESPEC_SERVER_IDLE_TIMEOUT=120s
# Tracing Configuration
RESOLVESPEC_TRACING_ENABLED=false
RESOLVESPEC_TRACING_SERVICE_NAME=resolvespec
RESOLVESPEC_TRACING_SERVICE_VERSION=1.0.0
RESOLVESPEC_TRACING_ENDPOINT=http://localhost:4318/v1/traces
# Cache Configuration
RESOLVESPEC_CACHE_PROVIDER=memory
# Redis Cache (when provider=redis)
RESOLVESPEC_CACHE_REDIS_HOST=localhost
RESOLVESPEC_CACHE_REDIS_PORT=6379
RESOLVESPEC_CACHE_REDIS_PASSWORD=
RESOLVESPEC_CACHE_REDIS_DB=0
# Memcache (when provider=memcache)
# Note: For arrays, separate values with commas
RESOLVESPEC_CACHE_MEMCACHE_SERVERS=localhost:11211
RESOLVESPEC_CACHE_MEMCACHE_MAX_IDLE_CONNS=10
RESOLVESPEC_CACHE_MEMCACHE_TIMEOUT=100ms
# Logger Configuration
RESOLVESPEC_LOGGER_DEV=false
RESOLVESPEC_LOGGER_PATH=
# Middleware Configuration
RESOLVESPEC_MIDDLEWARE_RATE_LIMIT_RPS=100.0
RESOLVESPEC_MIDDLEWARE_RATE_LIMIT_BURST=200
RESOLVESPEC_MIDDLEWARE_MAX_REQUEST_SIZE=10485760
# CORS Configuration
# Note: For arrays in env vars, separate with commas
RESOLVESPEC_CORS_ALLOWED_ORIGINS=*
RESOLVESPEC_CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
RESOLVESPEC_CORS_ALLOWED_HEADERS=*
RESOLVESPEC_CORS_MAX_AGE=3600
# Database Configuration
RESOLVESPEC_DATABASE_URL=host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable

View File

@ -1,4 +1,4 @@
name: Tests
name: Build , Vet Test, and Lint
on:
push:
@ -9,7 +9,7 @@ on:
jobs:
test:
name: Run Tests
name: Run Vet Tests
runs-on: ubuntu-latest
strategy:
@ -38,22 +38,6 @@ jobs:
- name: Run go vet
run: go vet ./...
- name: Run tests
run: go test -v -race -coverprofile=coverage.out -covermode=atomic ./...
- name: Display test coverage
run: go tool cover -func=coverage.out
# - name: Upload coverage to Codecov
# uses: codecov/codecov-action@v4
# with:
# file: ./coverage.out
# flags: unittests
# name: codecov-umbrella
# env:
# CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
# continue-on-error: true
lint:
name: Lint Code
runs-on: ubuntu-latest

82
.github/workflows/make_tag.yml vendored Normal file
View File

@ -0,0 +1,82 @@
# This workflow will build a golang project
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
name: Create Go Release (Tag Versioning)
on:
workflow_dispatch:
inputs:
semver:
description: "New Version"
required: true
default: "patch"
type: choice
options:
- patch
- minor
- major
jobs:
tag_and_commit:
name: "Tag and Commit ${{ github.event.inputs.semver }}"
runs-on: linux
permissions:
contents: write # 'write' access to repository contents
pull-requests: write # 'write' access to pull requests
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Set up Git
run: |
git config --global user.name "Hein"
git config --global user.email "hein.puth@gmail.com"
- name: Fetch latest tag
id: latest_tag
run: |
git fetch --tags
latest_tag=$(git describe --tags `git rev-list --tags --max-count=1`)
echo "::set-output name=tag::$latest_tag"
- name: Determine new tag version
id: new_tag
run: |
current_tag=${{ steps.latest_tag.outputs.tag }}
version=$(echo $current_tag | cut -c 2-) # remove the leading 'v'
IFS='.' read -r -a version_parts <<< "$version"
major=${version_parts[0]}
minor=${version_parts[1]}
patch=${version_parts[2]}
case "${{ github.event.inputs.semver }}" in
"patch")
((patch++))
;;
"minor")
((minor++))
patch=0
;;
"release")
((major++))
minor=0
patch=0
;;
*)
echo "Invalid semver input"
exit 1
;;
esac
new_tag="v$major.$minor.$patch"
echo "::set-output name=tag::$new_tag"
- name: Create tag
run: |
git tag -a ${{ steps.new_tag.outputs.tag }} -m "Tagging ${{ steps.new_tag.outputs.tag }} for release"
- name: Push changes
uses: ad-m/github-push-action@master
with:
github_token: ${{ secrets.BITECH_GITHUB_TOKEN }}
force: true
tags: true

81
.github/workflows/tests.yml vendored Normal file
View File

@ -0,0 +1,81 @@
name: Tests
on:
push:
branches: [main, develop]
pull_request:
branches: [main, develop]
jobs:
unit-tests:
name: Unit Tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: "1.24"
- name: Run unit tests
run: go test ./pkg/resolvespec ./pkg/restheadspec -v -cover
- name: Generate coverage report
run: |
go test ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage.out
go tool cover -html=coverage.out -o coverage.html
- name: Upload coverage
uses: actions/upload-artifact@v5
with:
name: coverage-report
path: coverage.html
integration-tests:
name: Integration Tests
runs-on: ubuntu-latest
services:
postgres:
image: postgres:15-alpine
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
steps:
- uses: actions/checkout@v6
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: "1.24"
- name: Create test databases
env:
PGPASSWORD: postgres
run: |
psql -h localhost -U postgres -c "CREATE DATABASE resolvespec_test;"
psql -h localhost -U postgres -c "CREATE DATABASE restheadspec_test;"
- name: Run resolvespec integration tests
env:
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
run: go test -tags=integration ./pkg/resolvespec -v -coverprofile=coverage-resolvespec-integration.out
- name: Run restheadspec integration tests
env:
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=restheadspec_test port=5432 sslmode=disable"
run: go test -tags=integration ./pkg/restheadspec -v -coverprofile=coverage-restheadspec-integration.out
- name: Generate integration coverage
env:
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
run: |
go tool cover -html=coverage-resolvespec-integration.out -o coverage-resolvespec-integration.html
go tool cover -html=coverage-restheadspec-integration.out -o coverage-restheadspec-integration.html
- name: Upload resolvespec integration coverage
uses: actions/upload-artifact@v5
with:
name: resolvespec-integration-coverage-report
path: coverage-resolvespec-integration.html
- name: Upload restheadspec integration coverage
uses: actions/upload-artifact@v5
with:
name: integration-coverage-restheadspec-report
path: coverage-restheadspec-integration

View File

@ -71,41 +71,26 @@
},
"gocritic": {
"enabled-checks": [
"appendAssign",
"assignOp",
"boolExprSimplify",
"builtinShadow",
"captLocal",
"caseOrder",
"defaultCaseOrder",
"dupArg",
"dupBranchBody",
"dupCase",
"dupSubExpr",
"elseif",
"emptyFallthrough",
"equalFold",
"flagName",
"ifElseChain",
"indexAlloc",
"initClause",
"methodExprCall",
"nilValReturn",
"rangeExprCopy",
"rangeValCopy",
"regexpMust",
"singleCaseSwitch",
"sloppyLen",
"stringXbytes",
"switchTrue",
"typeAssertChain",
"typeSwitchVar",
"underef",
"unlabelStmt",
"unnamedResult",
"unnecessaryBlock",
"weakCond",
"yodaStyleExpr"
],
"disabled-checks": [
"ifElseChain"
]
},
"revive": {

56
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,56 @@
{
"go.testFlags": [
"-v"
],
"go.testTimeout": "300s",
"go.coverOnSave": false,
"go.coverOnSingleTest": true,
"go.coverageDecorator": {
"type": "gutter"
},
"go.testEnvVars": {
"TEST_DATABASE_URL": "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
},
"go.toolsEnvVars": {
"CGO_ENABLED": "0"
},
"go.buildTags": "",
"go.testTags": "",
"files.exclude": {
"**/.git": true,
"**/.DS_Store": true,
"**/coverage.out": true,
"**/coverage.html": true,
"**/coverage-integration.out": true,
"**/coverage-integration.html": true
},
"files.watcherExclude": {
"**/.git/objects/**": true,
"**/.git/subtree-cache/**": true,
"**/node_modules/*/**": true,
"**/.hg/store/**": true,
"**/vendor/**": true
},
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
},
"[go]": {
"editor.defaultFormatter": "golang.go",
"editor.formatOnSave": true,
"editor.insertSpaces": false,
"editor.tabSize": 4
},
"gopls": {
"ui.completion.usePlaceholders": true,
"ui.semanticTokens": true,
"ui.codelenses": {
"generate": true,
"regenerate_cgo": true,
"test": true,
"tidy": true,
"upgrade_dependency": true,
"vendor": true
}
}
}

234
.vscode/tasks.json vendored
View File

@ -6,10 +6,10 @@
"label": "go: build workspace",
"command": "build",
"options": {
"env": {
"env": {
"CGO_ENABLED": "0"
},
"cwd": "${workspaceFolder}/bin",
"cwd": "${workspaceFolder}/bin"
},
"args": [
"../..."
@ -17,12 +17,179 @@
"problemMatcher": [
"$go"
],
"group": "build",
"group": "build"
},
{
"type": "shell",
"label": "test: unit tests (all)",
"command": "go test ./pkg/resolvespec ./pkg/restheadspec -v -cover",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": {
"kind": "test",
"isDefault": true
},
"presentation": {
"reveal": "always",
"panel": "shared",
"focus": true
}
},
{
"type": "shell",
"label": "test: unit tests (resolvespec)",
"command": "go test ./pkg/resolvespec -v -cover",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "test: unit tests (restheadspec)",
"command": "go test ./pkg/restheadspec -v -cover",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "test: integration tests (automated)",
"command": "./scripts/run-integration-tests.sh",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated",
"focus": true
}
},
{
"type": "shell",
"label": "test: integration tests (resolvespec only)",
"command": "./scripts/run-integration-tests.sh resolvespec",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated"
}
},
{
"type": "shell",
"label": "test: integration tests (restheadspec only)",
"command": "./scripts/run-integration-tests.sh restheadspec",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated"
}
},
{
"type": "shell",
"label": "test: coverage report",
"command": "make coverage",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "test: integration coverage report",
"command": "make coverage-integration",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "docker: start postgres",
"command": "make docker-up",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "docker: stop postgres",
"command": "make docker-down",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "docker: clean postgres data",
"command": "make clean",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "go",
"label": "go: test workspace",
"label": "go: test workspace (with race)",
"command": "test",
"options": {
"cwd": "${workspaceFolder}"
@ -37,13 +204,10 @@
"problemMatcher": [
"$go"
],
"group": {
"kind": "test",
"isDefault": true
},
"group": "test",
"presentation": {
"reveal": "always",
"panel": "new"
"panel": "shared"
}
},
{
@ -66,21 +230,59 @@
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"group": "test"
"group": "build"
},
{
"type": "shell",
"label": "go: full test suite",
"label": "go: lint workspace (fix)",
"command": "golangci-lint run --timeout=5m --fix",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"group": "build"
},
{
"type": "shell",
"label": "test: all tests (unit + integration)",
"command": "make test",
"options": {
"cwd": "${workspaceFolder}"
},
"dependsOn": [
"docker: start postgres"
],
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated",
"focus": true
}
},
{
"type": "shell",
"label": "test: full suite with checks",
"dependsOrder": "sequence",
"dependsOn": [
"go: vet workspace",
"go: test workspace"
"test: unit tests (all)",
"test: integration tests (automated)"
],
"problemMatcher": [],
"group": {
"kind": "test",
"isDefault": false
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated"
}
},
{
"type": "shell",
"label": "Make Release",
"problemMatcher": [],
"command": "sh ${workspaceFolder}/make_release.sh"
}
]
}

View File

@ -1,173 +0,0 @@
# Migration Guide: Database and Router Abstraction
This guide explains how to migrate from the direct GORM/Router dependencies to the new abstracted interfaces.
## Overview of Changes
### What was changed:
1. **Database Operations**: GORM-specific code is now abstracted behind `Database` interface
2. **Router Integration**: HTTP router dependencies are abstracted behind `Router` interface
3. **Model Registry**: Models are now managed through a `ModelRegistry` interface
4. **Backward Compatibility**: Existing code continues to work with `NewAPIHandler()`
### Benefits:
- **Database Flexibility**: Switch between GORM, Bun, or other ORMs without code changes
- **Router Flexibility**: Use Gorilla Mux, Gin, Echo, or other routers
- **Better Testing**: Easy to mock database and router interactions
- **Cleaner Separation**: Business logic separated from ORM/router specifics
## Migration Path
### Option 1: No Changes Required (Backward Compatible)
Your existing code continues to work without any changes:
```go
// This still works exactly as before
handler := resolvespec.NewAPIHandler(db)
```
### Option 2: Gradual Migration to New API
#### Step 1: Use New Handler Constructor
```go
// Old way
handler := resolvespec.NewAPIHandler(gormDB)
// New way
handler := resolvespec.NewHandlerWithGORM(gormDB)
```
#### Step 2: Use Interface-based Approach
```go
// Create database adapter
dbAdapter := resolvespec.NewGormAdapter(gormDB)
// Create model registry
registry := resolvespec.NewModelRegistry()
// Register your models
registry.RegisterModel("public.users", &User{})
registry.RegisterModel("public.orders", &Order{})
// Create handler
handler := resolvespec.NewHandler(dbAdapter, registry)
```
## Switching Database Backends
### From GORM to Bun
```go
// Add bun dependency first:
// go get github.com/uptrace/bun
// Old GORM setup
gormDB, _ := gorm.Open(sqlite.Open("test.db"), &gorm.Config{})
gormAdapter := resolvespec.NewGormAdapter(gormDB)
// New Bun setup
sqlDB, _ := sql.Open("sqlite3", "test.db")
bunDB := bun.NewDB(sqlDB, sqlitedialect.New())
bunAdapter := resolvespec.NewBunAdapter(bunDB)
// Handler creation is identical
handler := resolvespec.NewHandler(bunAdapter, registry)
```
## Router Flexibility
### Current Gorilla Mux (Default)
```go
router := mux.NewRouter()
resolvespec.SetupRoutes(router, handler)
```
### BunRouter (Built-in Support)
```go
// Simple setup
router := bunrouter.New()
resolvespec.SetupBunRouterWithResolveSpec(router, handler)
// Or using adapter
routerAdapter := resolvespec.NewStandardBunRouterAdapter()
// Use routerAdapter.GetBunRouter() for the underlying router
```
### Using Router Adapters (Advanced)
```go
// For when you want router abstraction
routerAdapter := resolvespec.NewStandardRouter()
routerAdapter.RegisterRoute("/{schema}/{entity}", handlerFunc)
```
## Model Registration
### Old Way (Still Works)
```go
// Models registered through existing models package
handler.RegisterModel("public", "users", &User{})
```
### New Way (Recommended)
```go
registry := resolvespec.NewModelRegistry()
registry.RegisterModel("public.users", &User{})
registry.RegisterModel("public.orders", &Order{})
handler := resolvespec.NewHandler(dbAdapter, registry)
```
## Interface Definitions
### Database Interface
```go
type Database interface {
NewSelect() SelectQuery
NewInsert() InsertQuery
NewUpdate() UpdateQuery
NewDelete() DeleteQuery
// ... transaction methods
}
```
### Available Adapters
- `GormAdapter` - For GORM (ready to use)
- `BunAdapter` - For Bun (add dependency: `go get github.com/uptrace/bun`)
- Easy to create custom adapters for other ORMs
## Testing Benefits
### Before (Tightly Coupled)
```go
// Hard to test - requires real GORM setup
func TestHandler(t *testing.T) {
db := setupRealGormDB()
handler := resolvespec.NewAPIHandler(db)
// ... test logic
}
```
### After (Mockable)
```go
// Easy to test - mock the interfaces
func TestHandler(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := resolvespec.NewHandler(mockDB, mockRegistry)
// ... test logic with mocks
}
```
## Breaking Changes
- **None for existing code** - Full backward compatibility maintained
- New interfaces are additive, not replacing existing APIs
## Recommended Migration Timeline
1. **Phase 1**: Use existing code (no changes needed)
2. **Phase 2**: Gradually adopt new constructors (`NewHandlerWithGORM`)
3. **Phase 3**: Move to interface-based approach when needed
4. **Phase 4**: Switch database backends if desired
## Getting Help
- Check example functions in `resolvespec.go`
- Review interface definitions in `database.go`
- Examine adapter implementations for patterns

66
Makefile Normal file
View File

@ -0,0 +1,66 @@
.PHONY: test test-unit test-integration docker-up docker-down clean
# Run all unit tests
test-unit:
@echo "Running unit tests..."
@go test ./pkg/resolvespec ./pkg/restheadspec -v -cover
# Run all integration tests (requires PostgreSQL)
test-integration:
@echo "Running integration tests..."
@go test -tags=integration ./pkg/resolvespec ./pkg/restheadspec -v
# Run all tests (unit + integration)
test: test-unit test-integration
# Start PostgreSQL for integration tests
docker-up:
@echo "Starting PostgreSQL container..."
@docker-compose up -d postgres-test
@echo "Waiting for PostgreSQL to be ready..."
@sleep 5
@echo "PostgreSQL is ready!"
# Stop PostgreSQL container
docker-down:
@echo "Stopping PostgreSQL container..."
@docker-compose down
# Clean up Docker volumes and test data
clean:
@echo "Cleaning up..."
@docker-compose down -v
@echo "Cleanup complete!"
# Run integration tests with Docker (full workflow)
test-integration-docker: docker-up
@echo "Running integration tests with Docker..."
@go test -tags=integration ./pkg/resolvespec ./pkg/restheadspec -v
@$(MAKE) docker-down
# Check test coverage
coverage:
@echo "Generating coverage report..."
@go test ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage.out
@go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report generated: coverage.html"
# Run integration tests coverage
coverage-integration:
@echo "Generating integration test coverage report..."
@go test -tags=integration ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage-integration.out
@go tool cover -html=coverage-integration.out -o coverage-integration.html
@echo "Integration coverage report generated: coverage-integration.html"
help:
@echo "Available targets:"
@echo " test-unit - Run unit tests"
@echo " test-integration - Run integration tests (requires PostgreSQL)"
@echo " test - Run all tests"
@echo " docker-up - Start PostgreSQL container"
@echo " docker-down - Stop PostgreSQL container"
@echo " test-integration-docker - Run integration tests with Docker (automated)"
@echo " clean - Clean up Docker volumes"
@echo " coverage - Generate unit test coverage report"
@echo " coverage-integration - Generate integration test coverage report"
@echo " help - Show this help message"

636
README.md
View File

@ -1,73 +1,83 @@
# 📜 ResolveSpec 📜
![Tests](https://github.com/bitechdev/ResolveSpec/workflows/Tests/badge.svg)
![1.00](https://github.com/bitechdev/ResolveSpec/workflows/Tests/badge.svg)
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **two complementary approaches**:
1. **ResolveSpec** - Body-based API with JSON request options
2. **RestHeadSpec** - Header-based API where query options are passed via HTTP headers
3. **FuncSpec** - Header-based API to map and call API's to sql functions.
Both share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
**🆕 New in v2.0**: Database-agnostic architecture with support for GORM, Bun, and other ORMs. Router-flexible design works with Gorilla Mux, Gin, Echo, and more.
Documentation Generated by LLMs
**🆕 New in v2.1**: RestHeadSpec (HeaderSpec) - Header-based REST API with lifecycle hooks, cursor pagination, and advanced filtering.
![slogan](./generated_slogan.webp)
![1.00](./generated_slogan.webp)
## Table of Contents
- [Features](#features)
- [Installation](#installation)
- [Quick Start](#quick-start)
- [ResolveSpec (Body-Based API)](#resolvespec-body-based-api)
- [RestHeadSpec (Header-Based API)](#restheadspec-header-based-api)
- [Existing Code (Backward Compatible)](#option-1-existing-code-backward-compatible)
- [New Database-Agnostic API](#option-2-new-database-agnostic-api)
- [Router Integration](#router-integration)
- [Migration from v1.x](#migration-from-v1x)
- [Architecture](#architecture)
- [API Structure](#api-structure)
- [RestHeadSpec: Header-Based API](#restheadspec-header-based-api-1)
- [Lifecycle Hooks](#lifecycle-hooks)
- [Cursor Pagination](#cursor-pagination)
- [Response Formats](#response-formats)
- [Single Record as Object](#single-record-as-object-default-behavior)
- [Example Usage](#example-usage)
- [Recursive CRUD Operations](#recursive-crud-operations-)
- [Testing](#testing)
- [What's New](#whats-new)
* [Features](#features)
* [Installation](#installation)
* [Quick Start](#quick-start)
* [ResolveSpec (Body-Based API)](#resolvespec-body-based-api)
* [RestHeadSpec (Header-Based API)](#restheadspec-header-based-api)
* [Existing Code (Backward Compatible)](#option-1-existing-code-backward-compatible)
* [New Database-Agnostic API](#option-2-new-database-agnostic-api)
* [Router Integration](#router-integration)
* [Migration from v1.x](#migration-from-v1x)
* [Architecture](#architecture)
* [API Structure](#api-structure)
* [RestHeadSpec: Header-Based API](#restheadspec-header-based-api-1)
* [Lifecycle Hooks](#lifecycle-hooks)
* [Cursor Pagination](#cursor-pagination)
* [Response Formats](#response-formats)
* [Single Record as Object](#single-record-as-object-default-behavior)
* [Example Usage](#example-usage)
* [Recursive CRUD Operations](#recursive-crud-operations-)
* [Testing](#testing)
* [What's New](#whats-new)
## Features
### Core Features
- **Dynamic Data Querying**: Select specific columns and relationships to return
- **Relationship Preloading**: Load related entities with custom column selection and filters
- **Complex Filtering**: Apply multiple filters with various operators
- **Sorting**: Multi-column sort support
- **Pagination**: Built-in limit/offset and cursor-based pagination
- **Computed Columns**: Define virtual columns for complex calculations
- **Custom Operators**: Add custom SQL conditions when needed
- **🆕 Recursive CRUD Handler**: Automatically handle nested object graphs with foreign key resolution and per-record operation control via `_request` field
* **Dynamic Data Querying**: Select specific columns and relationships to return
* **Relationship Preloading**: Load related entities with custom column selection and filters
* **Complex Filtering**: Apply multiple filters with various operators
* **Sorting**: Multi-column sort support
* **Pagination**: Built-in limit/offset and cursor-based pagination (both ResolveSpec and RestHeadSpec)
* **Computed Columns**: Define virtual columns for complex calculations
* **Custom Operators**: Add custom SQL conditions when needed
* **🆕 Recursive CRUD Handler**: Automatically handle nested object graphs with foreign key resolution and per-record operation control via `_request` field
### Architecture (v2.0+)
- **🆕 Database Agnostic**: Works with GORM, Bun, or any database layer through adapters
- **🆕 Router Flexible**: Integrates with Gorilla Mux, Gin, Echo, or custom routers
- **🆕 Backward Compatible**: Existing code works without changes
- **🆕 Better Testing**: Mockable interfaces for easy unit testing
* **🆕 Database Agnostic**: Works with GORM, Bun, or any database layer through adapters
* **🆕 Router Flexible**: Integrates with Gorilla Mux, Gin, Echo, or custom routers
* **🆕 Backward Compatible**: Existing code works without changes
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
### RestHeadSpec (v2.1+)
- **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
- **🆕 Lifecycle Hooks**: Before/after hooks for create, read, update, and delete operations
- **🆕 Cursor Pagination**: Efficient cursor-based pagination with complex sort support
- **🆕 Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible formats
- **🆕 Single Record as Object**: Automatically normalize single-element arrays to objects (enabled by default)
- **🆕 Advanced Filtering**: Field filters, search operators, AND/OR logic, and custom SQL
- **🆕 Base64 Encoding**: Support for base64-encoded header values
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
* **🆕 Lifecycle Hooks**: Before/after hooks for create, read, update, and delete operations
* **🆕 Cursor Pagination**: Efficient cursor-based pagination with complex sort support
* **🆕 Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible formats
* **🆕 Single Record as Object**: Automatically normalize single-element arrays to objects (enabled by default)
* **🆕 Advanced Filtering**: Field filters, search operators, AND/OR logic, and custom SQL
* **🆕 Base64 Encoding**: Support for base64-encoded header values
### Routing & CORS (v3.0+)
* **🆕 Explicit Route Registration**: Routes created per registered model instead of dynamic lookups
* **🆕 OPTIONS Method Support**: Full OPTIONS method support returning model metadata
* **🆕 CORS Headers**: Comprehensive CORS support with all HeadSpec headers allowed
* **🆕 Better Route Control**: Customize routes per model with more flexibility
## API Structure
### URL Patterns
```
/[schema]/[table_or_entity]/[id]
/[schema]/[table_or_entity]
@ -77,7 +87,7 @@ Both share the same core architecture and provide dynamic data querying, relatio
### Request Format
```json
```JSON
{
"operation": "read|create|update|delete",
"data": {
@ -102,7 +112,7 @@ RestHeadSpec provides an alternative REST API approach where all query options a
### Quick Example
```http
```HTTP
GET /public/users HTTP/1.1
Host: api.example.com
X-Select-Fields: id,name,email,department_id
@ -116,20 +126,22 @@ X-DetailApi: true
### Setup with GORM
```go
```Go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
import "github.com/gorilla/mux"
// Create handler
handler := restheadspec.NewHandlerWithGORM(db)
// Register models using schema.table format
// IMPORTANT: Register models BEFORE setting up routes
// Routes are created explicitly for each registered model
handler.Registry.RegisterModel("public.users", &User{})
handler.Registry.RegisterModel("public.posts", &Post{})
// Setup routes
// Setup routes (creates explicit routes for each registered model)
// This replaces the old dynamic route lookup approach
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler)
restheadspec.SetupMuxRoutes(router, handler, nil)
// Start server
http.ListenAndServe(":8080", router)
@ -137,7 +149,7 @@ http.ListenAndServe(":8080", router)
### Setup with Bun ORM
```go
```Go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
import "github.com/uptrace/bun"
@ -154,29 +166,70 @@ restheadspec.SetupMuxRoutes(router, handler)
### Common Headers
| Header | Description | Example |
|--------|-------------|---------|
| `X-Select-Fields` | Columns to include | `id,name,email` |
| `X-Not-Select-Fields` | Columns to exclude | `password,internal_notes` |
| `X-FieldFilter-{col}` | Exact match filter | `X-FieldFilter-Status: active` |
| `X-SearchFilter-{col}` | Fuzzy search (ILIKE) | `X-SearchFilter-Name: john` |
| `X-SearchOp-{op}-{col}` | Filter with operator | `X-SearchOp-Gte-Age: 18` |
| `X-Preload` | Preload relations | `posts:id,title` |
| `X-Sort` | Sort columns | `-created_at,+name` |
| `X-Limit` | Limit results | `50` |
| `X-Offset` | Offset for pagination | `100` |
| `X-Clean-JSON` | Remove null/empty fields | `true` |
| `X-Single-Record-As-Object` | Return single records as objects (default: `true`) | `false` |
| Header | Description | Example |
| --------------------------- | -------------------------------------------------- | ------------------------------ |
| `X-Select-Fields` | Columns to include | `id,name,email` |
| `X-Not-Select-Fields` | Columns to exclude | `password,internal_notes` |
| `X-FieldFilter-{col}` | Exact match filter | `X-FieldFilter-Status: active` |
| `X-SearchFilter-{col}` | Fuzzy search (ILIKE) | `X-SearchFilter-Name: john` |
| `X-SearchOp-{op}-{col}` | Filter with operator | `X-SearchOp-Gte-Age: 18` |
| `X-Preload` | Preload relations | `posts:id,title` |
| `X-Sort` | Sort columns | `-created_at,+name` |
| `X-Limit` | Limit results | `50` |
| `X-Offset` | Offset for pagination | `100` |
| `X-Clean-JSON` | Remove null/empty fields | `true` |
| `X-Single-Record-As-Object` | Return single records as objects (default: `true`) | `false` |
**Available Operators**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `contains`, `startswith`, `endswith`, `between`, `betweeninclusive`, `in`, `empty`, `notempty`
For complete header documentation, see [pkg/restheadspec/HEADERS.md](pkg/restheadspec/HEADERS.md).
### CORS & OPTIONS Support
ResolveSpec and RestHeadSpec include comprehensive CORS support for cross-origin requests:
**OPTIONS Method**:
```HTTP
OPTIONS /public/users HTTP/1.1
```
Returns metadata with appropriate CORS headers:
```HTTP
Access-Control-Allow-Origin: *
Access-Control-Allow-Methods: GET, POST, OPTIONS
Access-Control-Allow-Headers: Content-Type, Authorization, X-Select-Fields, X-FieldFilter-*, ...
Access-Control-Max-Age: 86400
Access-Control-Allow-Credentials: true
```
**Key Features**:
* OPTIONS returns model metadata (same as GET metadata endpoint)
* All HTTP methods include CORS headers automatically
* OPTIONS requests don't require authentication (CORS preflight)
* Supports all HeadSpec custom headers (`X-Select-Fields`, `X-FieldFilter-*`, etc.)
* 24-hour max age to reduce preflight requests
**Configuration**:
```Go
import "github.com/bitechdev/ResolveSpec/pkg/common"
// Get default CORS config
corsConfig := common.DefaultCORSConfig()
// Customize if needed
corsConfig.AllowedOrigins = []string{"https://example.com"}
corsConfig.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
```
### Lifecycle Hooks
RestHeadSpec supports lifecycle hooks for all CRUD operations:
```go
```Go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
// Create handler
@ -221,27 +274,29 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
```
**Available Hook Types**:
- `BeforeRead`, `AfterRead`
- `BeforeCreate`, `AfterCreate`
- `BeforeUpdate`, `AfterUpdate`
- `BeforeDelete`, `AfterDelete`
* `BeforeRead`, `AfterRead`
* `BeforeCreate`, `AfterCreate`
* `BeforeUpdate`, `AfterUpdate`
* `BeforeDelete`, `AfterDelete`
**HookContext** provides:
- `Context`: Request context
- `Handler`: Access to handler, database, and registry
- `Schema`, `Entity`, `TableName`: Request info
- `Model`: The registered model type
- `Options`: Parsed request options (filters, sorting, etc.)
- `ID`: Record ID (for single-record operations)
- `Data`: Request data (for create/update)
- `Result`: Operation result (for after hooks)
- `Writer`: Response writer (allows hooks to modify response)
* `Context`: Request context
* `Handler`: Access to handler, database, and registry
* `Schema`, `Entity`, `TableName`: Request info
* `Model`: The registered model type
* `Options`: Parsed request options (filters, sorting, etc.)
* `ID`: Record ID (for single-record operations)
* `Data`: Request data (for create/update)
* `Result`: Operation result (for after hooks)
* `Writer`: Response writer (allows hooks to modify response)
### Cursor Pagination
RestHeadSpec supports efficient cursor-based pagination for large datasets:
```http
```HTTP
GET /public/posts HTTP/1.1
X-Sort: -created_at,+id
X-Limit: 50
@ -249,20 +304,22 @@ X-Cursor-Forward: <cursor_token>
```
**How it works**:
1. First request returns results + cursor token in response
2. Subsequent requests use `X-Cursor-Forward` or `X-Cursor-Backward`
3. Cursor maintains consistent ordering even with data changes
4. Supports complex multi-column sorting
**Benefits over offset pagination**:
- Consistent results when data changes
- Better performance for large offsets
- Prevents "skipped" or duplicate records
- Works with complex sort expressions
* Consistent results when data changes
* Better performance for large offsets
* Prevents "skipped" or duplicate records
* Works with complex sort expressions
**Example with hooks**:
```go
```Go
// Enable cursor pagination in a hook
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
// For large tables, enforce cursor pagination
@ -278,7 +335,8 @@ handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookConte
RestHeadSpec supports multiple response formats:
**1. Simple Format** (`X-SimpleApi: true`):
```json
```JSON
[
{ "id": 1, "name": "John" },
{ "id": 2, "name": "Jane" }
@ -286,7 +344,8 @@ RestHeadSpec supports multiple response formats:
```
**2. Detail Format** (`X-DetailApi: true`, default):
```json
```JSON
{
"success": true,
"data": [...],
@ -300,7 +359,8 @@ RestHeadSpec supports multiple response formats:
```
**3. Syncfusion Format** (`X-Syncfusion: true`):
```json
```JSON
{
"result": [...],
"count": 100
@ -312,10 +372,12 @@ RestHeadSpec supports multiple response formats:
By default, RestHeadSpec automatically converts single-element arrays into objects for cleaner API responses. This provides a better developer experience when fetching individual records.
**Default behavior (enabled)**:
```http
```HTTP
GET /public/users/123
```
```json
```JSON
{
"success": true,
"data": { "id": 123, "name": "John", "email": "john@example.com" }
@ -323,7 +385,8 @@ GET /public/users/123
```
Instead of:
```json
```JSON
{
"success": true,
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
@ -331,11 +394,13 @@ Instead of:
```
**To disable** (force arrays for consistency):
```http
```HTTP
GET /public/users/123
X-Single-Record-As-Object: false
```
```json
```JSON
{
"success": true,
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
@ -343,23 +408,26 @@ X-Single-Record-As-Object: false
```
**How it works**:
- When a query returns exactly **one record**, it's returned as an object
- When a query returns **multiple records**, they're returned as an array
- Set `X-Single-Record-As-Object: false` to always receive arrays
- Works with all response formats (simple, detail, syncfusion)
- Applies to both read operations and create/update returning clauses
* When a query returns exactly **one record**, it's returned as an object
* When a query returns **multiple records**, they're returned as an array
* Set `X-Single-Record-As-Object: false` to always receive arrays
* Works with all response formats (simple, detail, syncfusion)
* Applies to both read operations and create/update returning clauses
**Benefits**:
- Cleaner API responses for single-record queries
- No need to unwrap single-element arrays on the client side
- Better TypeScript/type inference support
- Consistent with common REST API patterns
- Backward compatible via header opt-out
* Cleaner API responses for single-record queries
* No need to unwrap single-element arrays on the client side
* Better TypeScript/type inference support
* Consistent with common REST API patterns
* Backward compatible via header opt-out
## Example Usage
### Reading Data with Related Entities
```json
```JSON
POST /core/users
{
"operation": "read",
@ -397,13 +465,89 @@ POST /core/users
}
```
### Cursor Pagination (ResolveSpec)
ResolveSpec now supports cursor-based pagination for efficient traversal of large datasets:
```JSON
POST /core/posts
{
"operation": "read",
"options": {
"sort": [
{
"column": "created_at",
"direction": "desc"
},
{
"column": "id",
"direction": "asc"
}
],
"limit": 50,
"cursor_forward": "12345"
}
}
```
**How it works**:
1. First request returns results + cursor token (last record's ID)
2. Subsequent requests use `cursor_forward` or `cursor_backward` in options
3. Cursor maintains consistent ordering even when data changes
4. Supports complex multi-column sorting
**Benefits over offset pagination**:
- Consistent results when data changes between requests
- Better performance for large offsets
- Prevents "skipped" or duplicate records
- Works with complex sort expressions
**Example request sequence**:
```JSON
// First request - no cursor
POST /core/posts
{
"operation": "read",
"options": {
"sort": [{"column": "created_at", "direction": "desc"}],
"limit": 50
}
}
// Response includes data + last record ID
// Use the last record's ID as cursor_forward for next page
// Second request - with cursor
POST /core/posts
{
"operation": "read",
"options": {
"sort": [{"column": "created_at", "direction": "desc"}],
"limit": 50,
"cursor_forward": "12345" // ID of last record from previous page
}
}
// For backward pagination
POST /core/posts
{
"operation": "read",
"options": {
"sort": [{"column": "created_at", "direction": "desc"}],
"limit": 50,
"cursor_backward": "12300" // ID of first record from current page
}
}
```
### Recursive CRUD Operations (🆕)
ResolveSpec now supports automatic handling of nested object graphs with intelligent foreign key resolution. This allows you to create, update, or delete entire object hierarchies in a single request.
#### Creating Nested Objects
```json
```JSON
POST /core/users
{
"operation": "create",
@ -436,7 +580,7 @@ POST /core/users
Control individual operations for each nested record using the special `_request` field:
```json
```JSON
POST /core/users/123
{
"operation": "update",
@ -462,11 +606,12 @@ POST /core/users/123
}
```
**Supported `_request` values**:
- `insert` - Create a new related record
- `update` - Update an existing related record
- `delete` - Delete a related record
- `upsert` - Create if doesn't exist, update if exists
**Supported** **`_request`** **values**:
* `insert` - Create a new related record
* `update` - Update an existing related record
* `delete` - Delete a related record
* `upsert` - Create if doesn't exist, update if exists
#### How It Works
@ -478,14 +623,14 @@ POST /core/users/123
#### Benefits
- Reduce API round trips for complex object graphs
- Maintain referential integrity automatically
- Simplify client-side code
- Atomic operations with automatic rollback on errors
* Reduce API round trips for complex object graphs
* Maintain referential integrity automatically
* Simplify client-side code
* Atomic operations with automatic rollback on errors
## Installation
```bash
```Shell
go get github.com/bitechdev/ResolveSpec
```
@ -495,7 +640,7 @@ go get github.com/bitechdev/ResolveSpec
ResolveSpec uses JSON request bodies to specify query options:
```go
```Go
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
// Create handler
@ -522,7 +667,7 @@ resolvespec.SetupRoutes(router, handler)
RestHeadSpec uses HTTP headers for query options instead of request body:
```go
```Go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
// Create handler with GORM
@ -551,7 +696,7 @@ See [RestHeadSpec: Header-Based API](#restheadspec-header-based-api-1) for compl
Your existing code continues to work without any changes:
```go
```Go
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
// This still works exactly as before
@ -569,7 +714,7 @@ ResolveSpec v2.0 introduces a new database and router abstraction layer while ma
To update your imports:
```bash
```Shell
# Update go.mod
go mod edit -replace github.com/Warky-Devs/ResolveSpec=github.com/bitechdev/ResolveSpec@latest
go mod tidy
@ -581,7 +726,7 @@ go mod tidy
Alternatively, use find and replace in your project:
```bash
```Shell
find . -type f -name "*.go" -exec sed -i 's|github.com/Warky-Devs/ResolveSpec|github.com/bitechdev/ResolveSpec|g' {} +
go mod tidy
```
@ -596,7 +741,7 @@ go mod tidy
### Detailed Migration Guide
For detailed migration instructions, examples, and best practices, see [MIGRATION_GUIDE.md](MIGRATION_GUIDE.md).
For detailed migration instructions, examples, and best practices, see [MIGRATION\_GUIDE.md](MIGRATION_GUIDE.md).
## Architecture
@ -638,22 +783,23 @@ Your Application Code
### Supported Database Layers
- **GORM** (default, fully supported)
- **Bun** (ready to use, included in dependencies)
- **Custom ORMs** (implement the `Database` interface)
* **GORM** (default, fully supported)
* **Bun** (ready to use, included in dependencies)
* **Custom ORMs** (implement the `Database` interface)
### Supported Routers
- **Gorilla Mux** (built-in support with `SetupRoutes()`)
- **BunRouter** (built-in support with `SetupBunRouterWithResolveSpec()`)
- **Gin** (manual integration, see examples above)
- **Echo** (manual integration, see examples above)
- **Custom Routers** (implement request/response adapters)
* **Gorilla Mux** (built-in support with `SetupRoutes()`)
* **BunRouter** (built-in support with `SetupBunRouterWithResolveSpec()`)
* **Gin** (manual integration, see examples above)
* **Echo** (manual integration, see examples above)
* **Custom Routers** (implement request/response adapters)
### Option 2: New Database-Agnostic API
#### With GORM (Recommended Migration Path)
```go
```Go
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
// Create database adapter
@ -669,7 +815,8 @@ handler := resolvespec.NewHandler(dbAdapter, registry)
```
#### With Bun ORM
```go
```Go
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
import "github.com/uptrace/bun"
@ -684,22 +831,25 @@ handler := resolvespec.NewHandler(dbAdapter, registry)
### Router Integration
#### Gorilla Mux (Built-in Support)
```go
```Go
import "github.com/gorilla/mux"
// Backward compatible way
router := mux.NewRouter()
resolvespec.SetupRoutes(router, handler)
// Register models first
handler.Registry.RegisterModel("public.users", &User{})
handler.Registry.RegisterModel("public.posts", &Post{})
// Or manually:
router.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
handler.Handle(w, r, vars)
}).Methods("POST")
// Setup routes - creates explicit routes for each model
router := mux.NewRouter()
resolvespec.SetupMuxRoutes(router, handler, nil)
// Routes created: /public/users, /public/posts, etc.
// Each route includes GET, POST, and OPTIONS methods with CORS support
```
#### Gin (Custom Integration)
```go
```Go
import "github.com/gin-gonic/gin"
func setupGin(handler *resolvespec.Handler) *gin.Engine {
@ -722,7 +872,8 @@ func setupGin(handler *resolvespec.Handler) *gin.Engine {
```
#### Echo (Custom Integration)
```go
```Go
import "github.com/labstack/echo/v4"
func setupEcho(handler *resolvespec.Handler) *echo.Echo {
@ -745,7 +896,8 @@ func setupEcho(handler *resolvespec.Handler) *echo.Echo {
```
#### BunRouter (Built-in Support)
```go
```Go
import "github.com/uptrace/bunrouter"
// Simple setup with built-in function
@ -790,7 +942,8 @@ func setupFullUptrace(bunDB *bun.DB) *bunrouter.Router {
## Configuration
### Model Registration
```go
```Go
type User struct {
ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"name"`
@ -804,20 +957,24 @@ handler.RegisterModel("core", "users", &User{})
## Features in Detail
### Filtering
Supported operators:
- eq: Equal
- neq: Not Equal
- gt: Greater Than
- gte: Greater Than or Equal
- lt: Less Than
- lte: Less Than or Equal
- like: LIKE pattern matching
- ilike: Case-insensitive LIKE
- in: IN clause
* eq: Equal
* neq: Not Equal
* gt: Greater Than
* gte: Greater Than or Equal
* lt: Less Than
* lte: Less Than or Equal
* like: LIKE pattern matching
* ilike: Case-insensitive LIKE
* in: IN clause
### Sorting
Support for multiple sort criteria with direction:
```json
```JSON
"sort": [
{
"column": "created_at",
@ -831,8 +988,10 @@ Support for multiple sort criteria with direction:
```
### Computed Columns
Define virtual columns using SQL expressions:
```json
```JSON
"computedColumns": [
{
"name": "full_name",
@ -845,7 +1004,7 @@ Define virtual columns using SQL expressions:
### With New Architecture (Mockable)
```go
```Go
import "github.com/stretchr/testify/mock"
// Create mock database
@ -880,14 +1039,14 @@ ResolveSpec uses GitHub Actions for automated testing and quality checks. The CI
The project includes automated workflows that:
- **Test**: Run all tests with race detection and code coverage
- **Lint**: Check code quality with golangci-lint
- **Build**: Verify the project builds successfully
- **Multi-version**: Test against multiple Go versions (1.23.x, 1.24.x)
* **Test**: Run all tests with race detection and code coverage
* **Lint**: Check code quality with golangci-lint
* **Build**: Verify the project builds successfully
* **Multi-version**: Test against multiple Go versions (1.23.x, 1.24.x)
### Running Tests Locally
```bash
```Shell
# Run all tests
go test -v ./...
@ -905,13 +1064,13 @@ golangci-lint run
The project includes comprehensive test coverage:
- **Unit Tests**: Individual component testing
- **Integration Tests**: End-to-end API testing
- **CRUD Tests**: Standalone tests for both ResolveSpec and RestHeadSpec APIs
* **Unit Tests**: Individual component testing
* **Integration Tests**: End-to-end API testing
* **CRUD Tests**: Standalone tests for both ResolveSpec and RestHeadSpec APIs
To run only the CRUD standalone tests:
```bash
```Shell
go test -v ./tests -run TestCRUDStandalone
```
@ -923,18 +1082,18 @@ Check the [Actions tab](../../actions) on GitHub to see the status of recent CI
Add this badge to display CI status in your fork:
```markdown
```Markdown
![Tests](https://github.com/bitechdev/ResolveSpec/workflows/Tests/badge.svg)
```
## Security Considerations
- Implement proper authentication and authorization
- Validate all input parameters
- Use prepared statements (handled by GORM/Bun/your ORM)
- Implement rate limiting
- Control access at schema/entity level
- **New**: Database abstraction layer provides additional security through interface boundaries
* Implement proper authentication and authorization
* Validate all input parameters
* Use prepared statements (handled by GORM/Bun/your ORM)
* Implement rate limiting
* Control access at schema/entity level
* **New**: Database abstraction layer provides additional security through interface boundaries
## Contributing
@ -946,73 +1105,114 @@ Add this badge to display CI status in your fork:
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## What's New
### v2.1 (Latest)
### v3.0 (Latest - December 2025)
**Explicit Route Registration (🆕)**:
* **Breaking Change**: Routes are now created explicitly for each registered model
* **Better Control**: Customize routes per model with more flexibility
* **Registration Order**: Models must be registered BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
* **Benefits**: More flexible routing, easier to add custom routes per model, better performance
**OPTIONS Method & CORS Support (🆕)**:
* **OPTIONS Endpoint**: Full OPTIONS method support for CORS preflight requests
* **Metadata Response**: OPTIONS returns model metadata (same as GET /metadata)
* **CORS Headers**: Comprehensive CORS headers on all responses
* **Header Support**: All HeadSpec custom headers (`X-Select-Fields`, `X-FieldFilter-*`, etc.) allowed
* **No Auth on OPTIONS**: CORS preflight requests don't require authentication
* **Configurable**: Customize CORS settings via `common.CORSConfig`
**Migration Notes**:
* Update your code to register models BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
* Routes like `/public/users` are now created per registered model instead of using dynamic `/{schema}/{entity}` pattern
* This is a **breaking change** but provides better control and flexibility
### v2.1
**Cursor Pagination for ResolveSpec (🆕 Dec 9, 2025)**:
* **Cursor-Based Pagination**: Efficient cursor pagination now available in ResolveSpec (body-based API)
* **Consistent with RestHeadSpec**: Both APIs now support cursor pagination for feature parity
* **Multi-Column Sort Support**: Works seamlessly with complex sorting requirements
* **Better Performance**: Improved performance for large datasets compared to offset pagination
* **SQL Safety**: Proper SQL sanitization for cursor values
**Recursive CRUD Handler (🆕 Nov 11, 2025)**:
- **Nested Object Graphs**: Automatically handle complex object hierarchies with parent-child relationships
- **Foreign Key Resolution**: Automatic propagation of parent IDs to child records
- **Per-Record Operations**: Control create/update/delete operations per record via `_request` field
- **Transaction Safety**: All nested operations execute atomically within database transactions
- **Relationship Detection**: Automatic detection of belongsTo, hasMany, hasOne, and many2many relationships
- **Deep Nesting Support**: Handle relationships at any depth level
- **Mixed Operations**: Combine insert, update, and delete operations in a single request
* **Nested Object Graphs**: Automatically handle complex object hierarchies with parent-child relationships
* **Foreign Key Resolution**: Automatic propagation of parent IDs to child records
* **Per-Record Operations**: Control create/update/delete operations per record via `_request` field
* **Transaction Safety**: All nested operations execute atomically within database transactions
* **Relationship Detection**: Automatic detection of belongsTo, hasMany, hasOne, and many2many relationships
* **Deep Nesting Support**: Handle relationships at any depth level
* **Mixed Operations**: Combine insert, update, and delete operations in a single request
**Primary Key Improvements (Nov 11, 2025)**:
- **GetPrimaryKeyName**: Enhanced primary key detection for better preload and ID field handling
- **Better GORM/Bun Support**: Improved compatibility with both ORMs for primary key operations
- **Computed Column Support**: Fixed computed columns functionality across handlers
* **GetPrimaryKeyName**: Enhanced primary key detection for better preload and ID field handling
* **Better GORM/Bun Support**: Improved compatibility with both ORMs for primary key operations
* **Computed Column Support**: Fixed computed columns functionality across handlers
**Database Adapter Enhancements (Nov 11, 2025)**:
- **Bun ORM Relations**: Using Scan model method for better has-many and many-to-many relationship handling
- **Model Method Support**: Enhanced query building with proper model registration
- **Improved Type Safety**: Better handling of relationship queries with type-aware scanning
* **Bun ORM Relations**: Using Scan model method for better has-many and many-to-many relationship handling
* **Model Method Support**: Enhanced query building with proper model registration
* **Improved Type Safety**: Better handling of relationship queries with type-aware scanning
**RestHeadSpec - Header-Based REST API**:
- **Header-Based Querying**: All query options via HTTP headers instead of request body
- **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
- **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
- **Advanced Filtering**: Field filters, search operators, AND/OR logic
- **Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible responses
- **Single Record as Object**: Automatically return single-element arrays as objects (default, toggleable via header)
- **Base64 Support**: Base64-encoded header values for complex queries
- **Type-Aware Filtering**: Automatic type detection and conversion for filters
* **Header-Based Querying**: All query options via HTTP headers instead of request body
* **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
* **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
* **Advanced Filtering**: Field filters, search operators, AND/OR logic
* **Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible responses
* **Single Record as Object**: Automatically return single-element arrays as objects (default, toggleable via header)
* **Base64 Support**: Base64-encoded header values for complex queries
* **Type-Aware Filtering**: Automatic type detection and conversion for filters
**Core Improvements**:
- Better model registry with schema.table format support
- Enhanced validation and error handling
- Improved reflection safety
- Fixed COUNT query issues with table aliasing
- Better pointer handling throughout the codebase
- **Comprehensive Test Coverage**: Added standalone CRUD tests for both ResolveSpec and RestHeadSpec
* Better model registry with schema.table format support
* Enhanced validation and error handling
* Improved reflection safety
* Fixed COUNT query issues with table aliasing
* Better pointer handling throughout the codebase
* **Comprehensive Test Coverage**: Added standalone CRUD tests for both ResolveSpec and RestHeadSpec
### v2.0
**Breaking Changes**:
- **None!** Full backward compatibility maintained
* **None!** Full backward compatibility maintained
**New Features**:
- **Database Abstraction**: Support for GORM, Bun, and custom ORMs
- **Router Flexibility**: Works with any HTTP router through adapters
- **BunRouter Integration**: Built-in support for uptrace/bunrouter
- **Better Architecture**: Clean separation of concerns with interfaces
- **Enhanced Testing**: Mockable interfaces for comprehensive testing
- **Migration Guide**: Step-by-step migration instructions
* **Database Abstraction**: Support for GORM, Bun, and custom ORMs
* **Router Flexibility**: Works with any HTTP router through adapters
* **BunRouter Integration**: Built-in support for uptrace/bunrouter
* **Better Architecture**: Clean separation of concerns with interfaces
* **Enhanced Testing**: Mockable interfaces for comprehensive testing
* **Migration Guide**: Step-by-step migration instructions
**Performance Improvements**:
- More efficient query building through interface design
- Reduced coupling between components
- Better memory management with interface boundaries
* More efficient query building through interface design
* Reduced coupling between components
* Better memory management with interface boundaries
## Acknowledgments
- Inspired by REST, OData, and GraphQL's flexibility
- **Header-based approach**: Inspired by REST best practices and clean API design
- **Database Support**: [GORM](https://gorm.io) and [Bun](https://bun.uptrace.dev/)
- **Router Support**: Gorilla Mux (built-in), BunRouter, Gin, Echo, and others through adapters
- Slogan generated using DALL-E
- AI used for documentation checking and correction
- Community feedback and contributions that made v2.0 and v2.1 possible
* Inspired by REST, OData, and GraphQL's flexibility
* **Header-based approach**: Inspired by REST best practices and clean API design
* **Database Support**: [GORM](https://gorm.io) and [Bun](https://bun.uptrace.dev/)
* **Router Support**: Gorilla Mux (built-in), BunRouter, Gin, Echo, and others through adapters
* Slogan generated using DALL-E
* AI used for documentation checking and correction
* Community feedback and contributions that made v2.0 and v2.1 possible

View File

@ -6,8 +6,10 @@ import (
"os"
"time"
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/server"
"github.com/bitechdev/ResolveSpec/pkg/testmodels"
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
@ -19,12 +21,27 @@ import (
)
func main() {
// Initialize logger
logger.Init(true)
// Load configuration
cfgMgr := config.NewManager()
if err := cfgMgr.Load(); err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
cfg, err := cfgMgr.GetConfig()
if err != nil {
log.Fatalf("Failed to get configuration: %v", err)
}
// Initialize logger with configuration
logger.Init(cfg.Logger.Dev)
if cfg.Logger.Path != "" {
logger.UpdateLoggerPath(cfg.Logger.Path, cfg.Logger.Dev)
}
logger.Info("ResolveSpec test server starting")
logger.Info("Configuration loaded - Server will listen on: %s", cfg.Server.Addr)
// Initialize database
db, err := initDB()
db, err := initDB(cfg)
if err != nil {
logger.Error("Failed to initialize database: %+v", err)
os.Exit(1)
@ -47,32 +64,54 @@ func main() {
handler.RegisterModel("public", modelNames[i], model)
}
// Setup routes using new SetupMuxRoutes function
resolvespec.SetupMuxRoutes(r, handler)
// Setup routes using new SetupMuxRoutes function (without authentication)
resolvespec.SetupMuxRoutes(r, handler, nil)
// Start server
logger.Info("Starting server on :8080")
if err := http.ListenAndServe(":8080", r); err != nil {
// Create graceful server with configuration
srv := server.NewGracefulServer(server.Config{
Addr: cfg.Server.Addr,
Handler: r,
ShutdownTimeout: cfg.Server.ShutdownTimeout,
DrainTimeout: cfg.Server.DrainTimeout,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
})
// Start server with graceful shutdown
logger.Info("Starting server on %s", cfg.Server.Addr)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("Server failed to start: %v", err)
os.Exit(1)
}
}
func initDB() (*gorm.DB, error) {
func initDB(cfg *config.Config) (*gorm.DB, error) {
// Configure GORM logger based on config
logLevel := gormlog.Info
if !cfg.Logger.Dev {
logLevel = gormlog.Warn
}
newLogger := gormlog.New(
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
gormlog.Config{
SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: gormlog.Info, // Log level
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
ParameterizedQueries: true, // Don't include params in the SQL log
Colorful: true, // Disable color
SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: logLevel, // Log level
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
ParameterizedQueries: true, // Don't include params in the SQL log
Colorful: cfg.Logger.Dev,
},
)
// Use database URL from config if available, otherwise use default SQLite
dbURL := cfg.Database.URL
if dbURL == "" {
dbURL = "test.db"
}
// Create SQLite database
db, err := gorm.Open(sqlite.Open("test.db"), &gorm.Config{Logger: newLogger, FullSaveAssociations: false})
db, err := gorm.Open(sqlite.Open(dbURL), &gorm.Config{Logger: newLogger, FullSaveAssociations: false})
if err != nil {
return nil, err
}

41
config.yaml Normal file
View File

@ -0,0 +1,41 @@
# ResolveSpec Test Server Configuration
# This is a minimal configuration for the test server
server:
addr: ":8080"
shutdown_timeout: 30s
drain_timeout: 25s
read_timeout: 10s
write_timeout: 10s
idle_timeout: 120s
logger:
dev: true # Enable development mode for readable logs
path: "" # Empty means log to stdout
cache:
provider: "memory"
middleware:
rate_limit_rps: 100.0
rate_limit_burst: 200
max_request_size: 10485760 # 10MB
cors:
allowed_origins:
- "*"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
- "OPTIONS"
allowed_headers:
- "*"
max_age: 3600
tracing:
enabled: false
database:
url: "" # Empty means use default SQLite (test.db)

57
config.yaml.example Normal file
View File

@ -0,0 +1,57 @@
# ResolveSpec Configuration Example
# This file demonstrates all available configuration options
# Copy this file to config.yaml and customize as needed
server:
addr: ":8080"
shutdown_timeout: 30s
drain_timeout: 25s
read_timeout: 10s
write_timeout: 10s
idle_timeout: 120s
tracing:
enabled: false
service_name: "resolvespec"
service_version: "1.0.0"
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
cache:
provider: "memory" # Options: memory, redis, memcache
redis:
host: "localhost"
port: 6379
password: ""
db: 0
memcache:
servers:
- "localhost:11211"
max_idle_conns: 10
timeout: 100ms
logger:
dev: false
path: "" # Empty for stdout, or specify file path
middleware:
rate_limit_rps: 100.0
rate_limit_burst: 200
max_request_size: 10485760 # 10MB in bytes
cors:
allowed_origins:
- "*"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
- "OPTIONS"
allowed_headers:
- "*"
max_age: 3600
database:
url: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable"

27
docker-compose.yml Normal file
View File

@ -0,0 +1,27 @@
services:
postgres-test:
image: postgres:15-alpine
container_name: resolvespec-postgres-test
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
ports:
- "5434:5432"
volumes:
- postgres-test-data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
timeout: 5s
retries: 5
networks:
- resolvespec-test
volumes:
postgres-test-data:
driver: local
networks:
resolvespec-test:
driver: bridge

59
go.mod
View File

@ -1,45 +1,94 @@
module github.com/bitechdev/ResolveSpec
go 1.23.0
go 1.24.0
toolchain go1.24.6
require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf
github.com/getsentry/sentry-go v0.40.0
github.com/glebarez/sqlite v1.11.0
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1
github.com/stretchr/testify v1.8.1
github.com/prometheus/client_golang v1.23.2
github.com/redis/go-redis/v9 v9.17.1
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/uptrace/bun v1.2.15
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15
github.com/uptrace/bun/driver/sqliteshim v1.2.15
github.com/uptrace/bunrouter v1.0.23
go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0
go.opentelemetry.io/otel/sdk v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/zap v1.27.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.6.0
gorm.io/gorm v1.25.12
)
require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.28 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.uber.org/multierr v1.10.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.41.0 // indirect
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect
golang.org/x/sys v0.34.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.28.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/grpc v1.75.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.66.3 // indirect
modernc.org/mathutil v1.7.1 // indirect

150
go.sum
View File

@ -1,45 +1,127 @@
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs=
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@ -63,31 +145,71 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4=
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=

340
pkg/cache/README.md vendored Normal file
View File

@ -0,0 +1,340 @@
# Cache Package
A flexible, provider-based caching library for Go that supports multiple backend storage systems including in-memory, Redis, and Memcache.
## Features
- **Multiple Providers**: Support for in-memory, Redis, and Memcache backends
- **Pluggable Architecture**: Easy to add custom cache providers
- **Type-Safe API**: Automatic JSON serialization/deserialization
- **TTL Support**: Configurable time-to-live for cache entries
- **Context-Aware**: All operations support Go contexts
- **Statistics**: Built-in cache statistics and monitoring
- **Pattern Deletion**: Delete keys by pattern (Redis)
- **Lazy Loading**: GetOrSet pattern for easy cache-aside implementation
## Installation
```bash
go get github.com/bitechdev/ResolveSpec/pkg/cache
```
For Redis support:
```bash
go get github.com/redis/go-redis/v9
```
For Memcache support:
```bash
go get github.com/bradfitz/gomemcache/memcache
```
## Quick Start
### In-Memory Cache
```go
package main
import (
"context"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
func main() {
// Initialize with in-memory provider
cache.UseMemory(&cache.Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
})
defer cache.Close()
ctx := context.Background()
c := cache.GetDefaultCache()
// Store a value
type User struct {
ID int
Name string
}
user := User{ID: 1, Name: "John"}
c.Set(ctx, "user:1", user, 10*time.Minute)
// Retrieve a value
var retrieved User
c.Get(ctx, "user:1", &retrieved)
}
```
### Redis Cache
```go
cache.UseRedis(&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "",
DB: 0,
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
})
defer cache.Close()
```
### Memcache
```go
cache.UseMemcache(&cache.MemcacheConfig{
Servers: []string{"localhost:11211"},
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
})
defer cache.Close()
```
## API Reference
### Core Methods
#### Set
```go
Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error
```
Stores a value in the cache with automatic JSON serialization.
#### Get
```go
Get(ctx context.Context, key string, dest interface{}) error
```
Retrieves and deserializes a value from the cache.
#### SetBytes / GetBytes
```go
SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error
GetBytes(ctx context.Context, key string) ([]byte, error)
```
Store and retrieve raw bytes without serialization.
#### Delete
```go
Delete(ctx context.Context, key string) error
```
Removes a key from the cache.
#### DeleteByPattern
```go
DeleteByPattern(ctx context.Context, pattern string) error
```
Removes all keys matching a pattern (Redis only).
#### Clear
```go
Clear(ctx context.Context) error
```
Removes all items from the cache.
#### Exists
```go
Exists(ctx context.Context, key string) bool
```
Checks if a key exists in the cache.
#### GetOrSet
```go
GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration,
loader func() (interface{}, error)) error
```
Retrieves a value from cache, or loads and caches it if not found (lazy loading).
#### Stats
```go
Stats(ctx context.Context) (*CacheStats, error)
```
Returns cache statistics including hits, misses, and key counts.
## Provider Configuration
### In-Memory Options
```go
&cache.Options{
DefaultTTL: 5 * time.Minute, // Default expiration time
MaxSize: 10000, // Maximum number of items
EvictionPolicy: "LRU", // Eviction strategy (future)
}
```
### Redis Configuration
```go
&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "", // Optional authentication
DB: 0, // Database number
PoolSize: 10, // Connection pool size
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
}
```
### Memcache Configuration
```go
&cache.MemcacheConfig{
Servers: []string{"localhost:11211"},
MaxIdleConns: 2,
Timeout: 1 * time.Second,
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
}
```
## Advanced Usage
### Custom Provider
```go
// Create a custom provider instance
memProvider := cache.NewMemoryProvider(&cache.Options{
DefaultTTL: 10 * time.Minute,
MaxSize: 500,
})
// Initialize with custom provider
cache.Initialize(memProvider)
```
### Lazy Loading Pattern
```go
var data ExpensiveData
err := c.GetOrSet(ctx, "expensive:key", &data, 10*time.Minute, func() (interface{}, error) {
// This expensive operation only runs if key is not in cache
return computeExpensiveData(), nil
})
```
### Query API Cache
The package includes specialized functions for caching query results:
```go
// Cache a query result
api := "GetUsers"
query := "SELECT * FROM users WHERE active = true"
tablenames := "users"
total := int64(150)
cache.PutQueryAPICache(ctx, api, query, tablenames, total)
// Retrieve cached query
hash := cache.HashQueryAPICache(api, query)
cachedQuery, err := cache.FetchQueryAPICache(ctx, hash)
```
## Provider Comparison
| Feature | In-Memory | Redis | Memcache |
|---------|-----------|-------|----------|
| Persistence | No | Yes | No |
| Distributed | No | Yes | Yes |
| Pattern Delete | No | Yes | No |
| Statistics | Full | Full | Limited |
| Atomic Operations | Yes | Yes | Yes |
| Max Item Size | Memory | 512MB | 1MB |
## Best Practices
1. **Use contexts**: Always pass context for cancellation and timeout control
2. **Set appropriate TTLs**: Balance between freshness and performance
3. **Handle errors**: Cache misses and errors should be handled gracefully
4. **Monitor statistics**: Use Stats() to monitor cache performance
5. **Clean up**: Always call Close() when shutting down
6. **Pattern consistency**: Use consistent key naming patterns (e.g., "user:id:field")
## Example: Complete Application
```go
package main
import (
"context"
"log"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
type UserService struct {
cache *cache.Cache
}
func NewUserService() *UserService {
// Initialize with Redis in production, memory for testing
cache.UseRedis(&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Options: &cache.Options{
DefaultTTL: 10 * time.Minute,
},
})
return &UserService{
cache: cache.GetDefaultCache(),
}
}
func (s *UserService) GetUser(ctx context.Context, userID int) (*User, error) {
var user User
cacheKey := fmt.Sprintf("user:%d", userID)
// Try to get from cache first
err := s.cache.GetOrSet(ctx, cacheKey, &user, 15*time.Minute, func() (interface{}, error) {
// Load from database if not in cache
return s.loadUserFromDB(userID)
})
if err != nil {
return nil, err
}
return &user, nil
}
func (s *UserService) InvalidateUser(ctx context.Context, userID int) error {
cacheKey := fmt.Sprintf("user:%d", userID)
return s.cache.Delete(ctx, cacheKey)
}
func main() {
service := NewUserService()
defer cache.Close()
ctx := context.Background()
user, err := service.GetUser(ctx, 123)
if err != nil {
log.Fatal(err)
}
log.Printf("User: %+v", user)
}
```
## Performance Considerations
- **In-Memory**: Fastest but limited by RAM and not distributed
- **Redis**: Great for distributed systems, persistent, but network overhead
- **Memcache**: Good for distributed caching, simpler than Redis but less features
Choose based on your needs:
- Single instance? Use in-memory
- Need persistence or advanced features? Use Redis
- Simple distributed cache? Use Memcache
## License
See repository license.

76
pkg/cache/cache.go vendored Normal file
View File

@ -0,0 +1,76 @@
package cache
import (
"context"
"fmt"
"time"
)
var (
defaultCache *Cache
)
// Initialize initializes the cache with a provider.
// If not called, the package will use an in-memory provider by default.
func Initialize(provider Provider) {
defaultCache = NewCache(provider)
}
// UseMemory configures the cache to use in-memory storage.
func UseMemory(opts *Options) error {
provider := NewMemoryProvider(opts)
defaultCache = NewCache(provider)
return nil
}
// UseRedis configures the cache to use Redis storage.
func UseRedis(config *RedisConfig) error {
provider, err := NewRedisProvider(config)
if err != nil {
return fmt.Errorf("failed to initialize Redis provider: %w", err)
}
defaultCache = NewCache(provider)
return nil
}
// UseMemcache configures the cache to use Memcache storage.
func UseMemcache(config *MemcacheConfig) error {
provider, err := NewMemcacheProvider(config)
if err != nil {
return fmt.Errorf("failed to initialize Memcache provider: %w", err)
}
defaultCache = NewCache(provider)
return nil
}
// GetDefaultCache returns the default cache instance.
// Initializes with in-memory provider if not already initialized.
func GetDefaultCache() *Cache {
if defaultCache == nil {
_ = UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
})
}
return defaultCache
}
// SetDefaultCache sets a custom cache instance as the default cache.
// This is useful for testing or when you want to use a pre-configured cache instance.
func SetDefaultCache(cache *Cache) {
defaultCache = cache
}
// GetStats returns cache statistics.
func GetStats(ctx context.Context) (*CacheStats, error) {
cache := GetDefaultCache()
return cache.Stats(ctx)
}
// Close closes the cache and releases resources.
func Close() error {
if defaultCache != nil {
return defaultCache.Close()
}
return nil
}

147
pkg/cache/cache_manager.go vendored Normal file
View File

@ -0,0 +1,147 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
)
// Cache is the main cache manager that wraps a Provider.
type Cache struct {
provider Provider
}
// NewCache creates a new cache manager with the specified provider.
func NewCache(provider Provider) *Cache {
return &Cache{
provider: provider,
}
}
// Get retrieves and deserializes a value from the cache.
func (c *Cache) Get(ctx context.Context, key string, dest interface{}) error {
data, exists := c.provider.Get(ctx, key)
if !exists {
return fmt.Errorf("key not found: %s", key)
}
if err := json.Unmarshal(data, dest); err != nil {
return fmt.Errorf("failed to deserialize: %w", err)
}
return nil
}
// GetBytes retrieves raw bytes from the cache.
func (c *Cache) GetBytes(ctx context.Context, key string) ([]byte, error) {
data, exists := c.provider.Get(ctx, key)
if !exists {
return nil, fmt.Errorf("key not found: %s", key)
}
return data, nil
}
// Set serializes and stores a value in the cache with the specified TTL.
func (c *Cache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to serialize: %w", err)
}
return c.provider.Set(ctx, key, data, ttl)
}
// SetBytes stores raw bytes in the cache with the specified TTL.
func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return c.provider.Set(ctx, key, value, ttl)
}
// Delete removes a key from the cache.
func (c *Cache) Delete(ctx context.Context, key string) error {
return c.provider.Delete(ctx, key)
}
// DeleteByPattern removes all keys matching the pattern.
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
return c.provider.DeleteByPattern(ctx, pattern)
}
// Clear removes all items from the cache.
func (c *Cache) Clear(ctx context.Context) error {
return c.provider.Clear(ctx)
}
// Exists checks if a key exists in the cache.
func (c *Cache) Exists(ctx context.Context, key string) bool {
return c.provider.Exists(ctx, key)
}
// Stats returns statistics about the cache.
func (c *Cache) Stats(ctx context.Context) (*CacheStats, error) {
return c.provider.Stats(ctx)
}
// Close closes the cache and releases any resources.
func (c *Cache) Close() error {
return c.provider.Close()
}
// GetOrSet retrieves a value from cache, or sets it if it doesn't exist.
// The loader function is called only if the key is not found in cache.
func (c *Cache) GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration, loader func() (interface{}, error)) error {
// Try to get from cache first
err := c.Get(ctx, key, dest)
if err == nil {
return nil
}
// Load the value
value, err := loader()
if err != nil {
return fmt.Errorf("loader failed: %w", err)
}
// Store in cache
if err := c.Set(ctx, key, value, ttl); err != nil {
return fmt.Errorf("failed to cache value: %w", err)
}
// Populate dest with the loaded value
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to serialize loaded value: %w", err)
}
if err := json.Unmarshal(data, dest); err != nil {
return fmt.Errorf("failed to deserialize loaded value: %w", err)
}
return nil
}
// Remember is a convenience function that caches the result of a function call.
// It's similar to GetOrSet but returns the value directly.
func (c *Cache) Remember(ctx context.Context, key string, ttl time.Duration, loader func() (interface{}, error)) (interface{}, error) {
// Try to get from cache first as bytes
data, err := c.GetBytes(ctx, key)
if err == nil {
var result interface{}
if err := json.Unmarshal(data, &result); err == nil {
return result, nil
}
}
// Load the value
value, err := loader()
if err != nil {
return nil, fmt.Errorf("loader failed: %w", err)
}
// Store in cache
if err := c.Set(ctx, key, value, ttl); err != nil {
return nil, fmt.Errorf("failed to cache value: %w", err)
}
return value, nil
}

69
pkg/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,69 @@
package cache
import (
"context"
"testing"
"time"
)
func TestSetDefaultCache(t *testing.T) {
// Create a custom cache instance
provider := NewMemoryProvider(&Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 50,
})
customCache := NewCache(provider)
// Set it as the default
SetDefaultCache(customCache)
// Verify it's now the default
retrievedCache := GetDefaultCache()
if retrievedCache != customCache {
t.Error("SetDefaultCache did not set the cache correctly")
}
// Test that we can use it
ctx := context.Background()
testKey := "test_key"
testValue := "test_value"
err := retrievedCache.Set(ctx, testKey, testValue, time.Minute)
if err != nil {
t.Fatalf("Failed to set value: %v", err)
}
var result string
err = retrievedCache.Get(ctx, testKey, &result)
if err != nil {
t.Fatalf("Failed to get value: %v", err)
}
if result != testValue {
t.Errorf("Expected %s, got %s", testValue, result)
}
// Clean up - reset to default
SetDefaultCache(nil)
}
func TestGetDefaultCacheInitialization(t *testing.T) {
// Reset to nil first
SetDefaultCache(nil)
// GetDefaultCache should auto-initialize
cache := GetDefaultCache()
if cache == nil {
t.Error("GetDefaultCache should auto-initialize, got nil")
}
// Should be usable
ctx := context.Background()
err := cache.Set(ctx, "test", "value", time.Minute)
if err != nil {
t.Errorf("Failed to use auto-initialized cache: %v", err)
}
// Clean up
SetDefaultCache(nil)
}

266
pkg/cache/example_usage.go vendored Normal file
View File

@ -0,0 +1,266 @@
package cache
import (
"context"
"fmt"
"log"
"time"
)
// ExampleInMemoryCache demonstrates using the in-memory cache provider.
func ExampleInMemoryCache() {
// Initialize with in-memory provider
err := UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 1000,
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store a value
type User struct {
ID int
Name string
}
user := User{ID: 1, Name: "John Doe"}
err = cache.Set(ctx, "user:1", user, 10*time.Minute)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Retrieve a value
var retrieved User
err = cache.Get(ctx, "user:1", &retrieved)
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Retrieved user: %+v\n", retrieved)
// Check if key exists
exists := cache.Exists(ctx, "user:1")
fmt.Printf("Key exists: %v\n", exists)
// Delete a key
err = cache.Delete(ctx, "user:1")
if err != nil {
_ = Close()
log.Fatal(err)
}
// Get statistics
stats, err := cache.Stats(ctx)
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Cache stats: %+v\n", stats)
_ = Close()
}
// ExampleRedisCache demonstrates using the Redis cache provider.
func ExampleRedisCache() {
// Initialize with Redis provider
err := UseRedis(&RedisConfig{
Host: "localhost",
Port: 6379,
Password: "", // Set if Redis requires authentication
DB: 0,
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store raw bytes
data := []byte("Hello, Redis!")
err = cache.SetBytes(ctx, "greeting", data, 1*time.Hour)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Retrieve raw bytes
retrieved, err := cache.GetBytes(ctx, "greeting")
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Retrieved data: %s\n", string(retrieved))
// Clear all cache
err = cache.Clear(ctx)
if err != nil {
_ = Close()
log.Fatal(err)
}
_ = Close()
}
// ExampleMemcacheCache demonstrates using the Memcache cache provider.
func ExampleMemcacheCache() {
// Initialize with Memcache provider
err := UseMemcache(&MemcacheConfig{
Servers: []string{"localhost:11211"},
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store a value
type Product struct {
ID int
Name string
Price float64
}
product := Product{ID: 100, Name: "Widget", Price: 29.99}
err = cache.Set(ctx, "product:100", product, 30*time.Minute)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Retrieve a value
var retrieved Product
err = cache.Get(ctx, "product:100", &retrieved)
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Retrieved product: %+v\n", retrieved)
_ = Close()
}
// ExampleGetOrSet demonstrates the GetOrSet pattern for lazy loading.
func ExampleGetOrSet() {
err := UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 1000,
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
cache := GetDefaultCache()
type ExpensiveData struct {
Result string
}
var data ExpensiveData
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
// This expensive operation only runs if the key is not in cache
fmt.Println("Computing expensive result...")
time.Sleep(1 * time.Second)
return ExpensiveData{Result: "computed value"}, nil
})
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Data: %+v\n", data)
// Second call will use cached value
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
fmt.Println("This won't be called!")
return ExpensiveData{Result: "new value"}, nil
})
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Cached data: %+v\n", data)
_ = Close()
}
// ExampleCustomProvider demonstrates using a custom provider.
func ExampleCustomProvider() {
// Create a custom provider
memProvider := NewMemoryProvider(&Options{
DefaultTTL: 10 * time.Minute,
MaxSize: 500,
})
// Initialize with custom provider
Initialize(memProvider)
ctx := context.Background()
cache := GetDefaultCache()
// Use the cache
err := cache.SetBytes(ctx, "key", []byte("value"), 5*time.Minute)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Clean expired items (memory provider specific)
if mp, ok := cache.provider.(*MemoryProvider); ok {
count := mp.CleanExpired(ctx)
fmt.Printf("Cleaned %d expired items\n", count)
}
_ = Close()
}
// ExampleDeleteByPattern demonstrates pattern-based deletion (Redis only).
func ExampleDeleteByPattern() {
err := UseRedis(&RedisConfig{
Host: "localhost",
Port: 6379,
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
cache := GetDefaultCache()
// Store multiple keys with a pattern
_ = cache.SetBytes(ctx, "user:1:profile", []byte("profile1"), 10*time.Minute)
_ = cache.SetBytes(ctx, "user:2:profile", []byte("profile2"), 10*time.Minute)
_ = cache.SetBytes(ctx, "user:1:settings", []byte("settings1"), 10*time.Minute)
// Delete all keys matching pattern (Redis glob pattern)
err = cache.DeleteByPattern(ctx, "user:*:profile")
if err != nil {
_ = Close()
log.Print(err)
return
}
fmt.Println("Deleted all user profile keys")
_ = Close()
}

57
pkg/cache/provider.go vendored Normal file
View File

@ -0,0 +1,57 @@
package cache
import (
"context"
"time"
)
// Provider defines the interface that all cache providers must implement.
type Provider interface {
// Get retrieves a value from the cache by key.
// Returns nil, false if key doesn't exist or is expired.
Get(ctx context.Context, key string) ([]byte, bool)
// Set stores a value in the cache with the specified TTL.
// If ttl is 0, the item never expires.
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
// Delete removes a key from the cache.
Delete(ctx context.Context, key string) error
// DeleteByPattern removes all keys matching the pattern.
// Pattern syntax depends on the provider implementation.
DeleteByPattern(ctx context.Context, pattern string) error
// Clear removes all items from the cache.
Clear(ctx context.Context) error
// Exists checks if a key exists in the cache.
Exists(ctx context.Context, key string) bool
// Close closes the provider and releases any resources.
Close() error
// Stats returns statistics about the cache provider.
Stats(ctx context.Context) (*CacheStats, error)
}
// CacheStats contains cache statistics.
type CacheStats struct {
Hits int64 `json:"hits"`
Misses int64 `json:"misses"`
Keys int64 `json:"keys"`
ProviderType string `json:"provider_type"`
ProviderStats map[string]any `json:"provider_stats,omitempty"`
}
// Options contains configuration options for cache providers.
type Options struct {
// DefaultTTL is the default time-to-live for cache items.
DefaultTTL time.Duration
// MaxSize is the maximum number of items (for in-memory provider).
MaxSize int
// EvictionPolicy determines how items are evicted (LRU, LFU, etc).
EvictionPolicy string
}

144
pkg/cache/provider_memcache.go vendored Normal file
View File

@ -0,0 +1,144 @@
package cache
import (
"context"
"fmt"
"time"
"github.com/bradfitz/gomemcache/memcache"
)
// MemcacheProvider is a Memcache implementation of the Provider interface.
type MemcacheProvider struct {
client *memcache.Client
options *Options
}
// MemcacheConfig contains Memcache-specific configuration.
type MemcacheConfig struct {
// Servers is a list of memcache server addresses (e.g., "localhost:11211")
Servers []string
// MaxIdleConns is the maximum number of idle connections (default: 2)
MaxIdleConns int
// Timeout for connection operations (default: 1 second)
Timeout time.Duration
// Options contains general cache options
Options *Options
}
// NewMemcacheProvider creates a new Memcache cache provider.
func NewMemcacheProvider(config *MemcacheConfig) (*MemcacheProvider, error) {
if config == nil {
config = &MemcacheConfig{
Servers: []string{"localhost:11211"},
}
}
if len(config.Servers) == 0 {
config.Servers = []string{"localhost:11211"}
}
if config.MaxIdleConns == 0 {
config.MaxIdleConns = 2
}
if config.Timeout == 0 {
config.Timeout = 1 * time.Second
}
if config.Options == nil {
config.Options = &Options{
DefaultTTL: 5 * time.Minute,
}
}
client := memcache.New(config.Servers...)
client.MaxIdleConns = config.MaxIdleConns
client.Timeout = config.Timeout
// Test connection
if err := client.Ping(); err != nil {
return nil, fmt.Errorf("failed to connect to Memcache: %w", err)
}
return &MemcacheProvider{
client: client,
options: config.Options,
}, nil
}
// Get retrieves a value from the cache by key.
func (m *MemcacheProvider) Get(ctx context.Context, key string) ([]byte, bool) {
item, err := m.client.Get(key)
if err == memcache.ErrCacheMiss {
return nil, false
}
if err != nil {
return nil, false
}
return item.Value, true
}
// Set stores a value in the cache with the specified TTL.
func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if ttl == 0 {
ttl = m.options.DefaultTTL
}
item := &memcache.Item{
Key: key,
Value: value,
Expiration: int32(ttl.Seconds()),
}
return m.client.Set(item)
}
// Delete removes a key from the cache.
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
err := m.client.Delete(key)
if err == memcache.ErrCacheMiss {
return nil
}
return err
}
// DeleteByPattern removes all keys matching the pattern.
// Note: Memcache does not support pattern-based deletion natively.
// This is a no-op for memcache and returns an error.
func (m *MemcacheProvider) DeleteByPattern(ctx context.Context, pattern string) error {
return fmt.Errorf("pattern-based deletion is not supported by Memcache")
}
// Clear removes all items from the cache.
func (m *MemcacheProvider) Clear(ctx context.Context) error {
return m.client.FlushAll()
}
// Exists checks if a key exists in the cache.
func (m *MemcacheProvider) Exists(ctx context.Context, key string) bool {
_, err := m.client.Get(key)
return err == nil
}
// Close closes the provider and releases any resources.
func (m *MemcacheProvider) Close() error {
// Memcache client doesn't have a close method
return nil
}
// Stats returns statistics about the cache provider.
// Note: Memcache provider returns limited statistics.
func (m *MemcacheProvider) Stats(ctx context.Context) (*CacheStats, error) {
stats := &CacheStats{
ProviderType: "memcache",
ProviderStats: map[string]any{
"note": "Memcache does not provide detailed statistics through the standard client",
},
}
return stats, nil
}

238
pkg/cache/provider_memory.go vendored Normal file
View File

@ -0,0 +1,238 @@
package cache
import (
"context"
"fmt"
"regexp"
"sync"
"sync/atomic"
"time"
)
// memoryItem represents a cached item in memory.
type memoryItem struct {
Value []byte
Expiration time.Time
LastAccess time.Time
HitCount int64
}
// isExpired checks if the item has expired.
func (m *memoryItem) isExpired() bool {
if m.Expiration.IsZero() {
return false
}
return time.Now().After(m.Expiration)
}
// MemoryProvider is an in-memory implementation of the Provider interface.
type MemoryProvider struct {
mu sync.RWMutex
items map[string]*memoryItem
options *Options
hits atomic.Int64
misses atomic.Int64
}
// NewMemoryProvider creates a new in-memory cache provider.
func NewMemoryProvider(opts *Options) *MemoryProvider {
if opts == nil {
opts = &Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
}
}
return &MemoryProvider{
items: make(map[string]*memoryItem),
options: opts,
}
}
// Get retrieves a value from the cache by key.
func (m *MemoryProvider) Get(ctx context.Context, key string) ([]byte, bool) {
// First try with read lock for fast path
m.mu.RLock()
item, exists := m.items[key]
if !exists {
m.mu.RUnlock()
m.misses.Add(1)
return nil, false
}
if item.isExpired() {
m.mu.RUnlock()
// Upgrade to write lock to delete expired item
m.mu.Lock()
delete(m.items, key)
m.mu.Unlock()
m.misses.Add(1)
return nil, false
}
// Update stats and access time with write lock
value := item.Value
m.mu.RUnlock()
// Update access tracking with write lock
m.mu.Lock()
item.LastAccess = time.Now()
item.HitCount++
m.mu.Unlock()
m.hits.Add(1)
return value, true
}
// Set stores a value in the cache with the specified TTL.
func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
if ttl == 0 {
ttl = m.options.DefaultTTL
}
var expiration time.Time
if ttl > 0 {
expiration = time.Now().Add(ttl)
}
// Check max size and evict if necessary
if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize {
if _, exists := m.items[key]; !exists {
m.evictOne()
}
}
m.items[key] = &memoryItem{
Value: value,
Expiration: expiration,
LastAccess: time.Now(),
}
return nil
}
// Delete removes a key from the cache.
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.items, key)
return nil
}
// DeleteByPattern removes all keys matching the pattern.
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
m.mu.Lock()
defer m.mu.Unlock()
re, err := regexp.Compile(pattern)
if err != nil {
return fmt.Errorf("invalid pattern: %w", err)
}
for key := range m.items {
if re.MatchString(key) {
delete(m.items, key)
}
}
return nil
}
// Clear removes all items from the cache.
func (m *MemoryProvider) Clear(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
m.items = make(map[string]*memoryItem)
m.hits.Store(0)
m.misses.Store(0)
return nil
}
// Exists checks if a key exists in the cache.
func (m *MemoryProvider) Exists(ctx context.Context, key string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
item, exists := m.items[key]
if !exists {
return false
}
return !item.isExpired()
}
// Close closes the provider and releases any resources.
func (m *MemoryProvider) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
m.items = nil
return nil
}
// Stats returns statistics about the cache provider.
func (m *MemoryProvider) Stats(ctx context.Context) (*CacheStats, error) {
m.mu.RLock()
defer m.mu.RUnlock()
// Clean expired items first
validKeys := 0
for _, item := range m.items {
if !item.isExpired() {
validKeys++
}
}
return &CacheStats{
Hits: m.hits.Load(),
Misses: m.misses.Load(),
Keys: int64(validKeys),
ProviderType: "memory",
ProviderStats: map[string]any{
"capacity": m.options.MaxSize,
},
}, nil
}
// evictOne removes one item from the cache using LRU strategy.
func (m *MemoryProvider) evictOne() {
var oldestKey string
var oldestTime time.Time
for key, item := range m.items {
if item.isExpired() {
delete(m.items, key)
return
}
if oldestKey == "" || item.LastAccess.Before(oldestTime) {
oldestKey = key
oldestTime = item.LastAccess
}
}
if oldestKey != "" {
delete(m.items, oldestKey)
}
}
// CleanExpired removes all expired items from the cache.
func (m *MemoryProvider) CleanExpired(ctx context.Context) int {
m.mu.Lock()
defer m.mu.Unlock()
count := 0
for key, item := range m.items {
if item.isExpired() {
delete(m.items, key)
count++
}
}
return count
}

185
pkg/cache/provider_redis.go vendored Normal file
View File

@ -0,0 +1,185 @@
package cache
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
// RedisProvider is a Redis implementation of the Provider interface.
type RedisProvider struct {
client *redis.Client
options *Options
}
// RedisConfig contains Redis-specific configuration.
type RedisConfig struct {
// Host is the Redis server host (default: localhost)
Host string
// Port is the Redis server port (default: 6379)
Port int
// Password for Redis authentication (optional)
Password string
// DB is the Redis database number (default: 0)
DB int
// PoolSize is the maximum number of connections (default: 10)
PoolSize int
// Options contains general cache options
Options *Options
}
// NewRedisProvider creates a new Redis cache provider.
func NewRedisProvider(config *RedisConfig) (*RedisProvider, error) {
if config == nil {
config = &RedisConfig{
Host: "localhost",
Port: 6379,
DB: 0,
}
}
if config.Host == "" {
config.Host = "localhost"
}
if config.Port == 0 {
config.Port = 6379
}
if config.PoolSize == 0 {
config.PoolSize = 10
}
if config.Options == nil {
config.Options = &Options{
DefaultTTL: 5 * time.Minute,
}
}
client := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", config.Host, config.Port),
Password: config.Password,
DB: config.DB,
PoolSize: config.PoolSize,
})
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
return &RedisProvider{
client: client,
options: config.Options,
}, nil
}
// Get retrieves a value from the cache by key.
func (r *RedisProvider) Get(ctx context.Context, key string) ([]byte, bool) {
val, err := r.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, false
}
if err != nil {
return nil, false
}
return val, true
}
// Set stores a value in the cache with the specified TTL.
func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if ttl == 0 {
ttl = r.options.DefaultTTL
}
return r.client.Set(ctx, key, value, ttl).Err()
}
// Delete removes a key from the cache.
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
return r.client.Del(ctx, key).Err()
}
// DeleteByPattern removes all keys matching the pattern.
func (r *RedisProvider) DeleteByPattern(ctx context.Context, pattern string) error {
iter := r.client.Scan(ctx, 0, pattern, 0).Iterator()
pipe := r.client.Pipeline()
count := 0
for iter.Next(ctx) {
pipe.Del(ctx, iter.Val())
count++
// Execute pipeline in batches of 100
if count%100 == 0 {
if _, err := pipe.Exec(ctx); err != nil {
return err
}
pipe = r.client.Pipeline()
}
}
if err := iter.Err(); err != nil {
return err
}
// Execute remaining commands
if count%100 != 0 {
_, err := pipe.Exec(ctx)
return err
}
return nil
}
// Clear removes all items from the cache.
func (r *RedisProvider) Clear(ctx context.Context) error {
return r.client.FlushDB(ctx).Err()
}
// Exists checks if a key exists in the cache.
func (r *RedisProvider) Exists(ctx context.Context, key string) bool {
result, err := r.client.Exists(ctx, key).Result()
if err != nil {
return false
}
return result > 0
}
// Close closes the provider and releases any resources.
func (r *RedisProvider) Close() error {
return r.client.Close()
}
// Stats returns statistics about the cache provider.
func (r *RedisProvider) Stats(ctx context.Context) (*CacheStats, error) {
info, err := r.client.Info(ctx, "stats", "keyspace").Result()
if err != nil {
return nil, fmt.Errorf("failed to get Redis stats: %w", err)
}
dbSize, err := r.client.DBSize(ctx).Result()
if err != nil {
return nil, fmt.Errorf("failed to get DB size: %w", err)
}
// Parse stats from INFO command
// This is a simplified version - you may want to parse more detailed stats
stats := &CacheStats{
Keys: dbSize,
ProviderType: "redis",
ProviderStats: map[string]any{
"info": info,
},
}
return stats, nil
}

127
pkg/cache/query_cache.go vendored Normal file
View File

@ -0,0 +1,127 @@
package cache
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// QueryCacheKey represents the components used to build a cache key for query total count
type QueryCacheKey struct {
TableName string `json:"table_name"`
Filters []common.FilterOption `json:"filters"`
Sort []common.SortOption `json:"sort"`
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
CustomSQLOr string `json:"custom_sql_or,omitempty"`
Expand []ExpandOptionKey `json:"expand,omitempty"`
Distinct bool `json:"distinct,omitempty"`
CursorForward string `json:"cursor_forward,omitempty"`
CursorBackward string `json:"cursor_backward,omitempty"`
}
// ExpandOptionKey represents expand options for cache key
type ExpandOptionKey struct {
Relation string `json:"relation"`
Where string `json:"where,omitempty"`
}
// BuildQueryCacheKey builds a cache key from query parameters for total count caching
// This is used to cache the total count of records matching a query
func BuildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
key := QueryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
}
return hashString(string(jsonData))
}
// BuildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
// Includes expand, distinct, and cursor pagination options
func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
key := QueryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
Distinct: distinct,
CursorForward: cursorFwd,
CursorBackward: cursorBwd,
}
// Convert expand options to cache key format
if len(expandOpts) > 0 {
key.Expand = make([]ExpandOptionKey, 0, len(expandOpts))
for _, exp := range expandOpts {
// Type assert to get the expand option fields we care about for caching
if expMap, ok := exp.(map[string]interface{}); ok {
expKey := ExpandOptionKey{}
if rel, ok := expMap["relation"].(string); ok {
expKey.Relation = rel
}
if where, ok := expMap["where"].(string); ok {
expKey.Where = where
}
key.Expand = append(key.Expand, expKey)
}
}
// Sort expand options for consistent hashing (already sorted by relation name above)
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s",
tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd))
}
return hashString(string(jsonData))
}
// hashString computes SHA256 hash of a string
func hashString(s string) string {
h := sha256.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}
// GetQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
func GetQueryTotalCacheKey(hash string) string {
return fmt.Sprintf("query_total:%s", hash)
}
// CachedTotal represents a cached total count
type CachedTotal struct {
Total int `json:"total"`
}
// InvalidateCacheForTable removes all cached totals for a specific table
// This should be called when data in the table changes (insert/update/delete)
func InvalidateCacheForTable(ctx context.Context, tableName string) error {
cache := GetDefaultCache()
// Build a pattern to match all query totals for this table
// Note: This requires pattern matching support in the provider
pattern := fmt.Sprintf("query_total:*%s*", strings.ToLower(tableName))
return cache.DeleteByPattern(ctx, pattern)
}

151
pkg/cache/query_cache_test.go vendored Normal file
View File

@ -0,0 +1,151 @@
package cache
import (
"context"
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestBuildQueryCacheKey(t *testing.T) {
filters := []common.FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
{Column: "age", Operator: "gt", Value: 25},
}
sorts := []common.SortOption{
{Column: "name", Direction: "asc"},
}
// Generate cache key
key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
// Same parameters should generate same key
key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
if key1 != key2 {
t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2)
}
// Different parameters should generate different key
key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "")
if key1 == key3 {
t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3)
}
}
func TestBuildExtendedQueryCacheKey(t *testing.T) {
filters := []common.FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
}
sorts := []common.SortOption{
{Column: "name", Direction: "asc"},
}
expandOpts := []interface{}{
map[string]interface{}{
"relation": "posts",
"where": "status = 'published'",
},
}
// Generate cache key
key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
// Same parameters should generate same key
key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
if key1 != key2 {
t.Errorf("Expected same cache keys for identical parameters")
}
// Different distinct value should generate different key
key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "")
if key1 == key3 {
t.Errorf("Expected different cache keys for different distinct values")
}
}
func TestGetQueryTotalCacheKey(t *testing.T) {
hash := "abc123"
key := GetQueryTotalCacheKey(hash)
expected := "query_total:abc123"
if key != expected {
t.Errorf("Expected %s, got %s", expected, key)
}
}
func TestCachedTotalIntegration(t *testing.T) {
// Initialize cache with memory provider for testing
UseMemory(&Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 100,
})
ctx := context.Background()
// Create test data
filters := []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
}
sorts := []common.SortOption{
{Column: "created_at", Direction: "desc"},
}
// Build cache key
cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "")
cacheKey := GetQueryTotalCacheKey(cacheKeyHash)
// Store a total count in cache
totalToCache := CachedTotal{Total: 42}
err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute)
if err != nil {
t.Fatalf("Failed to set cache: %v", err)
}
// Retrieve from cache
var cachedTotal CachedTotal
err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
if err != nil {
t.Fatalf("Failed to get from cache: %v", err)
}
if cachedTotal.Total != 42 {
t.Errorf("Expected total 42, got %d", cachedTotal.Total)
}
// Test cache miss
nonExistentKey := GetQueryTotalCacheKey("nonexistent")
var missedTotal CachedTotal
err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal)
if err == nil {
t.Errorf("Expected error for cache miss, got nil")
}
}
func TestHashString(t *testing.T) {
input1 := "test string"
input2 := "test string"
input3 := "different string"
hash1 := hashString(input1)
hash2 := hashString(input2)
hash3 := hashString(input3)
// Same input should produce same hash
if hash1 != hash2 {
t.Errorf("Expected same hash for identical inputs")
}
// Different input should produce different hash
if hash1 == hash3 {
t.Errorf("Expected different hash for different inputs")
}
// Hash should be hex encoded SHA256 (64 characters)
if len(hash1) != 64 {
t.Errorf("Expected hash length of 64, got %d", len(hash1))
}
}

View File

@ -0,0 +1,218 @@
# Automatic Relation Loading Strategies
## Overview
**NEW:** The database adapters now **automatically** choose the optimal loading strategy by inspecting your model's relationship tags!
Simply use `PreloadRelation()` and the system automatically:
- Detects relationship type from Bun/GORM tags
- Uses **JOIN** for many-to-one and one-to-one (efficient, no duplication)
- Uses **separate query** for one-to-many and many-to-many (avoids duplication)
## How It Works
```go
// Just write this - the system handles the rest!
db.NewSelect().
Model(&links).
PreloadRelation("Provider"). // ✓ Auto-detects belongs-to → uses JOIN
PreloadRelation("Tags"). // ✓ Auto-detects has-many → uses separate query
Scan(ctx, &links)
```
### Detection Logic
The system inspects your model's struct tags:
**Bun models:**
```go
type Link struct {
Provider *Provider `bun:"rel:belongs-to"` // → Detected: belongs-to → JOIN
Tags []Tag `bun:"rel:has-many"` // → Detected: has-many → Separate query
}
```
**GORM models:**
```go
type Link struct {
ProviderID int
Provider *Provider `gorm:"foreignKey:ProviderID"` // → Detected: belongs-to → JOIN
Tags []Tag `gorm:"many2many:link_tags"` // → Detected: many-to-many → Separate query
}
```
**Type inference (fallback):**
- `[]Type` (slice) → has-many → Separate query
- `*Type` (pointer) → belongs-to → JOIN
- `Type` (struct) → belongs-to → JOIN
### What Gets Logged
Enable debug logging to see strategy selection:
```go
bunAdapter.EnableQueryDebug()
```
**Output:**
```
DEBUG: PreloadRelation 'Provider' detected as: belongs-to
INFO: Using JOIN strategy for belongs-to relation 'Provider'
DEBUG: PreloadRelation 'Links' detected as: has-many
DEBUG: Using separate query for has-many relation 'Links'
```
## Relationship Types
| Bun Tag | GORM Pattern | Field Type | Strategy | Why |
|---------|--------------|------------|----------|-----|
| `rel:has-many` | Slice field | `[]Type` | Separate Query | Avoids duplicating parent data |
| `rel:belongs-to` | `foreignKey:` | `*Type` | JOIN | Single parent, no duplication |
| `rel:has-one` | Single pointer | `*Type` | JOIN | One-to-one, no duplication |
| `rel:many-to-many` | `many2many:` | `[]Type` | Separate Query | Complex join, avoid cartesian |
## Manual Override
If you need to force a specific strategy, use `JoinRelation()`:
```go
// Force JOIN even for has-many (not recommended)
db.NewSelect().
Model(&providers).
JoinRelation("Links"). // Explicitly use JOIN
Scan(ctx, &providers)
```
## Examples
### Automatic Strategy Selection (Recommended)
```go
// Example 1: Loading parent provider for each link
// System detects belongs-to → uses JOIN automatically
db.NewSelect().
Model(&links).
PreloadRelation("Provider", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &links)
// Generated SQL: Single query with JOIN
// SELECT links.*, providers.*
// FROM links
// LEFT JOIN providers ON links.provider_id = providers.id
// WHERE providers.active = true
// Example 2: Loading child links for each provider
// System detects has-many → uses separate query automatically
db.NewSelect().
Model(&providers).
PreloadRelation("Links", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &providers)
// Generated SQL: Two queries
// Query 1: SELECT * FROM providers
// Query 2: SELECT * FROM links
// WHERE provider_id IN (1, 2, 3, ...)
// AND active = true
```
### Mixed Relationships
```go
type Order struct {
ID int
CustomerID int
Customer *Customer `bun:"rel:belongs-to"` // JOIN
Items []Item `bun:"rel:has-many"` // Separate
Invoice *Invoice `bun:"rel:has-one"` // JOIN
}
// All three handled optimally!
db.NewSelect().
Model(&orders).
PreloadRelation("Customer"). // → JOIN (many-to-one)
PreloadRelation("Items"). // → Separate (one-to-many)
PreloadRelation("Invoice"). // → JOIN (one-to-one)
Scan(ctx, &orders)
```
## Performance Benefits
### Before (Manual Strategy Selection)
```go
// You had to remember which to use:
.PreloadRelation("Provider") // Should I use PreloadRelation or JoinRelation?
.PreloadRelation("Links") // Which is more efficient here?
```
### After (Automatic Selection)
```go
// Just use PreloadRelation everywhere:
.PreloadRelation("Provider") // ✓ System uses JOIN automatically
.PreloadRelation("Links") // ✓ System uses separate query automatically
```
## Migration Guide
**No changes needed!** If you're already using `PreloadRelation()`, it now automatically optimizes:
```go
// Before: Always used separate query
.PreloadRelation("Provider") // Inefficient: extra round trip
// After: Automatic optimization
.PreloadRelation("Provider") // ✓ Now uses JOIN automatically!
```
## Implementation Details
### Supported Bun Tags
- `rel:has-many` → Separate query
- `rel:belongs-to` → JOIN
- `rel:has-one` → JOIN
- `rel:many-to-many` or `rel:m2m` → Separate query
### Supported GORM Patterns
- `many2many:` tag → Separate query
- `foreignKey:` tag → JOIN (belongs-to)
- `[]Type` slice without many2many → Separate query (has-many)
- `*Type` pointer with foreignKey → JOIN (belongs-to)
- `*Type` pointer without foreignKey → JOIN (has-one)
### Fallback Behavior
- `[]Type` (slice) → Separate query (safe default for collections)
- `*Type` or `Type` (single) → JOIN (safe default for single relations)
- Unknown → Separate query (safest default)
## Debugging
To see strategy selection in action:
```go
// Enable debug logging
bunAdapter.EnableQueryDebug() // or gormAdapter.EnableQueryDebug()
// Run your query
db.NewSelect().
Model(&records).
PreloadRelation("RelationName").
Scan(ctx, &records)
// Check logs for:
// - "PreloadRelation 'X' detected as: belongs-to"
// - "Using JOIN strategy for belongs-to relation 'X'"
// - Actual SQL queries executed
```
## Best Practices
1. **Use PreloadRelation() for everything** - Let the system optimize
2. **Define proper relationship tags** - Ensures correct detection
3. **Only use JoinRelation() for overrides** - When you know better than auto-detection
4. **Enable debug logging during development** - Verify optimal strategies are chosen
5. **Trust the system** - It's designed to choose correctly based on relationship type

View File

@ -0,0 +1,81 @@
package database
import (
"testing"
)
func TestNormalizeTableAlias(t *testing.T) {
tests := []struct {
name string
query string
expectedAlias string
tableName string
want string
}{
{
name: "strips plausible alias from simple condition",
query: "APIL.rid_hub = 2576",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = 2576",
},
{
name: "keeps correct alias",
query: "apiproviderlink.rid_hub = 2576",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "apiproviderlink.rid_hub = 2576",
},
{
name: "strips plausible alias with multiple conditions",
query: "APIL.rid_hub = ? AND APIL.active = ?",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = ? AND active = ?",
},
{
name: "handles mixed correct and plausible aliases",
query: "APIL.rid_hub = ? AND apiproviderlink.active = ?",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = ? AND apiproviderlink.active = ?",
},
{
name: "handles parentheses",
query: "(APIL.rid_hub = ?)",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "(rid_hub = ?)",
},
{
name: "no alias in query",
query: "rid_hub = ?",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = ?",
},
{
name: "keeps reference to different table (not in current table name)",
query: "APIL.rid_hub = ?",
expectedAlias: "apiprovider",
tableName: "apiprovider",
want: "APIL.rid_hub = ?",
},
{
name: "keeps reference with short prefix that might be ambiguous",
query: "AP.rid = ?",
expectedAlias: "apiprovider",
tableName: "apiprovider",
want: "AP.rid = ?",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeTableAlias(tt.query, tt.expectedAlias, tt.tableName)
if got != tt.want {
t.Errorf("normalizeTableAlias() = %q, want %q", got, tt.want)
}
})
}
}

View File

@ -4,15 +4,93 @@ import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"time"
"github.com/uptrace/bun"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// QueryDebugHook is a Bun query hook that logs all SQL queries including preloads
type QueryDebugHook struct{}
func (h *QueryDebugHook) BeforeQuery(ctx context.Context, event *bun.QueryEvent) context.Context {
return ctx
}
func (h *QueryDebugHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) {
query := event.Query
duration := time.Since(event.StartTime)
if event.Err != nil {
logger.Error("SQL Query Failed [%s]: %s. Error: %v", duration, query, event.Err)
} else {
logger.Debug("SQL Query Success [%s]: %s", duration, query)
}
}
// debugScanIntoStruct attempts to scan rows into a struct with detailed field-level logging
// This helps identify which specific field is causing scanning issues
func debugScanIntoStruct(rows interface{}, dest interface{}) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer")
}
v = v.Elem()
if v.Kind() != reflect.Struct && v.Kind() != reflect.Slice {
return fmt.Errorf("dest must be pointer to struct or slice")
}
// Log the type being scanned into
typeName := v.Type().String()
logger.Debug("Debug scan into type: %s (kind: %s)", typeName, v.Kind())
// Handle slice types - inspect the element type
var structType reflect.Type
if v.Kind() == reflect.Slice {
elemType := v.Type().Elem()
logger.Debug(" Slice element type: %s", elemType)
// If slice of pointers, get the underlying type
if elemType.Kind() == reflect.Ptr {
structType = elemType.Elem()
} else {
structType = elemType
}
} else if v.Kind() == reflect.Struct {
structType = v.Type()
}
// If we have a struct type, log all its fields
if structType != nil && structType.Kind() == reflect.Struct {
logger.Debug(" Struct %s has %d fields:", structType.Name(), structType.NumField())
for i := 0; i < structType.NumField(); i++ {
field := structType.Field(i)
// Log embedded fields specially
if field.Anonymous {
logger.Debug(" [%d] EMBEDDED: %s (type: %s, kind: %s, bun:%q)",
i, field.Name, field.Type, field.Type.Kind(), field.Tag.Get("bun"))
} else {
bunTag := field.Tag.Get("bun")
if bunTag == "" {
bunTag = "(no tag)"
}
logger.Debug(" [%d] %s (type: %s, kind: %s, bun:%q)",
i, field.Name, field.Type, field.Type.Kind(), bunTag)
}
}
}
return nil
}
// BunAdapter adapts Bun to work with our Database interface
// This demonstrates how the abstraction works with different ORMs
type BunAdapter struct {
@ -24,6 +102,28 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
return &BunAdapter{db: db}
}
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing
func (b *BunAdapter) EnableQueryDebug() {
b.db.AddQueryHook(&QueryDebugHook{})
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
}
// EnableDetailedScanDebug enables verbose logging of scan operations
// WARNING: This generates a LOT of log output. Use only for debugging specific issues.
func (b *BunAdapter) EnableDetailedScanDebug() {
logger.Info("Detailed scan debugging enabled - will log all field scanning operations")
// This is a flag that can be checked in scan operations
// Implementation would require modifying the scan logic
}
// DisableQueryDebug removes all query hooks
func (b *BunAdapter) DisableQueryDebug() {
// Create a new DB without hooks
// Note: Bun doesn't have a RemoveQueryHook, so we'd need to track hooks manually
logger.Info("To disable query debug, recreate the BunAdapter without adding the hook")
}
func (b *BunAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{
query: b.db.NewSelect(),
@ -43,12 +143,22 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery {
return &BunDeleteQuery{query: b.db.NewDelete()}
}
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunAdapter.Exec", r)
}
}()
result, err := b.db.ExecContext(ctx, query, args...)
return &BunResult{result: result}, err
}
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunAdapter.Query", r)
}
}()
return b.db.NewRaw(query, args...).Scan(ctx, dest)
}
@ -73,7 +183,12 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error {
return nil
}
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
}
}()
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
// Create adapter with transaction
adapter := &BunTxAdapter{tx: tx}
@ -81,14 +196,28 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
})
}
func (b *BunAdapter) GetUnderlyingDB() interface{} {
return b.db
}
// BunSelectQuery implements SelectQuery for Bun
type BunSelectQuery struct {
query *bun.SelectQuery
db bun.IDB // Store DB connection for count queries
hasModel bool // Track if Model() was called
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
query *bun.SelectQuery
db bun.IDB // Store DB connection for count queries
hasModel bool // Track if Model() was called
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
deferredPreloads []deferredPreload // Preloads to execute as separate queries
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
}
// deferredPreload represents a preload that will be executed as a separate query
// to avoid PostgreSQL identifier length limits
type deferredPreload struct {
relation string
apply []func(common.SelectQuery) common.SelectQuery
}
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
@ -122,16 +251,156 @@ func (b *BunSelectQuery) Column(columns ...string) common.SelectQuery {
}
func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
b.query = b.query.ColumnExpr(query, args)
if len(args) > 0 {
b.query = b.query.ColumnExpr(query, args)
} else {
b.query = b.query.ColumnExpr(query)
}
return b
}
func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
// If we're in a JOIN context, add table prefix to unqualified columns
if b.inJoinContext && b.joinTableAlias != "" {
query = addTablePrefix(query, b.joinTableAlias)
} else if b.tableAlias != "" && b.tableName != "" {
// If we have a table alias defined, check if the query references a different alias
// This can happen in preloads where the user expects a certain alias but Bun generates another
query = normalizeTableAlias(query, b.tableAlias, b.tableName)
}
b.query = b.query.Where(query, args...)
return b
}
// addTablePrefix adds a table prefix to unqualified column references
// This is used in JOIN contexts where conditions must reference the joined table
func addTablePrefix(query, tableAlias string) string {
if tableAlias == "" || query == "" {
return query
}
// Split on spaces and parentheses to find column references
parts := strings.FieldsFunc(query, func(r rune) bool {
return r == ' ' || r == '(' || r == ')' || r == ','
})
modified := query
for _, part := range parts {
// Check if this looks like an unqualified column reference
// (no dot, and likely a column name before an operator)
if !strings.Contains(part, ".") {
// Extract potential column name (before = or other operators)
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
if strings.Contains(part, op) {
colName := strings.Split(part, op)[0]
colName = strings.TrimSpace(colName)
if colName != "" && !isOperatorOrKeyword(colName) {
// Add table prefix
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
modified = strings.ReplaceAll(modified, part, prefixed)
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
}
break
}
}
}
}
return modified
}
// isOperatorOrKeyword checks if a string is likely an operator or SQL keyword
func isOperatorOrKeyword(s string) bool {
s = strings.ToUpper(strings.TrimSpace(s))
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
for _, kw := range keywords {
if s == kw {
return true
}
}
return false
}
// isAcronymMatch checks if prefix is an acronym of tableName
// For example, "apil" matches "apiproviderlink" because each letter appears in sequence
func isAcronymMatch(prefix, tableName string) bool {
if len(prefix) == 0 || len(tableName) == 0 {
return false
}
prefixIdx := 0
for i := 0; i < len(tableName) && prefixIdx < len(prefix); i++ {
if tableName[i] == prefix[prefixIdx] {
prefixIdx++
}
}
// All characters of prefix were found in sequence in tableName
return prefixIdx == len(prefix)
}
// normalizeTableAlias replaces table alias prefixes in SQL conditions
// This handles cases where a user references a table alias that doesn't match
// what Bun generates (common in preload contexts)
func normalizeTableAlias(query, expectedAlias, tableName string) string {
// Pattern: <word>.<column> where <word> might be an incorrect alias
// We'll look for patterns like "APIL.column" and either:
// 1. Remove the alias prefix if it's clearly meant for this table
// 2. Leave it alone if it might be referring to another table (JOIN/preload)
// Split on spaces and parentheses to find qualified references
parts := strings.FieldsFunc(query, func(r rune) bool {
return r == ' ' || r == '(' || r == ')' || r == ','
})
modified := query
for _, part := range parts {
// Check if this looks like a qualified column reference
if dotIndex := strings.Index(part, "."); dotIndex > 0 {
prefix := part[:dotIndex]
column := part[dotIndex+1:]
// Check if the prefix matches our expected alias or table name (case-insensitive)
if strings.EqualFold(prefix, expectedAlias) ||
strings.EqualFold(prefix, tableName) ||
strings.EqualFold(prefix, strings.ToLower(tableName)) {
// Prefix matches current table, it's safe but redundant - leave it
continue
}
// Check if the prefix could plausibly be an alias/acronym for this table
// Only strip if we're confident it's meant for this table
// For example: "APIL" could be an acronym for "apiproviderlink"
prefixLower := strings.ToLower(prefix)
tableNameLower := strings.ToLower(tableName)
// Check if prefix is a substring of table name
isSubstring := strings.Contains(tableNameLower, prefixLower) && len(prefixLower) > 2
// Check if prefix is an acronym of table name
// e.g., "APIL" matches "ApiProviderLink" (A-p-I-providerL-ink)
isAcronym := false
if !isSubstring && len(prefixLower) > 2 {
isAcronym = isAcronymMatch(prefixLower, tableNameLower)
}
if isSubstring || isAcronym {
// This looks like it could be an alias for this table - strip it
logger.Debug("Stripping plausible alias '%s' from WHERE condition, keeping just '%s'", prefix, column)
// Replace the qualified reference with just the column name
modified = strings.ReplaceAll(modified, part, column)
} else {
// Prefix doesn't match the current table at all
// It's likely referring to a different table (JOIN/preload)
// DON'T strip it - leave the qualified reference as-is
logger.Debug("Keeping qualified reference '%s' - prefix '%s' doesn't match current table '%s'", part, prefix, tableName)
}
}
}
return modified
}
func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
b.query = b.query.WhereOr(query, args...)
return b
@ -217,8 +486,122 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
return b
}
// // shortenAliasForPostgres shortens a table/relation alias if it would exceed PostgreSQL's 63-char limit
// // when combined with typical column names
// func shortenAliasForPostgres(relationPath string) (string, bool) {
// // Convert relation path to the alias format Bun uses: dots become double underscores
// // Also convert to lowercase and use snake_case as Bun does
// parts := strings.Split(relationPath, ".")
// alias := strings.ToLower(strings.Join(parts, "__"))
// // PostgreSQL truncates identifiers to 63 chars
// // If the alias + typical column name would exceed this, we need to shorten
// // Reserve at least 30 chars for column names (e.g., "__rid_mastertype_hubtype")
// const maxAliasLength = 30
// if len(alias) > maxAliasLength {
// // Create a shortened alias using a hash of the original
// hash := md5.Sum([]byte(alias))
// hashStr := hex.EncodeToString(hash[:])[:8]
// // Keep first few chars of original for readability + hash
// prefixLen := maxAliasLength - 9 // 9 = 1 underscore + 8 hash chars
// if prefixLen > len(alias) {
// prefixLen = len(alias)
// }
// shortened := alias[:prefixLen] + "_" + hashStr
// logger.Debug("Shortened alias '%s' (%d chars) to '%s' (%d chars) to avoid PostgreSQL 63-char limit",
// alias, len(alias), shortened, len(shortened))
// return shortened, true
// }
// return alias, false
// }
// // estimateColumnAliasLength estimates the length of a column alias in a nested preload
// // Bun creates aliases like: relationChain__columnName
// func estimateColumnAliasLength(relationPath string, columnName string) int {
// relationParts := strings.Split(relationPath, ".")
// aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
// // Bun adds "__" between alias and column name
// return len(aliasChain) + 2 + len(columnName)
// }
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Auto-detect relationship type and choose optimal loading strategy
// Get the model from the query if available
model := b.query.GetModel()
if model != nil && model.Value() != nil {
relType := reflection.GetRelationType(model.Value(), relation)
// Log the detected relationship type
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
// If this is a belongs-to or has-one relation, use JOIN for better performance
if relType.ShouldUseJoin() {
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
return b.JoinRelation(relation, apply...)
}
// For has-many, many-to-many, or unknown: use separate query (safer default)
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
}
}
// Check if this relation chain would create problematic long aliases
relationParts := strings.Split(relation, ".")
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
// PostgreSQL's identifier limit is 63 characters
const postgresIdentifierLimit = 63
const safeAliasLimit = 35 // Leave room for column names
// If the alias chain is too long, defer this preload to be executed as a separate query
if len(aliasChain) > safeAliasLimit {
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
// For nested preloads (e.g., "Parent.Child"), split into separate preloads
// This avoids the long concatenated alias
if len(relationParts) > 1 {
// Load first level normally: "Parent"
firstLevel := relationParts[0]
remainingPath := strings.Join(relationParts[1:], ".")
logger.Info("Splitting nested preload: loading '%s' first, then '%s' separately",
firstLevel, remainingPath)
// Apply the first level preload normally
b.query = b.query.Relation(firstLevel)
// Store the remaining nested preload to be executed after the main query
b.deferredPreloads = append(b.deferredPreloads, deferredPreload{
relation: relation,
apply: apply,
})
return b
}
// Single level but still too long - just warn and continue
logger.Warn("Single-level preload '%s' has a very long name (%d chars). "+
"Consider renaming the field to avoid potential issues.",
relation, len(aliasChain))
}
// Normal preload handling
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() {
if r := recover(); r != nil {
err := logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
if err != nil {
return
}
}
}()
if len(apply) == 0 {
return sq
}
@ -229,6 +612,28 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
db: b.db,
}
// Try to extract table name and alias from the preload model
if model := sq.GetModel(); model != nil && model.Value() != nil {
modelValue := model.Value()
// Extract table name if model implements TableNameProvider
if provider, ok := modelValue.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
wrapper.schema, wrapper.tableName = parseTableName(fullTableName)
}
// Extract table alias if model implements TableAliasProvider
if provider, ok := modelValue.(common.TableAliasProvider); ok {
wrapper.tableAlias = provider.TableAlias()
// Apply the alias to the Bun query so conditions can reference it
if wrapper.tableAlias != "" {
// Note: Bun's Relation() already sets up the table, but we can add
// the alias explicitly if needed
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
}
}
}
// Start with the interface value (not pointer)
current := common.SelectQuery(wrapper)
@ -251,6 +656,36 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
return b
}
func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// JoinRelation uses a LEFT JOIN instead of a separate query
// This is more efficient for many-to-one or one-to-one relationships
logger.Debug("JoinRelation '%s' - Using JOIN strategy with automatic WHERE prefix addition", relation)
// Wrap the apply functions to automatically add table prefix to WHERE conditions
wrappedApply := make([]func(common.SelectQuery) common.SelectQuery, 0, len(apply))
for _, fn := range apply {
if fn != nil {
wrappedFn := func(originalFn func(common.SelectQuery) common.SelectQuery) func(common.SelectQuery) common.SelectQuery {
return func(q common.SelectQuery) common.SelectQuery {
// Create a special wrapper that adds prefixes to WHERE conditions
if bunQuery, ok := q.(*BunSelectQuery); ok {
// Mark this query as being in JOIN context
bunQuery.inJoinContext = true
bunQuery.joinTableAlias = strings.ToLower(relation)
}
return originalFn(q)
}
}(fn)
wrappedApply = append(wrappedApply, wrappedFn)
}
}
// Use PreloadRelation with the wrapped functions
// Bun's Relation() will use JOIN for belongs-to and has-one relations
return b.PreloadRelation(relation, wrappedApply...)
}
func (b *BunSelectQuery) Order(order string) common.SelectQuery {
b.query = b.query.Order(order)
return b
@ -276,33 +711,260 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
return b
}
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) error {
return b.query.Scan(ctx, dest)
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Scan", r)
}
}()
if dest == nil {
return fmt.Errorf("destination cannot be nil")
}
// Execute the main query first
err = b.query.Scan(ctx, dest)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
return err
}
// Execute any deferred preloads
if len(b.deferredPreloads) > 0 {
err = b.executeDeferredPreloads(ctx, dest)
if err != nil {
logger.Warn("Failed to execute deferred preloads: %v", err)
// Don't fail the whole query, just log the warning
}
}
return nil
}
func (b *BunSelectQuery) ScanModel(ctx context.Context) error {
return b.query.Scan(ctx)
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
// Enhanced panic recovery with model information
model := b.query.GetModel()
var modelInfo string
if model != nil && model.Value() != nil {
modelValue := model.Value()
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
// Try to get the model's underlying struct type
v := reflect.ValueOf(modelValue)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() == reflect.Slice {
if v.Type().Elem().Kind() == reflect.Ptr {
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Elem().Name())
} else {
modelInfo += fmt.Sprintf(", Slice of: %s", v.Type().Elem().Name())
}
} else if v.Kind() == reflect.Struct {
modelInfo += fmt.Sprintf(", Struct: %s", v.Type().Name())
}
}
sqlStr := b.query.String()
logger.Error("Panic in BunSelectQuery.ScanModel: %v. %s. SQL: %s", r, modelInfo, sqlStr)
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
}
}()
if b.query.GetModel() == nil {
return fmt.Errorf("model is nil")
}
// Optional: Enable detailed field-level debugging (set to true to debug)
const enableDetailedDebug = true
if enableDetailedDebug {
model := b.query.GetModel()
if model != nil && model.Value() != nil {
if err := debugScanIntoStruct(nil, model.Value()); err != nil {
logger.Warn("Debug scan inspection failed: %v", err)
}
}
}
// Execute the main query first
err = b.query.Scan(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
return err
}
// Execute any deferred preloads
if len(b.deferredPreloads) > 0 {
model := b.query.GetModel()
err = b.executeDeferredPreloads(ctx, model.Value())
if err != nil {
logger.Warn("Failed to execute deferred preloads: %v", err)
// Don't fail the whole query, just log the warning
}
}
return nil
}
func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
// executeDeferredPreloads executes preloads that were deferred to avoid PostgreSQL identifier length limits
func (b *BunSelectQuery) executeDeferredPreloads(ctx context.Context, dest interface{}) error {
if len(b.deferredPreloads) == 0 {
return nil
}
for _, dp := range b.deferredPreloads {
err := b.executeSingleDeferredPreload(ctx, dest, dp)
if err != nil {
return fmt.Errorf("failed to execute deferred preload '%s': %w", dp.relation, err)
}
}
return nil
}
// executeSingleDeferredPreload executes a single deferred preload
// For a relation like "Parent.Child", it:
// 1. Finds all loaded Parent records in dest
// 2. Loads Child records for those Parents using a separate query (loading only "Child", not "Parent.Child")
// 3. Bun automatically assigns the Child records to the appropriate Parent.Child field
func (b *BunSelectQuery) executeSingleDeferredPreload(ctx context.Context, dest interface{}, dp deferredPreload) error {
relationParts := strings.Split(dp.relation, ".")
if len(relationParts) < 2 {
return fmt.Errorf("deferred preload must be nested (e.g., 'Parent.Child'), got: %s", dp.relation)
}
// The parent relation that was already loaded
parentRelation := relationParts[0]
// The child relation we need to load
childRelation := strings.Join(relationParts[1:], ".")
logger.Debug("Executing deferred preload: loading '%s' on already-loaded '%s'", childRelation, parentRelation)
// Use reflection to access the parent relation field(s) in the loaded records
// Then load the child relation for those parent records
destValue := reflect.ValueOf(dest)
if destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
// Handle both slice and single record
if destValue.Kind() == reflect.Slice {
// Iterate through each record in the slice
for i := 0; i < destValue.Len(); i++ {
record := destValue.Index(i)
if err := b.loadChildRelationForRecord(ctx, record, parentRelation, childRelation, dp.apply); err != nil {
logger.Warn("Failed to load child relation '%s' for record %d: %v", childRelation, i, err)
// Continue with other records
}
}
} else {
// Single record
if err := b.loadChildRelationForRecord(ctx, destValue, parentRelation, childRelation, dp.apply); err != nil {
return fmt.Errorf("failed to load child relation '%s': %w", childRelation, err)
}
}
return nil
}
// loadChildRelationForRecord loads a child relation for a single parent record
func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record reflect.Value, parentRelation, childRelation string, apply []func(common.SelectQuery) common.SelectQuery) error {
// Ensure we're working with the actual struct value, not a pointer
if record.Kind() == reflect.Ptr {
record = record.Elem()
}
// Get the parent relation field
parentField := record.FieldByName(parentRelation)
if !parentField.IsValid() {
// Parent relation field doesn't exist
logger.Debug("Parent relation field '%s' not found in record", parentRelation)
return nil
}
// Check if the parent field is nil (for pointer fields)
if parentField.Kind() == reflect.Ptr && parentField.IsNil() {
// Parent relation not loaded or nil, skip
logger.Debug("Parent relation field '%s' is nil, skipping child preload", parentRelation)
return nil
}
// Get the interface value to pass to Bun
parentValue := parentField.Interface()
// Load the child relation on the parent record
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
return b.db.NewSelect().
Model(parentValue).
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
// Apply any custom query modifications
if len(apply) > 0 {
wrapper := &BunSelectQuery{query: sq, db: b.db}
current := common.SelectQuery(wrapper)
for _, fn := range apply {
if fn != nil {
current = fn(current)
}
}
if finalBun, ok := current.(*BunSelectQuery); ok {
return finalBun.query
}
}
return sq
}).
Scan(ctx)
}
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Count", r)
count = 0
}
}()
// If Model() was set, use bun's native Count() which works properly
if b.hasModel {
count, err := b.query.Count(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
return count, err
}
// Otherwise, wrap as subquery to avoid "Model(nil)" error
// This is needed when only Table() is set without a model
var count int
err := b.db.NewSelect().
countQuery := b.db.NewSelect().
TableExpr("(?) AS subquery", b.query).
ColumnExpr("COUNT(*)").
Scan(ctx, &count)
ColumnExpr("COUNT(*)")
err = countQuery.Scan(ctx, &count)
if err != nil {
// Log SQL string for debugging
sqlStr := countQuery.String()
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
}
return count, err
}
func (b *BunSelectQuery) Exists(ctx context.Context) (bool, error) {
return b.query.Exists(ctx)
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunSelectQuery.Exists", r)
exists = false
}
}()
exists, err = b.query.Exists(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
return exists, err
}
// BunInsertQuery implements InsertQuery for Bun
@ -320,7 +982,6 @@ func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
if b.hasModel {
// If model is set, do not override table name
return b
}
b.query = b.query.Table(table)
@ -347,8 +1008,13 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
return b
}
func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
if b.values != nil && len(b.values) > 0 {
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunInsertQuery.Exec", r)
}
}()
if len(b.values) > 0 {
if !b.hasModel {
// If no model was set, use the values map as the model
// Bun can insert map[string]interface{} directly
@ -428,8 +1094,18 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
return b
}
func (b *BunUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
}
}()
result, err := b.query.Exec(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
}
return &BunResult{result: result}, err
}
@ -453,8 +1129,18 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ
return b
}
func (b *BunDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
}
}()
result, err := b.query.Exec(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
}
return &BunResult{result: result}, err
}
@ -526,3 +1212,7 @@ func (b *BunTxAdapter) RollbackTx(ctx context.Context) error {
func (b *BunTxAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
return fn(b) // Already in transaction
}
func (b *BunTxAdapter) GetUnderlyingDB() interface{} {
return b.tx
}

View File

@ -8,6 +8,7 @@ import (
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
@ -22,6 +23,22 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter {
return &GormAdapter{db: db}
}
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
g.db = g.db.Debug()
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
return g
}
// DisableQueryDebug disables query debugging
func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
// GORM's Debug() creates a new session, so we need to get the base DB
// This is a simplified implementation
logger.Info("GORM debug mode - create a new adapter without Debug() to disable")
return g
}
func (g *GormAdapter) NewSelect() common.SelectQuery {
return &GormSelectQuery{db: g.db}
}
@ -38,12 +55,22 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
return &GormDeleteQuery{db: g.db}
}
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.Exec", r)
}
}()
result := g.db.WithContext(ctx).Exec(query, args...)
return &GormResult{result: result}, result.Error
}
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.Query", r)
}
}()
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
}
@ -63,19 +90,30 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
return g.db.WithContext(ctx).Rollback().Error
}
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
}
}()
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx}
return fn(adapter)
})
}
func (g *GormAdapter) GetUnderlyingDB() interface{} {
return g.db
}
// GormSelectQuery implements SelectQuery for GORM
type GormSelectQuery struct {
db *gorm.DB
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
db *gorm.DB
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
}
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
@ -109,15 +147,71 @@ func (g *GormSelectQuery) Column(columns ...string) common.SelectQuery {
}
func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
g.db = g.db.Select(query, args...)
if len(args) > 0 {
g.db = g.db.Select(query, args...)
} else {
g.db = g.db.Select(query)
}
return g
}
func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
// If we're in a JOIN context, add table prefix to unqualified columns
if g.inJoinContext && g.joinTableAlias != "" {
query = addTablePrefixGorm(query, g.joinTableAlias)
}
g.db = g.db.Where(query, args...)
return g
}
// addTablePrefixGorm adds a table prefix to unqualified column references (GORM version)
func addTablePrefixGorm(query, tableAlias string) string {
if tableAlias == "" || query == "" {
return query
}
// Split on spaces and parentheses to find column references
parts := strings.FieldsFunc(query, func(r rune) bool {
return r == ' ' || r == '(' || r == ')' || r == ','
})
modified := query
for _, part := range parts {
// Check if this looks like an unqualified column reference
if !strings.Contains(part, ".") {
// Extract potential column name (before = or other operators)
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
if strings.Contains(part, op) {
colName := strings.Split(part, op)[0]
colName = strings.TrimSpace(colName)
if colName != "" && !isOperatorOrKeywordGorm(colName) {
// Add table prefix
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
modified = strings.ReplaceAll(modified, part, prefixed)
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
}
break
}
}
}
}
return modified
}
// isOperatorOrKeywordGorm checks if a string is likely an operator or SQL keyword (GORM version)
func isOperatorOrKeywordGorm(s string) bool {
s = strings.ToUpper(strings.TrimSpace(s))
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
for _, kw := range keywords {
if s == kw {
return true
}
}
return false
}
func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
g.db = g.db.Or(query, args...)
return g
@ -201,6 +295,27 @@ func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) co
}
func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Auto-detect relationship type and choose optimal loading strategy
// Get the model from GORM's statement if available
if g.db.Statement != nil && g.db.Statement.Model != nil {
relType := reflection.GetRelationType(g.db.Statement.Model, relation)
// Log the detected relationship type
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
// If this is a belongs-to or has-one relation, use JOIN for better performance
if relType.ShouldUseJoin() {
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
return g.JoinRelation(relation, apply...)
}
// For has-many, many-to-many, or unknown: use separate query (safer default)
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
}
}
// Use GORM's Preload (separate query strategy)
g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB {
if len(apply) == 0 {
return db
@ -230,6 +345,42 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
return g
}
func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// JoinRelation uses a JOIN instead of a separate preload query
// This is more efficient for many-to-one or one-to-one relationships
// as it avoids additional round trips to the database
// GORM's Joins() method forces a JOIN for the preload
logger.Debug("JoinRelation '%s' - Using GORM Joins() with automatic WHERE prefix addition", relation)
g.db = g.db.Joins(relation, func(db *gorm.DB) *gorm.DB {
if len(apply) == 0 {
return db
}
wrapper := &GormSelectQuery{
db: db,
inJoinContext: true, // Mark as JOIN context
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
}
current := common.SelectQuery(wrapper)
for _, fn := range apply {
if fn != nil {
current = fn(current)
}
}
if finalGorm, ok := current.(*GormSelectQuery); ok {
return finalGorm.db
}
return db
})
return g
}
func (g *GormSelectQuery) Order(order string) common.SelectQuery {
g.db = g.db.Order(order)
return g
@ -255,26 +406,78 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
return g
}
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) error {
return g.db.WithContext(ctx).Find(dest).Error
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Scan", r)
}
}()
err = g.db.WithContext(ctx).Find(dest).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Find(dest)
})
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
}
return err
}
func (g *GormSelectQuery) ScanModel(ctx context.Context) error {
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
}
}()
if g.db.Statement.Model == nil {
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
}
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Find(g.db.Statement.Model)
})
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
}
return err
}
func (g *GormSelectQuery) Count(ctx context.Context) (int, error) {
var count int64
err := g.db.WithContext(ctx).Count(&count).Error
return int(count), err
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Count", r)
count = 0
}
}()
var count64 int64
err = g.db.WithContext(ctx).Count(&count64).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Count(&count64)
})
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
return int(count64), err
}
func (g *GormSelectQuery) Exists(ctx context.Context) (bool, error) {
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Exists", r)
exists = false
}
}()
var count int64
err := g.db.WithContext(ctx).Limit(1).Count(&count).Error
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Limit(1).Count(&count)
})
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
return count > 0, err
}
@ -314,7 +517,12 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
return g
}
func (g *GormInsertQuery) Exec(ctx context.Context) (common.Result, error) {
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormInsertQuery.Exec", r)
}
}()
var result *gorm.DB
switch {
case g.model != nil:
@ -401,8 +609,20 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
return g
}
func (g *GormUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Updates(g.updates)
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Updates(g.updates)
})
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
return &GormResult{result: result}, result.Error
}
@ -428,8 +648,20 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
return g
}
func (g *GormDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Delete(g.model)
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Delete(g.model)
})
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
return &GormResult{result: result}, result.Error
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,176 @@
package database
import (
"context"
"database/sql"
"fmt"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Example demonstrates how to use the PgSQL adapter
func ExamplePgSQLAdapter() error {
// Connect to PostgreSQL database
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
}
defer db.Close()
// Create the PgSQL adapter
adapter := NewPgSQLAdapter(db)
// Enable query debugging (optional)
adapter.EnableQueryDebug()
ctx := context.Background()
// Example 1: Simple SELECT query
var results []map[string]interface{}
err = adapter.NewSelect().
Table("users").
Where("age > ?", 18).
Order("created_at DESC").
Limit(10).
Scan(ctx, &results)
if err != nil {
return fmt.Errorf("select failed: %w", err)
}
// Example 2: INSERT query
result, err := adapter.NewInsert().
Table("users").
Value("name", "John Doe").
Value("email", "john@example.com").
Value("age", 25).
Returning("id").
Exec(ctx)
if err != nil {
return fmt.Errorf("insert failed: %w", err)
}
fmt.Printf("Rows affected: %d\n", result.RowsAffected())
// Example 3: UPDATE query
result, err = adapter.NewUpdate().
Table("users").
Set("name", "Jane Doe").
Where("id = ?", 1).
Exec(ctx)
if err != nil {
return fmt.Errorf("update failed: %w", err)
}
fmt.Printf("Rows updated: %d\n", result.RowsAffected())
// Example 4: DELETE query
result, err = adapter.NewDelete().
Table("users").
Where("age < ?", 18).
Exec(ctx)
if err != nil {
return fmt.Errorf("delete failed: %w", err)
}
fmt.Printf("Rows deleted: %d\n", result.RowsAffected())
// Example 5: Using transactions
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
// Insert a new user
_, err := tx.NewInsert().
Table("users").
Value("name", "Transaction User").
Value("email", "tx@example.com").
Exec(ctx)
if err != nil {
return err
}
// Update another user
_, err = tx.NewUpdate().
Table("users").
Set("verified", true).
Where("email = ?", "tx@example.com").
Exec(ctx)
if err != nil {
return err
}
// Both operations succeed or both rollback
return nil
})
if err != nil {
return fmt.Errorf("transaction failed: %w", err)
}
// Example 6: JOIN query
err = adapter.NewSelect().
Table("users u").
Column("u.id", "u.name", "p.title as post_title").
LeftJoin("posts p ON p.user_id = u.id").
Where("u.active = ?", true).
Scan(ctx, &results)
if err != nil {
return fmt.Errorf("join query failed: %w", err)
}
// Example 7: Aggregation query
count, err := adapter.NewSelect().
Table("users").
Where("active = ?", true).
Count(ctx)
if err != nil {
return fmt.Errorf("count failed: %w", err)
}
fmt.Printf("Active users: %d\n", count)
// Example 8: Raw SQL execution
_, err = adapter.Exec(ctx, "CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
if err != nil {
return fmt.Errorf("raw exec failed: %w", err)
}
// Example 9: Raw SQL query
var users []map[string]interface{}
err = adapter.Query(ctx, &users, "SELECT * FROM users WHERE age > $1 LIMIT $2", 18, 10)
if err != nil {
return fmt.Errorf("raw query failed: %w", err)
}
return nil
}
// User is an example model
type User struct {
ID int `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
Age int `json:"age"`
}
// TableName implements common.TableNameProvider
func (u User) TableName() string {
return "users"
}
// ExampleWithModel demonstrates using models with the PgSQL adapter
func ExampleWithModel() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Use model with adapter
user := User{}
err = adapter.NewSelect().
Model(&user).
Where("id = ?", 1).
Scan(ctx, &user)
return err
}

View File

@ -0,0 +1,526 @@
// +build integration
package database
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
// Integration test models
type IntegrationUser struct {
ID int `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
CreatedAt time.Time `db:"created_at"`
Posts []*IntegrationPost `bun:"rel:has-many,join:id=user_id"`
}
func (u IntegrationUser) TableName() string {
return "users"
}
type IntegrationPost struct {
ID int `db:"id"`
Title string `db:"title"`
Content string `db:"content"`
UserID int `db:"user_id"`
Published bool `db:"published"`
CreatedAt time.Time `db:"created_at"`
User *IntegrationUser `bun:"rel:belongs-to,join:user_id=id"`
Comments []*IntegrationComment `bun:"rel:has-many,join:id=post_id"`
}
func (p IntegrationPost) TableName() string {
return "posts"
}
type IntegrationComment struct {
ID int `db:"id"`
Content string `db:"content"`
PostID int `db:"post_id"`
CreatedAt time.Time `db:"created_at"`
Post *IntegrationPost `bun:"rel:belongs-to,join:post_id=id"`
}
func (c IntegrationComment) TableName() string {
return "comments"
}
// setupTestDB creates a PostgreSQL container and returns the connection
func setupTestDB(t *testing.T) (*sql.DB, func()) {
ctx := context.Background()
req := testcontainers.ContainerRequest{
Image: "postgres:15-alpine",
ExposedPorts: []string{"5432/tcp"},
Env: map[string]string{
"POSTGRES_USER": "testuser",
"POSTGRES_PASSWORD": "testpass",
"POSTGRES_DB": "testdb",
},
WaitingFor: wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(60 * time.Second),
}
postgres, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
require.NoError(t, err)
host, err := postgres.Host(ctx)
require.NoError(t, err)
port, err := postgres.MappedPort(ctx, "5432")
require.NoError(t, err)
dsn := fmt.Sprintf("postgres://testuser:testpass@%s:%s/testdb?sslmode=disable",
host, port.Port())
db, err := sql.Open("pgx", dsn)
require.NoError(t, err)
// Wait for database to be ready
err = db.Ping()
require.NoError(t, err)
// Create schema
createSchema(t, db)
cleanup := func() {
db.Close()
postgres.Terminate(ctx)
}
return db, cleanup
}
// createSchema creates test tables
func createSchema(t *testing.T, db *sql.DB) {
schema := `
DROP TABLE IF EXISTS comments CASCADE;
DROP TABLE IF EXISTS posts CASCADE;
DROP TABLE IF EXISTS users CASCADE;
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name VARCHAR(255) NOT NULL,
email VARCHAR(255) UNIQUE NOT NULL,
age INT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE posts (
id SERIAL PRIMARY KEY,
title VARCHAR(255) NOT NULL,
content TEXT NOT NULL,
user_id INT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
published BOOLEAN DEFAULT false,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE comments (
id SERIAL PRIMARY KEY,
content TEXT NOT NULL,
post_id INT NOT NULL REFERENCES posts(id) ON DELETE CASCADE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
`
_, err := db.Exec(schema)
require.NoError(t, err)
}
// TestIntegration_BasicCRUD tests basic CRUD operations
func TestIntegration_BasicCRUD(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// CREATE
result, err := adapter.NewInsert().
Table("users").
Value("name", "John Doe").
Value("email", "john@example.com").
Value("age", 25).
Exec(ctx)
require.NoError(t, err)
assert.Equal(t, int64(1), result.RowsAffected())
// READ
var users []IntegrationUser
err = adapter.NewSelect().
Table("users").
Where("email = ?", "john@example.com").
Scan(ctx, &users)
require.NoError(t, err)
assert.Len(t, users, 1)
assert.Equal(t, "John Doe", users[0].Name)
assert.Equal(t, 25, users[0].Age)
userID := users[0].ID
// UPDATE
result, err = adapter.NewUpdate().
Table("users").
Set("age", 26).
Where("id = ?", userID).
Exec(ctx)
require.NoError(t, err)
assert.Equal(t, int64(1), result.RowsAffected())
// Verify update
var updatedUser IntegrationUser
err = adapter.NewSelect().
Table("users").
Where("id = ?", userID).
Scan(ctx, &updatedUser)
require.NoError(t, err)
assert.Equal(t, 26, updatedUser.Age)
// DELETE
result, err = adapter.NewDelete().
Table("users").
Where("id = ?", userID).
Exec(ctx)
require.NoError(t, err)
assert.Equal(t, int64(1), result.RowsAffected())
// Verify delete
count, err := adapter.NewSelect().
Table("users").
Where("id = ?", userID).
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 0, count)
}
// TestIntegration_ScanModel tests ScanModel functionality
func TestIntegration_ScanModel(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Insert test data
_, err := adapter.NewInsert().
Table("users").
Value("name", "Jane Smith").
Value("email", "jane@example.com").
Value("age", 30).
Exec(ctx)
require.NoError(t, err)
// Test single struct scan
user := &IntegrationUser{}
err = adapter.NewSelect().
Model(user).
Table("users").
Where("email = ?", "jane@example.com").
ScanModel(ctx)
require.NoError(t, err)
assert.Equal(t, "Jane Smith", user.Name)
assert.Equal(t, 30, user.Age)
// Test slice scan
users := []*IntegrationUser{}
err = adapter.NewSelect().
Model(&users).
Table("users").
ScanModel(ctx)
require.NoError(t, err)
assert.Len(t, users, 1)
}
// TestIntegration_Transaction tests transaction handling
func TestIntegration_Transaction(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Successful transaction
err := adapter.RunInTransaction(ctx, func(tx common.Database) error {
_, err := tx.NewInsert().
Table("users").
Value("name", "Alice").
Value("email", "alice@example.com").
Value("age", 28).
Exec(ctx)
if err != nil {
return err
}
_, err = tx.NewInsert().
Table("users").
Value("name", "Bob").
Value("email", "bob@example.com").
Value("age", 32).
Exec(ctx)
return err
})
require.NoError(t, err)
// Verify both records exist
count, err := adapter.NewSelect().
Table("users").
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 2, count)
// Failed transaction (should rollback)
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
_, err := tx.NewInsert().
Table("users").
Value("name", "Charlie").
Value("email", "charlie@example.com").
Value("age", 35).
Exec(ctx)
if err != nil {
return err
}
// Intentional error - duplicate email
_, err = tx.NewInsert().
Table("users").
Value("name", "David").
Value("email", "alice@example.com"). // Duplicate
Value("age", 40).
Exec(ctx)
return err
})
assert.Error(t, err)
// Verify rollback - count should still be 2
count, err = adapter.NewSelect().
Table("users").
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 2, count)
}
// TestIntegration_Preload tests basic preload functionality
func TestIntegration_Preload(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
userID := createTestUser(t, adapter, ctx, "John Doe", "john@example.com", 25)
createTestPost(t, adapter, ctx, userID, "First Post", "Content 1", true)
createTestPost(t, adapter, ctx, userID, "Second Post", "Content 2", false)
// Test Preload
var users []*IntegrationUser
err := adapter.NewSelect().
Model(&IntegrationUser{}).
Table("users").
Preload("Posts").
Scan(ctx, &users)
require.NoError(t, err)
assert.Len(t, users, 1)
assert.NotNil(t, users[0].Posts)
assert.Len(t, users[0].Posts, 2)
}
// TestIntegration_PreloadRelation tests smart PreloadRelation
func TestIntegration_PreloadRelation(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
userID := createTestUser(t, adapter, ctx, "Jane Smith", "jane@example.com", 30)
postID := createTestPost(t, adapter, ctx, userID, "Test Post", "Test Content", true)
createTestComment(t, adapter, ctx, postID, "Great post!")
createTestComment(t, adapter, ctx, postID, "Thanks for sharing!")
// Test PreloadRelation with belongs-to (should use JOIN)
var posts []*IntegrationPost
err := adapter.NewSelect().
Model(&IntegrationPost{}).
Table("posts").
PreloadRelation("User").
Scan(ctx, &posts)
require.NoError(t, err)
assert.Len(t, posts, 1)
// Note: JOIN preloading needs proper column selection to work
// For now, we test that it doesn't error
// Test PreloadRelation with has-many (should use subquery)
posts = []*IntegrationPost{}
err = adapter.NewSelect().
Model(&IntegrationPost{}).
Table("posts").
PreloadRelation("Comments").
Scan(ctx, &posts)
require.NoError(t, err)
assert.Len(t, posts, 1)
if posts[0].Comments != nil {
assert.Len(t, posts[0].Comments, 2)
}
}
// TestIntegration_JoinRelation tests explicit JoinRelation
func TestIntegration_JoinRelation(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
userID := createTestUser(t, adapter, ctx, "Bob Wilson", "bob@example.com", 35)
createTestPost(t, adapter, ctx, userID, "Join Test", "Content", true)
// Test JoinRelation
var posts []*IntegrationPost
err := adapter.NewSelect().
Model(&IntegrationPost{}).
Table("posts").
JoinRelation("User").
Scan(ctx, &posts)
require.NoError(t, err)
assert.Len(t, posts, 1)
}
// TestIntegration_ComplexQuery tests complex queries
func TestIntegration_ComplexQuery(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
userID1 := createTestUser(t, adapter, ctx, "Alice", "alice@example.com", 25)
userID2 := createTestUser(t, adapter, ctx, "Bob", "bob@example.com", 30)
userID3 := createTestUser(t, adapter, ctx, "Charlie", "charlie@example.com", 35)
createTestPost(t, adapter, ctx, userID1, "Post 1", "Content", true)
createTestPost(t, adapter, ctx, userID2, "Post 2", "Content", true)
createTestPost(t, adapter, ctx, userID3, "Post 3", "Content", false)
// Complex query with joins, where, order, limit
var results []map[string]interface{}
err := adapter.NewSelect().
Table("posts p").
Column("p.title", "u.name as author_name", "u.age as author_age").
LeftJoin("users u ON u.id = p.user_id").
Where("p.published = ?", true).
WhereOr("u.age > ?", 25).
Order("u.age DESC").
Limit(2).
Scan(ctx, &results)
require.NoError(t, err)
assert.LessOrEqual(t, len(results), 2)
}
// TestIntegration_Aggregation tests aggregation queries
func TestIntegration_Aggregation(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
createTestUser(t, adapter, ctx, "User 1", "user1@example.com", 20)
createTestUser(t, adapter, ctx, "User 2", "user2@example.com", 25)
createTestUser(t, adapter, ctx, "User 3", "user3@example.com", 30)
// Test Count
count, err := adapter.NewSelect().
Table("users").
Where("age >= ?", 25).
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 2, count)
// Test Exists
exists, err := adapter.NewSelect().
Table("users").
Where("email = ?", "user1@example.com").
Exists(ctx)
require.NoError(t, err)
assert.True(t, exists)
// Test Group By with aggregation
var results []map[string]interface{}
err = adapter.NewSelect().
Table("users").
Column("age", "COUNT(*) as count").
Group("age").
Having("COUNT(*) > ?", 0).
Order("age ASC").
Scan(ctx, &results)
require.NoError(t, err)
assert.Len(t, results, 3)
}
// Helper functions
func createTestUser(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, name, email string, age int) int {
var userID int
err := adapter.Query(ctx, &userID,
"INSERT INTO users (name, email, age) VALUES ($1, $2, $3) RETURNING id",
name, email, age)
require.NoError(t, err)
return userID
}
func createTestPost(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, userID int, title, content string, published bool) int {
var postID int
err := adapter.Query(ctx, &postID,
"INSERT INTO posts (title, content, user_id, published) VALUES ($1, $2, $3, $4) RETURNING id",
title, content, userID, published)
require.NoError(t, err)
return postID
}
func createTestComment(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, postID int, content string) int {
var commentID int
err := adapter.Query(ctx, &commentID,
"INSERT INTO comments (content, post_id) VALUES ($1, $2) RETURNING id",
content, postID)
require.NoError(t, err)
return commentID
}

View File

@ -0,0 +1,275 @@
package database
import (
"context"
"database/sql"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Example models for demonstrating preload functionality
// Author model - has many Posts
type Author struct {
ID int `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Posts []*Post `bun:"rel:has-many,join:id=author_id"`
}
func (a Author) TableName() string {
return "authors"
}
// Post model - belongs to Author, has many Comments
type Post struct {
ID int `db:"id"`
Title string `db:"title"`
Content string `db:"content"`
AuthorID int `db:"author_id"`
Author *Author `bun:"rel:belongs-to,join:author_id=id"`
Comments []*Comment `bun:"rel:has-many,join:id=post_id"`
}
func (p Post) TableName() string {
return "posts"
}
// Comment model - belongs to Post
type Comment struct {
ID int `db:"id"`
Content string `db:"content"`
PostID int `db:"post_id"`
Post *Post `bun:"rel:belongs-to,join:post_id=id"`
}
func (c Comment) TableName() string {
return "comments"
}
// ExamplePreload demonstrates the Preload functionality
func ExamplePreload() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Example 1: Simple Preload (uses subquery for has-many)
var authors []*Author
err = adapter.NewSelect().
Model(&Author{}).
Table("authors").
Preload("Posts"). // Load all posts for each author
Scan(ctx, &authors)
if err != nil {
return err
}
// Now authors[i].Posts will be populated with their posts
return nil
}
// ExamplePreloadRelation demonstrates smart PreloadRelation with auto-detection
func ExamplePreloadRelation() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Example 1: PreloadRelation auto-detects has-many (uses subquery)
var authors []*Author
err = adapter.NewSelect().
Model(&Author{}).
Table("authors").
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
return q.Where("published = ?", true).Order("created_at DESC")
}).
Where("active = ?", true).
Scan(ctx, &authors)
if err != nil {
return err
}
// Example 2: PreloadRelation auto-detects belongs-to (uses JOIN)
var posts []*Post
err = adapter.NewSelect().
Model(&Post{}).
Table("posts").
PreloadRelation("Author"). // Will use JOIN because it's belongs-to
Scan(ctx, &posts)
if err != nil {
return err
}
// Example 3: Nested preloads
err = adapter.NewSelect().
Model(&Author{}).
Table("authors").
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
// First load posts, then preload comments for each post
return q.Limit(10)
}).
Scan(ctx, &authors)
if err != nil {
return err
}
// Manually load nested relationships (two-level preloading)
for _, author := range authors {
if author.Posts != nil {
for _, post := range author.Posts {
var comments []*Comment
err := adapter.NewSelect().
Table("comments").
Where("post_id = ?", post.ID).
Scan(ctx, &comments)
if err == nil {
post.Comments = comments
}
}
}
}
return nil
}
// ExampleJoinRelation demonstrates explicit JOIN loading
func ExampleJoinRelation() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Example 1: Force JOIN for belongs-to relationship
var posts []*Post
err = adapter.NewSelect().
Model(&Post{}).
Table("posts").
JoinRelation("Author", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &posts)
if err != nil {
return err
}
// Example 2: Multiple JOINs
err = adapter.NewSelect().
Model(&Post{}).
Table("posts p").
Column("p.*", "a.name as author_name", "a.email as author_email").
LeftJoin("authors a ON a.id = p.author_id").
Where("p.published = ?", true).
Scan(ctx, &posts)
return err
}
// ExampleScanModel demonstrates ScanModel with struct destinations
func ExampleScanModel() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Example 1: Scan single struct
author := Author{}
err = adapter.NewSelect().
Model(&author).
Table("authors").
Where("id = ?", 1).
ScanModel(ctx) // ScanModel automatically uses the model set with Model()
if err != nil {
return err
}
// Example 2: Scan slice of structs
authors := []*Author{}
err = adapter.NewSelect().
Model(&authors).
Table("authors").
Where("active = ?", true).
Limit(10).
ScanModel(ctx)
return err
}
// ExampleCompleteWorkflow demonstrates a complete workflow with preloading
func ExampleCompleteWorkflow() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
adapter.EnableQueryDebug() // Enable query logging
ctx := context.Background()
// Step 1: Create an author
author := &Author{
Name: "John Doe",
Email: "john@example.com",
}
result, err := adapter.NewInsert().
Table("authors").
Value("name", author.Name).
Value("email", author.Email).
Returning("id").
Exec(ctx)
if err != nil {
return err
}
_ = result
// Step 2: Load author with all their posts
var loadedAuthor Author
err = adapter.NewSelect().
Model(&loadedAuthor).
Table("authors").
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
return q.Order("created_at DESC").Limit(5)
}).
Where("id = ?", 1).
ScanModel(ctx)
if err != nil {
return err
}
// Step 3: Update author name
_, err = adapter.NewUpdate().
Table("authors").
Set("name", "Jane Doe").
Where("id = ?", 1).
Exec(ctx)
return err
}

View File

@ -0,0 +1,629 @@
package database
import (
"context"
"database/sql"
"reflect"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Test models
type TestUser struct {
ID int `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
}
func (u TestUser) TableName() string {
return "users"
}
type TestPost struct {
ID int `db:"id"`
Title string `db:"title"`
Content string `db:"content"`
UserID int `db:"user_id"`
User *TestUser `bun:"rel:belongs-to,join:user_id=id"`
Comments []TestComment `bun:"rel:has-many,join:id=post_id"`
}
func (p TestPost) TableName() string {
return "posts"
}
type TestComment struct {
ID int `db:"id"`
Content string `db:"content"`
PostID int `db:"post_id"`
}
func (c TestComment) TableName() string {
return "comments"
}
// TestNewPgSQLAdapter tests adapter creation
func TestNewPgSQLAdapter(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
adapter := NewPgSQLAdapter(db)
assert.NotNil(t, adapter)
assert.Equal(t, db, adapter.db)
}
// TestPgSQLSelectQuery_BuildSQL tests SQL query building
func TestPgSQLSelectQuery_BuildSQL(t *testing.T) {
tests := []struct {
name string
setup func(*PgSQLSelectQuery)
expected string
}{
{
name: "simple select",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
},
expected: "SELECT * FROM users",
},
{
name: "select with columns",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.columns = []string{"id", "name", "email"}
},
expected: "SELECT id, name, email FROM users",
},
{
name: "select with where",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.whereClauses = []string{"age > $1"}
q.args = []interface{}{18}
},
expected: "SELECT * FROM users WHERE (age > $1)",
},
{
name: "select with order and limit",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.orderBy = []string{"created_at DESC"}
q.limit = 10
q.offset = 5
},
expected: "SELECT * FROM users ORDER BY created_at DESC LIMIT 10 OFFSET 5",
},
{
name: "select with join",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.joins = []string{"LEFT JOIN posts ON posts.user_id = users.id"}
},
expected: "SELECT * FROM users LEFT JOIN posts ON posts.user_id = users.id",
},
{
name: "select with group and having",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.groupBy = []string{"country"}
q.havingClauses = []string{"COUNT(*) > $1"}
q.args = []interface{}{5}
},
expected: "SELECT * FROM users GROUP BY country HAVING COUNT(*) > $1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := &PgSQLSelectQuery{
columns: []string{"*"},
}
tt.setup(q)
sql := q.buildSQL()
assert.Equal(t, tt.expected, sql)
})
}
}
// TestPgSQLSelectQuery_ReplacePlaceholders tests placeholder replacement
func TestPgSQLSelectQuery_ReplacePlaceholders(t *testing.T) {
tests := []struct {
name string
query string
argCount int
paramCounter int
expected string
}{
{
name: "single placeholder",
query: "age > ?",
argCount: 1,
paramCounter: 0,
expected: "age > $1",
},
{
name: "multiple placeholders",
query: "age > ? AND status = ?",
argCount: 2,
paramCounter: 0,
expected: "age > $1 AND status = $2",
},
{
name: "with existing counter",
query: "name = ?",
argCount: 1,
paramCounter: 5,
expected: "name = $6",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := &PgSQLSelectQuery{paramCounter: tt.paramCounter}
result := q.replacePlaceholders(tt.query, tt.argCount)
assert.Equal(t, tt.expected, result)
})
}
}
// TestPgSQLSelectQuery_Chaining tests method chaining
func TestPgSQLSelectQuery_Chaining(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
adapter := NewPgSQLAdapter(db)
query := adapter.NewSelect().
Table("users").
Column("id", "name").
Where("age > ?", 18).
Order("name ASC").
Limit(10).
Offset(5)
pgQuery := query.(*PgSQLSelectQuery)
assert.Equal(t, "users", pgQuery.tableName)
assert.Equal(t, []string{"id", "name"}, pgQuery.columns)
assert.Len(t, pgQuery.whereClauses, 1)
assert.Equal(t, []string{"name ASC"}, pgQuery.orderBy)
assert.Equal(t, 10, pgQuery.limit)
assert.Equal(t, 5, pgQuery.offset)
}
// TestPgSQLSelectQuery_Model tests model setting
func TestPgSQLSelectQuery_Model(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
adapter := NewPgSQLAdapter(db)
user := &TestUser{}
query := adapter.NewSelect().Model(user)
pgQuery := query.(*PgSQLSelectQuery)
assert.Equal(t, "users", pgQuery.tableName)
assert.Equal(t, user, pgQuery.model)
}
// TestScanRowsToStructSlice tests scanning rows into struct slice
func TestScanRowsToStructSlice(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
AddRow(1, "John Doe", "john@example.com", 25).
AddRow(2, "Jane Smith", "jane@example.com", 30)
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
var users []TestUser
err = adapter.NewSelect().
Table("users").
Scan(ctx, &users)
require.NoError(t, err)
assert.Len(t, users, 2)
assert.Equal(t, "John Doe", users[0].Name)
assert.Equal(t, "jane@example.com", users[1].Email)
assert.Equal(t, 30, users[1].Age)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestScanRowsToStructSlicePointers tests scanning rows into pointer slice
func TestScanRowsToStructSlicePointers(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
AddRow(1, "John Doe", "john@example.com", 25)
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
var users []*TestUser
err = adapter.NewSelect().
Table("users").
Scan(ctx, &users)
require.NoError(t, err)
assert.Len(t, users, 1)
assert.NotNil(t, users[0])
assert.Equal(t, "John Doe", users[0].Name)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestScanRowsToSingleStruct tests scanning a single row
func TestScanRowsToSingleStruct(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
AddRow(1, "John Doe", "john@example.com", 25)
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
var user TestUser
err = adapter.NewSelect().
Table("users").
Where("id = ?", 1).
Scan(ctx, &user)
require.NoError(t, err)
assert.Equal(t, 1, user.ID)
assert.Equal(t, "John Doe", user.Name)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestScanRowsToMapSlice tests scanning into map slice
func TestScanRowsToMapSlice(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email"}).
AddRow(1, "John Doe", "john@example.com").
AddRow(2, "Jane Smith", "jane@example.com")
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
var results []map[string]interface{}
err = adapter.NewSelect().
Table("users").
Scan(ctx, &results)
require.NoError(t, err)
assert.Len(t, results, 2)
assert.Equal(t, int64(1), results[0]["id"])
assert.Equal(t, "John Doe", results[0]["name"])
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLInsertQuery_Exec tests insert query execution
func TestPgSQLInsertQuery_Exec(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock.ExpectExec("INSERT INTO users").
WithArgs("John Doe", "john@example.com", 25).
WillReturnResult(sqlmock.NewResult(1, 1))
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
result, err := adapter.NewInsert().
Table("users").
Value("name", "John Doe").
Value("email", "john@example.com").
Value("age", 25).
Exec(ctx)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int64(1), result.RowsAffected())
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLUpdateQuery_Exec tests update query execution
func TestPgSQLUpdateQuery_Exec(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
// Note: Args order is SET values first, then WHERE values
mock.ExpectExec("UPDATE users SET name = \\$1 WHERE id = \\$2").
WithArgs("Jane Doe", 1).
WillReturnResult(sqlmock.NewResult(0, 1))
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
result, err := adapter.NewUpdate().
Table("users").
Set("name", "Jane Doe").
Where("id = ?", 1).
Exec(ctx)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int64(1), result.RowsAffected())
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLDeleteQuery_Exec tests delete query execution
func TestPgSQLDeleteQuery_Exec(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock.ExpectExec("DELETE FROM users WHERE id = \\$1").
WithArgs(1).
WillReturnResult(sqlmock.NewResult(0, 1))
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
result, err := adapter.NewDelete().
Table("users").
Where("id = ?", 1).
Exec(ctx)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int64(1), result.RowsAffected())
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLSelectQuery_Count tests count query
func TestPgSQLSelectQuery_Count(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"count"}).AddRow(42)
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
count, err := adapter.NewSelect().
Table("users").
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 42, count)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLSelectQuery_Exists tests exists query
func TestPgSQLSelectQuery_Exists(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"count"}).AddRow(1)
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
exists, err := adapter.NewSelect().
Table("users").
Where("email = ?", "john@example.com").
Exists(ctx)
require.NoError(t, err)
assert.True(t, exists)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLAdapter_Transaction tests transaction handling
func TestPgSQLAdapter_Transaction(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock.ExpectBegin()
mock.ExpectExec("INSERT INTO users").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
_, err := tx.NewInsert().
Table("users").
Value("name", "John").
Exec(ctx)
return err
})
require.NoError(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLAdapter_TransactionRollback tests transaction rollback
func TestPgSQLAdapter_TransactionRollback(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock.ExpectBegin()
mock.ExpectExec("INSERT INTO users").WillReturnError(sql.ErrConnDone)
mock.ExpectRollback()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
_, err := tx.NewInsert().
Table("users").
Value("name", "John").
Exec(ctx)
return err
})
assert.Error(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestBuildFieldMap tests field mapping construction
func TestBuildFieldMap(t *testing.T) {
userType := reflect.TypeOf(TestUser{})
fieldMap := buildFieldMap(userType, nil)
assert.NotEmpty(t, fieldMap)
// Check that fields are mapped
assert.Contains(t, fieldMap, "id")
assert.Contains(t, fieldMap, "name")
assert.Contains(t, fieldMap, "email")
assert.Contains(t, fieldMap, "age")
// Check field info
idInfo := fieldMap["id"]
assert.Equal(t, "ID", idInfo.Name)
}
// TestGetRelationMetadata tests relationship metadata extraction
func TestGetRelationMetadata(t *testing.T) {
q := &PgSQLSelectQuery{
model: &TestPost{},
}
// Test belongs-to relationship
meta := q.getRelationMetadata("User")
assert.NotNil(t, meta)
assert.Equal(t, "User", meta.fieldName)
// Test has-many relationship
meta = q.getRelationMetadata("Comments")
assert.NotNil(t, meta)
assert.Equal(t, "Comments", meta.fieldName)
}
// TestPreloadConfiguration tests preload configuration
func TestPreloadConfiguration(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
adapter := NewPgSQLAdapter(db)
// Test Preload
query := adapter.NewSelect().
Model(&TestPost{}).
Table("posts").
Preload("User")
pgQuery := query.(*PgSQLSelectQuery)
assert.Len(t, pgQuery.preloads, 1)
assert.Equal(t, "User", pgQuery.preloads[0].relation)
assert.False(t, pgQuery.preloads[0].useJoin)
// Test PreloadRelation
query = adapter.NewSelect().
Model(&TestPost{}).
Table("posts").
PreloadRelation("Comments")
pgQuery = query.(*PgSQLSelectQuery)
assert.Len(t, pgQuery.preloads, 1)
assert.Equal(t, "Comments", pgQuery.preloads[0].relation)
// Test JoinRelation
query = adapter.NewSelect().
Model(&TestPost{}).
Table("posts").
JoinRelation("User")
pgQuery = query.(*PgSQLSelectQuery)
assert.Len(t, pgQuery.preloads, 1)
assert.Equal(t, "User", pgQuery.preloads[0].relation)
assert.True(t, pgQuery.preloads[0].useJoin)
}
// TestScanModel tests ScanModel functionality
func TestScanModel(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
AddRow(1, "John Doe", "john@example.com", 25)
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
user := &TestUser{}
err = adapter.NewSelect().
Model(user).
Table("users").
Where("id = ?", 1).
ScanModel(ctx)
require.NoError(t, err)
assert.Equal(t, 1, user.ID)
assert.Equal(t, "John Doe", user.Name)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestRawSQL tests raw SQL execution
func TestRawSQL(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
// Test Exec
mock.ExpectExec("CREATE TABLE test").WillReturnResult(sqlmock.NewResult(0, 0))
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
_, err = adapter.Exec(ctx, "CREATE TABLE test (id INT)")
require.NoError(t, err)
// Test Query
rows := sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "Test")
mock.ExpectQuery("SELECT (.+) FROM test").WillReturnRows(rows)
var results []map[string]interface{}
err = adapter.Query(ctx, &results, "SELECT * FROM test WHERE id = $1", 1)
require.NoError(t, err)
assert.Len(t, results, 1)
assert.NoError(t, mock.ExpectationsWereMet())
}

View File

@ -0,0 +1,132 @@
package database
import (
"context"
"database/sql"
"testing"
"github.com/stretchr/testify/require"
)
// TestHelper provides utilities for database testing
type TestHelper struct {
DB *sql.DB
Adapter *PgSQLAdapter
t *testing.T
}
// NewTestHelper creates a new test helper
func NewTestHelper(t *testing.T, db *sql.DB) *TestHelper {
return &TestHelper{
DB: db,
Adapter: NewPgSQLAdapter(db),
t: t,
}
}
// CleanupTables truncates all test tables
func (h *TestHelper) CleanupTables() {
ctx := context.Background()
tables := []string{"comments", "posts", "users"}
for _, table := range tables {
_, err := h.DB.ExecContext(ctx, "TRUNCATE TABLE "+table+" CASCADE")
require.NoError(h.t, err)
}
}
// InsertUser inserts a test user and returns the ID
func (h *TestHelper) InsertUser(name, email string, age int) int {
ctx := context.Background()
result, err := h.Adapter.NewInsert().
Table("users").
Value("name", name).
Value("email", email).
Value("age", age).
Exec(ctx)
require.NoError(h.t, err)
id, _ := result.LastInsertId()
return int(id)
}
// InsertPost inserts a test post and returns the ID
func (h *TestHelper) InsertPost(userID int, title, content string, published bool) int {
ctx := context.Background()
result, err := h.Adapter.NewInsert().
Table("posts").
Value("user_id", userID).
Value("title", title).
Value("content", content).
Value("published", published).
Exec(ctx)
require.NoError(h.t, err)
id, _ := result.LastInsertId()
return int(id)
}
// InsertComment inserts a test comment and returns the ID
func (h *TestHelper) InsertComment(postID int, content string) int {
ctx := context.Background()
result, err := h.Adapter.NewInsert().
Table("comments").
Value("post_id", postID).
Value("content", content).
Exec(ctx)
require.NoError(h.t, err)
id, _ := result.LastInsertId()
return int(id)
}
// AssertUserExists checks if a user exists by email
func (h *TestHelper) AssertUserExists(email string) {
ctx := context.Background()
exists, err := h.Adapter.NewSelect().
Table("users").
Where("email = ?", email).
Exists(ctx)
require.NoError(h.t, err)
require.True(h.t, exists, "User with email %s should exist", email)
}
// AssertUserCount asserts the number of users
func (h *TestHelper) AssertUserCount(expected int) {
ctx := context.Background()
count, err := h.Adapter.NewSelect().
Table("users").
Count(ctx)
require.NoError(h.t, err)
require.Equal(h.t, expected, count)
}
// GetUserByEmail retrieves a user by email
func (h *TestHelper) GetUserByEmail(email string) map[string]interface{} {
ctx := context.Background()
var results []map[string]interface{}
err := h.Adapter.NewSelect().
Table("users").
Where("email = ?", email).
Scan(ctx, &results)
require.NoError(h.t, err)
require.Len(h.t, results, 1, "Expected exactly one user with email %s", email)
return results[0]
}
// BeginTestTransaction starts a transaction for testing
func (h *TestHelper) BeginTestTransaction() (*PgSQLTxAdapter, func()) {
ctx := context.Background()
tx, err := h.DB.BeginTx(ctx, nil)
require.NoError(h.t, err)
adapter := &PgSQLTxAdapter{tx: tx}
cleanup := func() {
tx.Rollback()
}
return adapter, cleanup
}

View File

@ -6,6 +6,7 @@ import (
"github.com/uptrace/bunrouter"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// BunRouterAdapter adapts uptrace/bunrouter to work with our Router interface
@ -35,7 +36,11 @@ func (b *BunRouterAdapter) HandleFunc(pattern string, handler common.HTTPHandler
func (b *BunRouterAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) {
// This method would be used when we need to serve through our interface
// For now, we'll work directly with the underlying router
panic("ServeHTTP not implemented - use GetBunRouter() for direct access")
w.WriteHeader(http.StatusNotImplemented)
_, err := w.Write([]byte(`{"error":"ServeHTTP not implemented - use GetBunRouter() for direct access"}`))
if err != nil {
logger.Warn("Failed to write. %v", err)
}
}
// GetBunRouter returns the underlying bunrouter for direct access
@ -121,6 +126,16 @@ func (b *BunRouterRequest) QueryParam(key string) string {
return b.req.URL.Query().Get(key)
}
func (b *BunRouterRequest) AllQueryParams() map[string]string {
params := make(map[string]string)
for key, values := range b.req.URL.Query() {
if len(values) > 0 {
params[key] = values[0]
}
}
return params
}
func (b *BunRouterRequest) AllHeaders() map[string]string {
headers := make(map[string]string)
for key, values := range b.req.Header {
@ -131,6 +146,12 @@ func (b *BunRouterRequest) AllHeaders() map[string]string {
return headers
}
// UnderlyingRequest returns the underlying *http.Request
// This is useful when you need to pass the request to other handlers
func (b *BunRouterRequest) UnderlyingRequest() *http.Request {
return b.req.Request
}
// StandardBunRouterAdapter creates routes compatible with standard bunrouter handlers
type StandardBunRouterAdapter struct {
*BunRouterAdapter

View File

@ -8,6 +8,7 @@ import (
"github.com/gorilla/mux"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// MuxAdapter adapts Gorilla Mux to work with our Router interface
@ -32,7 +33,11 @@ func (m *MuxAdapter) HandleFunc(pattern string, handler common.HTTPHandlerFunc)
func (m *MuxAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) {
// This method would be used when we need to serve through our interface
// For now, we'll work directly with the underlying router
panic("ServeHTTP not implemented - use GetMuxRouter() for direct access")
w.WriteHeader(http.StatusNotImplemented)
_, err := w.Write([]byte(`{"error":"ServeHTTP not implemented - use GetMuxRouter() for direct access"}`))
if err != nil {
logger.Warn("Failed to write. %v", err)
}
}
// MuxRouteRegistration implements RouteRegistration for Mux
@ -117,6 +122,16 @@ func (h *HTTPRequest) QueryParam(key string) string {
return h.req.URL.Query().Get(key)
}
func (h *HTTPRequest) AllQueryParams() map[string]string {
params := make(map[string]string)
for key, values := range h.req.URL.Query() {
if len(values) > 0 {
params[key] = values[0]
}
}
return params
}
func (h *HTTPRequest) AllHeaders() map[string]string {
headers := make(map[string]string)
for key, values := range h.req.Header {
@ -127,6 +142,12 @@ func (h *HTTPRequest) AllHeaders() map[string]string {
return headers
}
// UnderlyingRequest returns the underlying *http.Request
// This is useful when you need to pass the request to other handlers
func (h *HTTPRequest) UnderlyingRequest() *http.Request {
return h.req
}
// HTTPResponseWriter adapts our ResponseWriter interface to standard http.ResponseWriter
type HTTPResponseWriter struct {
resp http.ResponseWriter
@ -156,6 +177,12 @@ func (h *HTTPResponseWriter) WriteJSON(data interface{}) error {
return json.NewEncoder(h.resp).Encode(data)
}
// UnderlyingResponseWriter returns the underlying http.ResponseWriter
// This is useful when you need to pass the response writer to other handlers
func (h *HTTPResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
return h.resp
}
// StandardMuxAdapter creates routes compatible with standard http.HandlerFunc
type StandardMuxAdapter struct {
*MuxAdapter

119
pkg/common/cors.go Normal file
View File

@ -0,0 +1,119 @@
package common
import (
"fmt"
"strings"
)
// CORSConfig holds CORS configuration
type CORSConfig struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
MaxAge int
}
// DefaultCORSConfig returns a default CORS configuration suitable for HeadSpec
func DefaultCORSConfig() CORSConfig {
return CORSConfig{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowedHeaders: GetHeadSpecHeaders(),
MaxAge: 86400, // 24 hours
}
}
// GetHeadSpecHeaders returns all headers used by HeadSpec
func GetHeadSpecHeaders() []string {
return []string{
// Standard headers
"Content-Type",
"Authorization",
"Accept",
"Accept-Language",
"Content-Language",
// Field Selection
"X-Select-Fields",
"X-Not-Select-Fields",
"X-Clean-JSON",
// Filtering & Search
"X-FieldFilter-*",
"X-SearchFilter-*",
"X-SearchOp-*",
"X-SearchOr-*",
"X-SearchAnd-*",
"X-SearchCols",
"X-Custom-SQL-W",
"X-Custom-SQL-W-*",
"X-Custom-SQL-Or",
"X-Custom-SQL-Or-*",
// Joins & Relations
"X-Preload",
"X-Preload-*",
"X-Expand",
"X-Expand-*",
"X-Custom-SQL-Join",
"X-Custom-SQL-Join-*",
// Sorting & Pagination
"X-Sort",
"X-Sort-*",
"X-Limit",
"X-Offset",
"X-Cursor-Forward",
"X-Cursor-Backward",
// Advanced Features
"X-AdvSQL-*",
"X-CQL-Sel-*",
"X-Distinct",
"X-SkipCount",
"X-SkipCache",
"X-Fetch-RowNumber",
"X-PKRow",
// Response Format
"X-SimpleAPI",
"X-DetailAPI",
"X-Syncfusion",
"X-Single-Record-As-Object",
// Transaction Control
"X-Transaction-Atomic",
// X-Files - comprehensive JSON configuration
"X-Files",
}
}
// SetCORSHeaders sets CORS headers on a response writer
func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
// Set allowed origins
if len(config.AllowedOrigins) > 0 {
w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
}
// Set allowed methods
if len(config.AllowedMethods) > 0 {
w.SetHeader("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
}
// Set allowed headers
if len(config.AllowedHeaders) > 0 {
w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
}
// Set max age
if config.MaxAge > 0 {
w.SetHeader("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
}
// Allow credentials
w.SetHeader("Access-Control-Allow-Credentials", "true")
// Expose headers that clients can read
w.SetHeader("Access-Control-Expose-Headers", "Content-Range, X-Api-Range-Total, X-Api-Range-Size")
}

View File

@ -0,0 +1,97 @@
package common
// Example showing how to use the common handler interfaces
// This file demonstrates the handler interface hierarchy and usage patterns
// ProcessWithAnyHandler demonstrates using the base SpecHandler interface
// which works with any handler type (resolvespec, restheadspec, or funcspec)
func ProcessWithAnyHandler(handler SpecHandler) Database {
// All handlers expose GetDatabase() through the SpecHandler interface
return handler.GetDatabase()
}
// ProcessCRUDRequest demonstrates using the CRUDHandler interface
// which works with resolvespec.Handler and restheadspec.Handler
func ProcessCRUDRequest(handler CRUDHandler, w ResponseWriter, r Request, params map[string]string) {
// Both resolvespec and restheadspec handlers implement Handle()
handler.Handle(w, r, params)
}
// ProcessMetadataRequest demonstrates getting metadata from CRUD handlers
func ProcessMetadataRequest(handler CRUDHandler, w ResponseWriter, r Request, params map[string]string) {
// Both resolvespec and restheadspec handlers implement HandleGet()
handler.HandleGet(w, r, params)
}
// Example usage patterns (not executable, just for documentation):
/*
// Example 1: Using with resolvespec.Handler
func ExampleResolveSpec() {
db := // ... get database
registry := // ... get registry
handler := resolvespec.NewHandler(db, registry)
// Can be used as SpecHandler
var specHandler SpecHandler = handler
database := specHandler.GetDatabase()
// Can be used as CRUDHandler
var crudHandler CRUDHandler = handler
crudHandler.Handle(w, r, params)
crudHandler.HandleGet(w, r, params)
}
// Example 2: Using with restheadspec.Handler
func ExampleRestHeadSpec() {
db := // ... get database
registry := // ... get registry
handler := restheadspec.NewHandler(db, registry)
// Can be used as SpecHandler
var specHandler SpecHandler = handler
database := specHandler.GetDatabase()
// Can be used as CRUDHandler
var crudHandler CRUDHandler = handler
crudHandler.Handle(w, r, params)
crudHandler.HandleGet(w, r, params)
}
// Example 3: Using with funcspec.Handler
func ExampleFuncSpec() {
db := // ... get database
handler := funcspec.NewHandler(db)
// Can be used as SpecHandler
var specHandler SpecHandler = handler
database := specHandler.GetDatabase()
// Can be used as QueryHandler
var queryHandler QueryHandler = handler
// funcspec has different methods: SqlQueryList() and SqlQuery()
// which return HTTP handler functions
}
// Example 4: Polymorphic handler processing
func ProcessHandlers(handlers []SpecHandler) {
for _, handler := range handlers {
// All handlers expose the database
db := handler.GetDatabase()
// Type switch for specific handler types
switch h := handler.(type) {
case CRUDHandler:
// This is resolvespec or restheadspec
// Can call Handle() and HandleGet()
_ = h
case QueryHandler:
// This is funcspec
// Can call SqlQueryList() and SqlQuery()
_ = h
}
}
}
*/

View File

@ -0,0 +1,47 @@
package common
import (
"fmt"
"reflect"
)
// ValidateAndUnwrapModelResult contains the result of model validation
type ValidateAndUnwrapModelResult struct {
ModelType reflect.Type
Model interface{}
ModelPtr interface{}
OriginalType reflect.Type
}
// ValidateAndUnwrapModel validates that a model is a struct type and unwraps
// pointers, slices, and arrays to get to the base struct type.
// Returns an error if the model is not a valid struct type.
func ValidateAndUnwrapModel(model interface{}) (*ValidateAndUnwrapModelResult, error) {
modelType := reflect.TypeOf(model)
originalType := modelType
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil, fmt.Errorf("model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType)
}
// If the registered model was a pointer or slice, use the unwrapped struct type
if originalType != modelType {
model = reflect.New(modelType).Elem().Interface()
}
// Create a pointer to the model type for database operations
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
return &ValidateAndUnwrapModelResult{
ModelType: modelType,
Model: model,
ModelPtr: modelPtr,
OriginalType: originalType,
}, nil
}

View File

@ -1,6 +1,11 @@
package common
import "context"
import (
"context"
"encoding/json"
"io"
"net/http"
)
// Database interface designed to work with both GORM and Bun
type Database interface {
@ -19,6 +24,12 @@ type Database interface {
CommitTx(ctx context.Context) error
RollbackTx(ctx context.Context) error
RunInTransaction(ctx context.Context, fn func(Database) error) error
// GetUnderlyingDB returns the underlying database connection
// For GORM, this returns *gorm.DB
// For Bun, this returns *bun.DB
// This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN
GetUnderlyingDB() interface{}
}
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)
@ -33,6 +44,7 @@ type SelectQuery interface {
LeftJoin(query string, args ...interface{}) SelectQuery
Preload(relation string, conditions ...interface{}) SelectQuery
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
Order(order string) SelectQuery
Limit(n int) SelectQuery
Offset(n int) SelectQuery
@ -116,6 +128,8 @@ type Request interface {
Body() ([]byte, error)
PathParam(key string) string
QueryParam(key string) string
AllQueryParams() map[string]string // Get all query parameters as a map
UnderlyingRequest() *http.Request // Get the underlying *http.Request for forwarding to other handlers
}
// ResponseWriter interface abstracts HTTP response
@ -124,11 +138,113 @@ type ResponseWriter interface {
WriteHeader(statusCode int)
Write(data []byte) (int, error)
WriteJSON(data interface{}) error
UnderlyingResponseWriter() http.ResponseWriter // Get the underlying http.ResponseWriter for forwarding to other handlers
}
// HTTPHandlerFunc type for HTTP handlers
type HTTPHandlerFunc func(ResponseWriter, Request)
// WrapHTTPRequest wraps standard http.ResponseWriter and *http.Request into common interfaces
func WrapHTTPRequest(w http.ResponseWriter, r *http.Request) (ResponseWriter, Request) {
return &StandardResponseWriter{w: w}, &StandardRequest{r: r}
}
// StandardResponseWriter adapts http.ResponseWriter to ResponseWriter interface
type StandardResponseWriter struct {
w http.ResponseWriter
status int
}
func (s *StandardResponseWriter) SetHeader(key, value string) {
s.w.Header().Set(key, value)
}
func (s *StandardResponseWriter) WriteHeader(statusCode int) {
s.status = statusCode
s.w.WriteHeader(statusCode)
}
func (s *StandardResponseWriter) Write(data []byte) (int, error) {
return s.w.Write(data)
}
func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
s.SetHeader("Content-Type", "application/json")
return json.NewEncoder(s.w).Encode(data)
}
func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
return s.w
}
// StandardRequest adapts *http.Request to Request interface
type StandardRequest struct {
r *http.Request
body []byte
}
func (s *StandardRequest) Method() string {
return s.r.Method
}
func (s *StandardRequest) URL() string {
return s.r.URL.String()
}
func (s *StandardRequest) Header(key string) string {
return s.r.Header.Get(key)
}
func (s *StandardRequest) AllHeaders() map[string]string {
headers := make(map[string]string)
for key, values := range s.r.Header {
if len(values) > 0 {
headers[key] = values[0]
}
}
return headers
}
func (s *StandardRequest) Body() ([]byte, error) {
if s.body != nil {
return s.body, nil
}
if s.r.Body == nil {
return nil, nil
}
defer s.r.Body.Close()
body, err := io.ReadAll(s.r.Body)
if err != nil {
return nil, err
}
s.body = body
return body, nil
}
func (s *StandardRequest) PathParam(key string) string {
// Standard http.Request doesn't have path params
// This should be set by the router
return ""
}
func (s *StandardRequest) QueryParam(key string) string {
return s.r.URL.Query().Get(key)
}
func (s *StandardRequest) AllQueryParams() map[string]string {
params := make(map[string]string)
for key, values := range s.r.URL.Query() {
if len(values) > 0 {
params[key] = values[0]
}
}
return params
}
func (s *StandardRequest) UnderlyingRequest() *http.Request {
return s.r
}
// TableNameProvider interface for models that provide table names
type TableNameProvider interface {
TableName() string
@ -147,3 +263,39 @@ type PrimaryKeyNameProvider interface {
type SchemaProvider interface {
SchemaName() string
}
// SpecHandler interface represents common functionality across all spec handlers
// This is the base interface implemented by:
// - resolvespec.Handler: Handles CRUD operations via request body with explicit operation field
// - restheadspec.Handler: Handles CRUD operations via HTTP methods (GET/POST/PUT/DELETE)
// - funcspec.Handler: Handles custom SQL query execution with dynamic parameters
//
// The interface hierarchy is:
//
// SpecHandler (base)
// ├── CRUDHandler (resolvespec, restheadspec)
// └── QueryHandler (funcspec)
type SpecHandler interface {
// GetDatabase returns the underlying database connection
GetDatabase() Database
}
// CRUDHandler interface for handlers that support CRUD operations
// This is implemented by resolvespec.Handler and restheadspec.Handler
type CRUDHandler interface {
SpecHandler
// Handle processes API requests through router-agnostic interface
Handle(w ResponseWriter, r Request, params map[string]string)
// HandleGet processes GET requests for metadata
HandleGet(w ResponseWriter, r Request, params map[string]string)
}
// QueryHandler interface for handlers that execute SQL queries
// This is implemented by funcspec.Handler
// Note: funcspec uses standard http.ResponseWriter and *http.Request instead of common interfaces
type QueryHandler interface {
SpecHandler
// Methods are defined in funcspec package due to different function signature requirements
}

635
pkg/common/sql_helpers.go Normal file
View File

@ -0,0 +1,635 @@
package common
import (
"fmt"
"regexp"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ValidateAndFixPreloadWhere validates and normalizes WHERE clauses for preloads
//
// NOTE: For preload queries, table aliases from the parent query are not valid since
// the preload executes as a separate query with its own table alias. This function
// now simply validates basic syntax without requiring or adding prefixes.
// The actual alias normalization happens in the database adapter layer.
//
// Returns the WHERE clause and an error if it contains obviously invalid syntax.
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
if where == "" {
return where, nil
}
where = strings.TrimSpace(where)
// Just do basic validation - don't require or add prefixes
// The database adapter will handle alias normalization
// Check if the WHERE clause contains any qualified column references
// If it does, log a debug message but don't fail - let the adapter handle it
if strings.Contains(where, ".") {
logger.Debug("Preload WHERE clause for '%s' contains qualified column references: '%s'. "+
"Note: In preload context, table aliases from parent query are not available. "+
"The database adapter will normalize aliases automatically.", relationName, where)
}
// Validate that it's not empty or just whitespace
if where == "" {
return where, nil
}
// Return the WHERE clause as-is
// The BunSelectQuery.Where() method will handle alias normalization via normalizeTableAlias()
return where, nil
}
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
func IsSQLExpression(cond string) bool {
// Common SQL literals and expressions
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
for _, literal := range sqlLiterals {
if cond == literal {
return true
}
}
return false
}
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
// These conditions should be removed from WHERE clauses as they have no filtering effect
func IsTrivialCondition(cond string) bool {
cond = strings.TrimSpace(cond)
lowerCond := strings.ToLower(cond)
// Conditions that always evaluate to true
trivialConditions := []string{
"1=1", "1 = 1", "1= 1", "1 =1",
"true", "true = true", "true=true", "true= true", "true =true",
"0=0", "0 = 0", "0= 0", "0 =0",
}
for _, trivial := range trivialConditions {
if lowerCond == trivial {
return true
}
}
return false
}
// validateWhereClauseSecurity checks for dangerous SQL statements in WHERE clauses
// Returns an error if any dangerous keywords are found
func validateWhereClauseSecurity(where string) error {
if where == "" {
return nil
}
lowerWhere := strings.ToLower(where)
// List of dangerous SQL keywords that should never appear in WHERE clauses
dangerousKeywords := []string{
"delete ", "delete\t", "delete\n", "delete;",
"update ", "update\t", "update\n", "update;",
"truncate ", "truncate\t", "truncate\n", "truncate;",
"drop ", "drop\t", "drop\n", "drop;",
"alter ", "alter\t", "alter\n", "alter;",
"create ", "create\t", "create\n", "create;",
"insert ", "insert\t", "insert\n", "insert;",
"grant ", "grant\t", "grant\n", "grant;",
"revoke ", "revoke\t", "revoke\n", "revoke;",
"exec ", "exec\t", "exec\n", "exec;",
"execute ", "execute\t", "execute\n", "execute;",
";delete", ";update", ";truncate", ";drop", ";alter", ";create", ";insert",
}
for _, keyword := range dangerousKeywords {
if strings.Contains(lowerWhere, keyword) {
logger.Error("Dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
return fmt.Errorf("dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
}
}
return nil
}
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
//
// Parameters:
// - where: The WHERE clause string to sanitize
// - tableName: The correct table/relation name to use when fixing incorrect prefixes
// - options: Optional RequestOptions containing preload relations that should be allowed as valid prefixes
//
// Returns:
// - The sanitized WHERE clause with trivial conditions removed and incorrect prefixes fixed
// - An empty string if all conditions were trivial or the input was empty
//
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
// prefix matches a preloaded relation name, in which case it's left unchanged.
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
if where == "" {
return ""
}
where = strings.TrimSpace(where)
// Validate that the WHERE clause doesn't contain dangerous SQL statements
if err := validateWhereClauseSecurity(where); err != nil {
logger.Debug("Security validation failed for WHERE clause: %v", err)
return ""
}
// Strip outer parentheses and re-trim
where = stripOuterParentheses(where)
// Get valid columns from the model if tableName is provided
var validColumns map[string]bool
if tableName != "" {
validColumns = getValidColumnsForTable(tableName)
}
// Build a set of allowed table prefixes (main table + preloaded relations)
allowedPrefixes := make(map[string]bool)
if tableName != "" {
allowedPrefixes[tableName] = true
}
// Add preload relation names as allowed prefixes
if len(options) > 0 && options[0] != nil {
for pi := range options[0].Preload {
if options[0].Preload[pi].Relation != "" {
allowedPrefixes[options[0].Preload[pi].Relation] = true
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
}
}
}
// Split by AND to handle multiple conditions
conditions := splitByAND(where)
validConditions := make([]string, 0, len(conditions))
for _, cond := range conditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Strip parentheses from the condition before checking
condToCheck := stripOuterParentheses(cond)
// Skip trivial conditions that always evaluate to true
if IsTrivialCondition(condToCheck) {
logger.Debug("Removing trivial condition: '%s'", cond)
continue
}
// If tableName is provided and the condition HAS a table prefix, check if it's correct
if tableName != "" && hasTablePrefix(condToCheck) {
// Extract the current prefix and column name
currentPrefix, columnName := extractTableAndColumn(condToCheck)
if currentPrefix != "" && columnName != "" {
// Check if the prefix is allowed (main table or preload relation)
if !allowedPrefixes[currentPrefix] {
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace the incorrect prefix with the correct main table name
oldRef := currentPrefix + "." + columnName
newRef := tableName + "." + columnName
cond = strings.Replace(cond, oldRef, newRef, 1)
logger.Debug("Fixed incorrect table prefix in condition: '%s' -> '%s'", oldRef, newRef)
} else {
logger.Debug("Skipping prefix fix for '%s.%s' - not a valid column in main table (might be preload relation)", currentPrefix, columnName)
}
}
}
} else if tableName != "" && !hasTablePrefix(condToCheck) {
// If tableName is provided and the condition DOESN'T have a table prefix,
// qualify unambiguous column references to prevent "ambiguous column" errors
// when there are multiple joins on the same table (e.g., recursive preloads)
columnName := extractUnqualifiedColumnName(condToCheck)
if columnName != "" && (validColumns == nil || isValidColumn(columnName, validColumns)) {
// Qualify the column with the table name
// Be careful to only replace the column name, not other occurrences of the string
oldRef := columnName
newRef := tableName + "." + columnName
// Use word boundary matching to avoid replacing partial matches
cond = qualifyColumnInCondition(cond, oldRef, newRef)
logger.Debug("Qualified unqualified column in condition: '%s' added table prefix '%s'", oldRef, tableName)
}
}
validConditions = append(validConditions, cond)
}
if len(validConditions) == 0 {
return ""
}
result := strings.Join(validConditions, " AND ")
if result != where {
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
}
return result
}
// stripOuterParentheses removes matching outer parentheses from a string
// It handles nested parentheses correctly
func stripOuterParentheses(s string) string {
s = strings.TrimSpace(s)
for {
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
return s
}
// Check if these parentheses match (i.e., they're the outermost pair)
depth := 0
matched := false
for i := 0; i < len(s); i++ {
switch s[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 && i == len(s)-1 {
matched = true
} else if depth == 0 {
// Found a closing paren before the end, so outer parens don't match
return s
}
}
}
if !matched {
return s
}
// Strip the outer parentheses and continue
s = strings.TrimSpace(s[1 : len(s)-1])
}
}
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
// This is parenthesis-aware and won't split on AND operators inside subqueries
func splitByAND(where string) []string {
conditions := []string{}
currentCondition := strings.Builder{}
depth := 0 // Track parenthesis depth
i := 0
for i < len(where) {
ch := where[i]
// Track parenthesis depth
if ch == '(' {
depth++
currentCondition.WriteByte(ch)
i++
continue
} else if ch == ')' {
depth--
currentCondition.WriteByte(ch)
i++
continue
}
// Only look for AND operators at depth 0 (not inside parentheses)
if depth == 0 {
// Check if we're at an AND operator (case-insensitive)
// We need at least " AND " (5 chars) or " and " (5 chars)
if i+5 <= len(where) {
substring := where[i : i+5]
lowerSubstring := strings.ToLower(substring)
if lowerSubstring == " and " {
// Found an AND operator at the top level
// Add the current condition to the list
conditions = append(conditions, currentCondition.String())
currentCondition.Reset()
// Skip past the AND operator
i += 5
continue
}
}
}
// Not an AND operator or we're inside parentheses, just add the character
currentCondition.WriteByte(ch)
i++
}
// Add the last condition
if currentCondition.Len() > 0 {
conditions = append(conditions, currentCondition.String())
}
return conditions
}
// hasTablePrefix checks if a condition already has a table/relation prefix (contains a dot)
func hasTablePrefix(cond string) bool {
// Look for patterns like "table.column" or "`table`.`column`" or "\"table\".\"column\""
return strings.Contains(cond, ".")
}
// ExtractColumnName extracts the column name from a WHERE condition
// For example: "status = 'active'" returns "status"
func ExtractColumnName(cond string) string {
// Common SQL operators
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
for _, op := range operators {
if idx := strings.Index(cond, op); idx > 0 {
columnName := strings.TrimSpace(cond[:idx])
// Remove quotes if present
columnName = strings.Trim(columnName, "`\"'")
return columnName
}
}
// If no operator found, check if it's a simple identifier (for boolean columns)
parts := strings.Fields(cond)
if len(parts) > 0 {
columnName := strings.Trim(parts[0], "`\"'")
// Check if it's a valid identifier (not a SQL keyword)
if !IsSQLKeyword(strings.ToLower(columnName)) {
return columnName
}
}
return ""
}
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
func IsSQLKeyword(word string) bool {
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
for _, kw := range keywords {
if word == kw {
return true
}
}
return false
}
// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry
// Returns a map of column names for fast lookup, or nil if the model is not found
func getValidColumnsForTable(tableName string) map[string]bool {
// Try to get the model from the registry
model, err := modelregistry.GetModelByName(tableName)
if err != nil {
// Model not found, return nil to indicate we should use fallback behavior
return nil
}
// Get SQL columns from the model
columns := reflection.GetSQLModelColumns(model)
if len(columns) == 0 {
// No columns found, return nil
return nil
}
// Build a map for fast lookup
columnMap := make(map[string]bool, len(columns))
for _, col := range columns {
columnMap[strings.ToLower(col)] = true
}
return columnMap
}
// extractTableAndColumn extracts the table prefix and column name from a qualified reference
// For example: "users.status = 'active'" returns ("users", "status")
// Returns empty strings if no table prefix is found
// This function is parenthesis-aware and will only look for operators outside of subqueries
func extractTableAndColumn(cond string) (table string, column string) {
// Common SQL operators to find the column reference
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
var columnRef string
// Find the column reference (left side of the operator)
// We need to find the first operator that appears OUTSIDE of parentheses
minIdx := -1
for _, op := range operators {
idx := findOperatorOutsideParentheses(cond, op)
if idx > 0 && (minIdx == -1 || idx < minIdx) {
minIdx = idx
}
}
if minIdx > 0 {
columnRef = strings.TrimSpace(cond[:minIdx])
}
// If no operator found, the whole condition might be the column reference
if columnRef == "" {
parts := strings.Fields(cond)
if len(parts) > 0 {
columnRef = parts[0]
}
}
if columnRef == "" {
return "", ""
}
// Remove any quotes
columnRef = strings.Trim(columnRef, "`\"'")
// Check if there's a function call (contains opening parenthesis)
openParenIdx := strings.Index(columnRef, "(")
if openParenIdx >= 0 {
// There's a function call - find the FIRST dot after the opening paren
// This handles cases like: ifblnk(users.status, orders.status) - extracts users.status
dotIdx := strings.Index(columnRef[openParenIdx:], ".")
if dotIdx > 0 {
dotIdx += openParenIdx // Adjust to absolute position
// Extract table name (between paren and dot)
// Find the last opening paren before this dot
lastOpenParen := strings.LastIndex(columnRef[:dotIdx], "(")
table = columnRef[lastOpenParen+1 : dotIdx]
// Find the column name - it ends at comma, closing paren, whitespace, or end of string
columnStart := dotIdx + 1
columnEnd := len(columnRef)
for i := columnStart; i < len(columnRef); i++ {
ch := columnRef[i]
if ch == ',' || ch == ')' || ch == ' ' || ch == '\t' {
columnEnd = i
break
}
}
column = columnRef[columnStart:columnEnd]
// Remove quotes from table and column if present
table = strings.Trim(table, "`\"'")
column = strings.Trim(column, "`\"'")
return table, column
}
}
// No function call - check if it contains a dot (qualified reference)
// Use LastIndex to handle schema.table.column properly
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
table = columnRef[:dotIdx]
column = columnRef[dotIdx+1:]
// Remove quotes from table and column if present
table = strings.Trim(table, "`\"'")
column = strings.Trim(column, "`\"'")
return table, column
}
return "", ""
}
// extractUnqualifiedColumnName extracts the column name from an unqualified condition
// For example: "rid_parentmastertaskitem is null" returns "rid_parentmastertaskitem"
// "status = 'active'" returns "status"
func extractUnqualifiedColumnName(cond string) string {
// Common SQL operators
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
// Find the column reference (left side of the operator)
minIdx := -1
for _, op := range operators {
idx := strings.Index(cond, op)
if idx > 0 && (minIdx == -1 || idx < minIdx) {
minIdx = idx
}
}
var columnRef string
if minIdx > 0 {
columnRef = strings.TrimSpace(cond[:minIdx])
} else {
// No operator found, might be a single column reference
parts := strings.Fields(cond)
if len(parts) > 0 {
columnRef = parts[0]
}
}
if columnRef == "" {
return ""
}
// Remove any quotes
columnRef = strings.Trim(columnRef, "`\"'")
// Return empty if it contains a dot (already qualified) or function call
if strings.Contains(columnRef, ".") || strings.Contains(columnRef, "(") {
return ""
}
return columnRef
}
// qualifyColumnInCondition replaces an unqualified column name with a qualified one in a condition
// Uses word boundaries to avoid partial matches
// For example: qualifyColumnInCondition("rid_item is null", "rid_item", "table.rid_item")
// returns "table.rid_item is null"
func qualifyColumnInCondition(cond, oldRef, newRef string) string {
// Use word boundary matching with Go's supported regex syntax
// \b matches word boundaries
escapedOld := regexp.QuoteMeta(oldRef)
pattern := `\b` + escapedOld + `\b`
re, err := regexp.Compile(pattern)
if err != nil {
// If regex fails, fall back to simple string replacement
logger.Debug("Failed to compile regex for column qualification, using simple replace: %v", err)
return strings.Replace(cond, oldRef, newRef, 1)
}
// Only replace if the match is not preceded by a dot (to avoid replacing already qualified columns)
result := cond
matches := re.FindAllStringIndex(cond, -1)
// Process matches in reverse order to maintain correct indices
for i := len(matches) - 1; i >= 0; i-- {
match := matches[i]
start := match[0]
// Check if preceded by a dot (already qualified)
if start > 0 && cond[start-1] == '.' {
continue
}
// Replace this occurrence
result = result[:start] + newRef + result[match[1]:]
}
return result
}
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
// Returns the index of the operator, or -1 if not found or only found inside parentheses
func findOperatorOutsideParentheses(s string, operator string) int {
depth := 0
inSingleQuote := false
inDoubleQuote := false
for i := 0; i < len(s); i++ {
ch := s[i]
// Track quote state (operators inside quotes should be ignored)
if ch == '\'' && !inDoubleQuote {
inSingleQuote = !inSingleQuote
continue
}
if ch == '"' && !inSingleQuote {
inDoubleQuote = !inDoubleQuote
continue
}
// Skip if we're inside quotes
if inSingleQuote || inDoubleQuote {
continue
}
// Track parenthesis depth
switch ch {
case '(':
depth++
case ')':
depth--
}
// Only look for the operator when we're outside parentheses (depth == 0)
if depth == 0 {
// Check if the operator starts at this position
if i+len(operator) <= len(s) {
if s[i:i+len(operator)] == operator {
return i
}
}
}
}
return -1
}
// isValidColumn checks if a column name exists in the valid columns map
// Handles case-insensitive comparison
func isValidColumn(columnName string, validColumns map[string]bool) bool {
if validColumns == nil {
return true // No model info, assume valid
}
return validColumns[strings.ToLower(columnName)]
}

View File

@ -0,0 +1,641 @@
package common
import (
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
func TestSanitizeWhereClause(t *testing.T) {
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "trivial conditions in parentheses",
where: "(true AND true AND true)",
tableName: "mastertask",
expected: "",
},
{
name: "trivial conditions without parentheses",
where: "true AND true AND true",
tableName: "mastertask",
expected: "",
},
{
name: "single trivial condition",
where: "true",
tableName: "mastertask",
expected: "",
},
{
name: "valid condition with parentheses - prefix added to prevent ambiguity",
where: "(status = 'active')",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "mixed trivial and valid conditions - prefix added",
where: "true AND status = 'active' AND 1=1",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "condition with correct table prefix - unchanged",
where: "users.status = 'active'",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "condition with incorrect table prefix - fixed",
where: "wrong_table.status = 'active'",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "multiple conditions with incorrect prefix - fixed",
where: "wrong_table.status = 'active' AND wrong_table.age > 18",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "multiple valid conditions without prefix - prefixes added",
where: "status = 'active' AND age > 18",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "no table name provided",
where: "status = 'active'",
tableName: "",
expected: "status = 'active'",
},
{
name: "empty where clause",
where: "",
tableName: "users",
expected: "",
},
{
name: "mixed correct and incorrect prefixes",
where: "users.status = 'active' AND wrong_table.age > 18",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "mixed case AND operators",
where: "status = 'active' AND age > 18 and name = 'John'",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18 AND users.name = 'John'",
},
{
name: "subquery with ORDER BY and LIMIT - allowed",
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
tableName: "users",
expected: "users.id IN (SELECT users.id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
},
{
name: "dangerous DELETE keyword - blocked",
where: "status = 'active'; DELETE FROM users",
tableName: "users",
expected: "",
},
{
name: "dangerous UPDATE keyword - blocked",
where: "1=1; UPDATE users SET admin = true",
tableName: "users",
expected: "",
},
{
name: "dangerous TRUNCATE keyword - blocked",
where: "status = 'active' OR TRUNCATE TABLE users",
tableName: "users",
expected: "",
},
{
name: "dangerous DROP keyword - blocked",
where: "status = 'active'; DROP TABLE users",
tableName: "users",
expected: "",
},
{
name: "subquery with table alias should not be modified",
where: "apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576)",
tableName: "apiprovider",
expected: "apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576)",
},
{
name: "complex subquery with AND and multiple operators",
where: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
tableName: "apiprovider",
expected: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}
func TestStripOuterParentheses(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "single level parentheses",
input: "(true)",
expected: "true",
},
{
name: "multiple levels",
input: "((true))",
expected: "true",
},
{
name: "no parentheses",
input: "true",
expected: "true",
},
{
name: "mismatched parentheses",
input: "(true",
expected: "(true",
},
{
name: "complex expression",
input: "(a AND b)",
expected: "a AND b",
},
{
name: "nested but not outer",
input: "(a AND (b OR c)) AND d",
expected: "(a AND (b OR c)) AND d",
},
{
name: "with spaces",
input: " ( true ) ",
expected: "true",
},
{
name: "complex sub query",
input: "(a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3)",
expected: "a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := stripOuterParentheses(tt.input)
if result != tt.expected {
t.Errorf("stripOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestIsTrivialCondition(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{"true", "true", true},
{"true with spaces", " true ", true},
{"TRUE uppercase", "TRUE", true},
{"1=1", "1=1", true},
{"1 = 1", "1 = 1", true},
{"true = true", "true = true", true},
{"valid condition", "status = 'active'", false},
{"false", "false", false},
{"column name", "is_active", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsTrivialCondition(tt.input)
if result != tt.expected {
t.Errorf("IsTrivialCondition(%q) = %v; want %v", tt.input, result, tt.expected)
}
})
}
}
func TestExtractTableAndColumn(t *testing.T) {
tests := []struct {
name string
input string
expectedTable string
expectedCol string
}{
{
name: "qualified column with equals",
input: "users.status = 'active'",
expectedTable: "users",
expectedCol: "status",
},
{
name: "qualified column with greater than",
input: "users.age > 18",
expectedTable: "users",
expectedCol: "age",
},
{
name: "qualified column with LIKE",
input: "users.name LIKE '%john%'",
expectedTable: "users",
expectedCol: "name",
},
{
name: "qualified column with IN",
input: "users.status IN ('active', 'pending')",
expectedTable: "users",
expectedCol: "status",
},
{
name: "unqualified column",
input: "status = 'active'",
expectedTable: "",
expectedCol: "",
},
{
name: "qualified with backticks",
input: "`users`.`status` = 'active'",
expectedTable: "users",
expectedCol: "status",
},
{
name: "schema.table.column reference",
input: "public.users.status = 'active'",
expectedTable: "public.users",
expectedCol: "status",
},
{
name: "empty string",
input: "",
expectedTable: "",
expectedCol: "",
},
{
name: "function call with table.column - ifblnk",
input: "ifblnk(users.status,0) in (1,2,3,4)",
expectedTable: "users",
expectedCol: "status",
},
{
name: "function call with table.column - coalesce",
input: "coalesce(users.age, 0) = 25",
expectedTable: "users",
expectedCol: "age",
},
{
name: "nested function calls",
input: "upper(trim(users.name)) = 'JOHN'",
expectedTable: "users",
expectedCol: "name",
},
{
name: "function with multiple args and table.column",
input: "substring(users.email, 1, 5) = 'admin'",
expectedTable: "users",
expectedCol: "email",
},
{
name: "cast function with table.column",
input: "cast(orders.total as decimal) > 100",
expectedTable: "orders",
expectedCol: "total",
},
{
name: "complex nested functions",
input: "coalesce(nullif(users.status, ''), 'default') = 'active'",
expectedTable: "users",
expectedCol: "status",
},
{
name: "function with multiple table.column refs (extracts first)",
input: "greatest(users.created_at, users.updated_at) > '2024-01-01'",
expectedTable: "users",
expectedCol: "created_at",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
table, col := extractTableAndColumn(tt.input)
if table != tt.expectedTable || col != tt.expectedCol {
t.Errorf("extractTableAndColumn(%q) = (%q, %q); want (%q, %q)",
tt.input, table, col, tt.expectedTable, tt.expectedCol)
}
})
}
}
func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
tests := []struct {
name string
where string
tableName string
options *RequestOptions
expected string
}{
{
name: "preload relation prefix is preserved",
where: "Department.name = 'Engineering'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
},
},
expected: "Department.name = 'Engineering'",
},
{
name: "multiple preload relations - all preserved",
where: "Department.name = 'Engineering' AND Manager.status = 'active'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
{Relation: "Manager"},
},
},
expected: "Department.name = 'Engineering' AND Manager.status = 'active'",
},
{
name: "mix of main table and preload relation",
where: "users.status = 'active' AND Department.name = 'Engineering'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
},
},
expected: "users.status = 'active' AND Department.name = 'Engineering'",
},
{
name: "incorrect prefix fixed when not a preload relation",
where: "wrong_table.status = 'active' AND Department.name = 'Engineering'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
},
},
expected: "users.status = 'active' AND Department.name = 'Engineering'",
},
{
name: "Function Call with correct table prefix - unchanged",
where: "ifblnk(users.status,0) in (1,2,3,4)",
tableName: "users",
options: nil,
expected: "ifblnk(users.status,0) in (1,2,3,4)",
},
{
name: "no options provided - works as before",
where: "wrong_table.status = 'active'",
tableName: "users",
options: nil,
expected: "users.status = 'active'",
},
{
name: "empty preload list - works as before",
where: "wrong_table.status = 'active'",
tableName: "users",
options: &RequestOptions{Preload: []PreloadOption{}},
expected: "users.status = 'active'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var result string
if tt.options != nil {
result = SanitizeWhereClause(tt.where, tt.tableName, tt.options)
} else {
result = SanitizeWhereClause(tt.where, tt.tableName)
}
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q, options) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}
// Test model for model-aware sanitization tests
type MasterTask struct {
ID int `bun:"id,pk"`
Name string `bun:"name"`
Status string `bun:"status"`
UserID int `bun:"user_id"`
}
func TestSplitByAND(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "uppercase AND",
input: "status = 'active' AND age > 18",
expected: []string{"status = 'active'", "age > 18"},
},
{
name: "lowercase and",
input: "status = 'active' and age > 18",
expected: []string{"status = 'active'", "age > 18"},
},
{
name: "mixed case AND",
input: "status = 'active' AND age > 18 and name = 'John'",
expected: []string{"status = 'active'", "age > 18", "name = 'John'"},
},
{
name: "single condition",
input: "status = 'active'",
expected: []string{"status = 'active'"},
},
{
name: "multiple uppercase AND",
input: "a = 1 AND b = 2 AND c = 3",
expected: []string{"a = 1", "b = 2", "c = 3"},
},
{
name: "multiple case subquery",
input: "a = 1 AND b = 2 AND c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
expected: []string{"a = 1", "b = 2", "c = 3", "(select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := splitByAND(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("splitByAND(%q) returned %d conditions; want %d", tt.input, len(result), len(tt.expected))
return
}
for i := range result {
if strings.TrimSpace(result[i]) != strings.TrimSpace(tt.expected[i]) {
t.Errorf("splitByAND(%q)[%d] = %q; want %q", tt.input, i, result[i], tt.expected[i])
}
}
})
}
}
func TestValidateWhereClauseSecurity(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
}{
{
name: "safe WHERE clause",
input: "status = 'active' AND age > 18",
expectError: false,
},
{
name: "safe subquery",
input: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
expectError: false,
},
{
name: "DELETE keyword",
input: "status = 'active'; DELETE FROM users",
expectError: true,
},
{
name: "UPDATE keyword",
input: "1=1; UPDATE users SET admin = true",
expectError: true,
},
{
name: "TRUNCATE keyword",
input: "status = 'active' OR TRUNCATE TABLE users",
expectError: true,
},
{
name: "DROP keyword",
input: "status = 'active'; DROP TABLE users",
expectError: true,
},
{
name: "INSERT keyword",
input: "status = 'active'; INSERT INTO users (name) VALUES ('hacker')",
expectError: true,
},
{
name: "ALTER keyword",
input: "1=1; ALTER TABLE users ADD COLUMN is_admin BOOLEAN",
expectError: true,
},
{
name: "CREATE keyword",
input: "1=1; CREATE TABLE malicious (id INT)",
expectError: true,
},
{
name: "empty clause",
input: "",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateWhereClauseSecurity(tt.input)
if tt.expectError && err == nil {
t.Errorf("validateWhereClauseSecurity(%q) expected error but got none", tt.input)
}
if !tt.expectError && err != nil {
t.Errorf("validateWhereClauseSecurity(%q) unexpected error: %v", tt.input, err)
}
})
}
}
func TestSanitizeWhereClauseWithModel(t *testing.T) {
// Register the test model
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
if err != nil {
// Model might already be registered, ignore error
t.Logf("Model registration returned: %v", err)
}
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "valid column without prefix - no prefix added",
where: "status = 'active'",
tableName: "mastertask",
expected: "status = 'active'",
},
{
name: "multiple valid columns without prefix - no prefix added",
where: "status = 'active' AND user_id = 123",
tableName: "mastertask",
expected: "status = 'active' AND user_id = 123",
},
{
name: "incorrect table prefix on valid column - fixed",
where: "wrong_table.status = 'active'",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "incorrect prefix on invalid column - not fixed",
where: "wrong_table.invalid_column = 'value'",
tableName: "mastertask",
expected: "wrong_table.invalid_column = 'value'",
},
{
name: "mix of valid and trivial conditions",
where: "true AND status = 'active' AND 1=1",
tableName: "mastertask",
expected: "status = 'active'",
},
{
name: "parentheses with valid column - no prefix added",
where: "(status = 'active')",
tableName: "mastertask",
expected: "status = 'active'",
},
{
name: "correct prefix - unchanged",
where: "mastertask.status = 'active'",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "multiple conditions with mixed prefixes",
where: "mastertask.status = 'active' AND wrong_table.user_id = 123",
tableName: "mastertask",
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@ -9,18 +9,18 @@ import (
"github.com/google/uuid"
)
// TestSqlInt16 tests SqlInt16 type
func TestSqlInt16(t *testing.T) {
// TestNewSqlInt16 tests NewSqlInt16 type
func TestNewSqlInt16(t *testing.T) {
tests := []struct {
name string
input interface{}
expected SqlInt16
}{
{"int", 42, SqlInt16(42)},
{"int32", int32(100), SqlInt16(100)},
{"int64", int64(200), SqlInt16(200)},
{"string", "123", SqlInt16(123)},
{"nil", nil, SqlInt16(0)},
{"int", 42, Null(int16(42), true)},
{"int32", int32(100), NewSqlInt16(100)},
{"int64", int64(200), NewSqlInt16(200)},
{"string", "123", NewSqlInt16(123)},
{"nil", nil, Null(int16(0), false)},
}
for _, tt := range tests {
@ -36,15 +36,15 @@ func TestSqlInt16(t *testing.T) {
}
}
func TestSqlInt16_Value(t *testing.T) {
func TestNewSqlInt16_Value(t *testing.T) {
tests := []struct {
name string
input SqlInt16
expected driver.Value
}{
{"zero", SqlInt16(0), nil},
{"positive", SqlInt16(42), int64(42)},
{"negative", SqlInt16(-10), int64(-10)},
{"zero", Null(int16(0), false), nil},
{"positive", NewSqlInt16(42), int16(42)},
{"negative", NewSqlInt16(-10), int16(-10)},
}
for _, tt := range tests {
@ -60,8 +60,8 @@ func TestSqlInt16_Value(t *testing.T) {
}
}
func TestSqlInt16_JSON(t *testing.T) {
n := SqlInt16(42)
func TestNewSqlInt16_JSON(t *testing.T) {
n := NewSqlInt16(42)
// Marshal
data, err := json.Marshal(n)
@ -78,24 +78,24 @@ func TestSqlInt16_JSON(t *testing.T) {
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if n2 != 123 {
t.Errorf("expected 123, got %d", n2)
if n2.Int64() != 123 {
t.Errorf("expected 123, got %d", n2.Int64())
}
}
// TestSqlInt64 tests SqlInt64 type
func TestSqlInt64(t *testing.T) {
// TestNewSqlInt64 tests NewSqlInt64 type
func TestNewSqlInt64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected SqlInt64
}{
{"int", 42, SqlInt64(42)},
{"int32", int32(100), SqlInt64(100)},
{"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)},
{"uint32", uint32(100), SqlInt64(100)},
{"uint64", uint64(200), SqlInt64(200)},
{"nil", nil, SqlInt64(0)},
{"int", 42, NewSqlInt64(42)},
{"int32", int32(100), NewSqlInt64(100)},
{"int64", int64(9223372036854775807), NewSqlInt64(9223372036854775807)},
{"uint32", uint32(100), NewSqlInt64(100)},
{"uint64", uint64(200), NewSqlInt64(200)},
{"nil", nil, SqlInt64{}},
}
for _, tt := range tests {
@ -135,8 +135,8 @@ func TestSqlFloat64(t *testing.T) {
if n.Valid != tt.valid {
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
}
if tt.valid && n.Float64 != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n.Float64)
if tt.valid && n.Float64() != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n.Float64())
}
})
}
@ -162,7 +162,7 @@ func TestSqlTimeStamp(t *testing.T) {
if err := ts.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if ts.GetTime().IsZero() {
if ts.Time().IsZero() {
t.Error("expected non-zero time")
}
})
@ -171,7 +171,7 @@ func TestSqlTimeStamp(t *testing.T) {
func TestSqlTimeStamp_JSON(t *testing.T) {
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
ts := SqlTimeStamp(now)
ts := NewSqlTimeStamp(now)
// Marshal
data, err := json.Marshal(ts)
@ -188,8 +188,8 @@ func TestSqlTimeStamp_JSON(t *testing.T) {
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if ts2.GetTime().Year() != 2024 {
t.Errorf("expected year 2024, got %d", ts2.GetTime().Year())
if ts2.Time().Year() != 2024 {
t.Errorf("expected year 2024, got %d", ts2.Time().Year())
}
// Test null
@ -226,7 +226,7 @@ func TestSqlDate(t *testing.T) {
}
func TestSqlDate_JSON(t *testing.T) {
date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
date := NewSqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
// Marshal
data, err := json.Marshal(date)
@ -471,8 +471,8 @@ func TestSqlUUID_Scan(t *testing.T) {
if u.Valid != tt.valid {
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
}
if tt.valid && u.String != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, u.String)
if tt.valid && u.String() != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, u.String())
}
})
}
@ -480,13 +480,13 @@ func TestSqlUUID_Scan(t *testing.T) {
func TestSqlUUID_Value(t *testing.T) {
testUUID := uuid.New()
u := SqlUUID{String: testUUID.String(), Valid: true}
u := NewSqlUUID(testUUID)
val, err := u.Value()
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if val != testUUID.String() {
if val != testUUID {
t.Errorf("expected %s, got %s", testUUID.String(), val)
}
@ -503,7 +503,7 @@ func TestSqlUUID_Value(t *testing.T) {
func TestSqlUUID_JSON(t *testing.T) {
testUUID := uuid.New()
u := SqlUUID{String: testUUID.String(), Valid: true}
u := NewSqlUUID(testUUID)
// Marshal
data, err := json.Marshal(u)
@ -520,8 +520,8 @@ func TestSqlUUID_JSON(t *testing.T) {
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if u2.String != testUUID.String() {
t.Errorf("expected %s, got %s", testUUID.String(), u2.String)
if u2.String() != testUUID.String() {
t.Errorf("expected %s, got %s", testUUID.String(), u2.String())
}
// Test null

View File

@ -32,15 +32,22 @@ type Parameter struct {
}
type PreloadOption struct {
Relation string `json:"relation"`
Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"`
Sort []SortOption `json:"sort"`
Filters []FilterOption `json:"filters"`
Where string `json:"where"`
Limit *int `json:"limit"`
Offset *int `json:"offset"`
Updatable *bool `json:"updateable"` // if true, the relation can be updated
Relation string `json:"relation"`
Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"`
Sort []SortOption `json:"sort"`
Filters []FilterOption `json:"filters"`
Where string `json:"where"`
Limit *int `json:"limit"`
Offset *int `json:"offset"`
Updatable *bool `json:"updateable"` // if true, the relation can be updated
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
// Relationship keys from XFiles - used to build proper foreign key filters
PrimaryKey string `json:"primary_key"` // Primary key of the related table
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
}
type FilterOption struct {

View File

@ -6,6 +6,7 @@ import (
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ColumnValidator validates column names against a model's fields
@ -92,23 +93,6 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
return strings.ToLower(field.Name)
}
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func extractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// ValidateColumn validates a single column name
// Returns nil if valid, error if invalid
// Columns prefixed with "cql" (case insensitive) are always valid
@ -125,7 +109,7 @@ func (v *ColumnValidator) ValidateColumn(column string) error {
}
// Extract source column name (remove JSON operators like ->> or ->)
sourceColumn := extractSourceColumn(column)
sourceColumn := reflection.ExtractSourceColumn(column)
// Check if column exists in model
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {

View File

@ -2,6 +2,8 @@ package common
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
func TestExtractSourceColumn(t *testing.T) {
@ -49,9 +51,9 @@ func TestExtractSourceColumn(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := extractSourceColumn(tc.input)
result := reflection.ExtractSourceColumn(tc.input)
if result != tc.expected {
t.Errorf("extractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
t.Errorf("reflection.ExtractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
}
})
}

291
pkg/config/README.md Normal file
View File

@ -0,0 +1,291 @@
# ResolveSpec Configuration System
A centralized configuration system with support for multiple configuration sources: config files (YAML, TOML, JSON), environment variables, and programmatic configuration.
## Features
- **Multiple Config Sources**: Config files, environment variables, and code
- **Priority Order**: Environment variables > Config file > Defaults
- **Multiple Formats**: YAML, TOML, JSON supported
- **Type Safety**: Strongly-typed configuration structs
- **Sensible Defaults**: Works out of the box with reasonable defaults
## Quick Start
### Basic Usage
```go
import "github.com/heinhel/ResolveSpec/pkg/config"
// Create a new config manager
mgr := config.NewManager()
// Load configuration from file and environment
if err := mgr.Load(); err != nil {
log.Fatal(err)
}
// Get the complete configuration
cfg, err := mgr.GetConfig()
if err != nil {
log.Fatal(err)
}
// Use the configuration
fmt.Println("Server address:", cfg.Server.Addr)
```
### Custom Configuration Paths
```go
mgr := config.NewManagerWithOptions(
config.WithConfigFile("/path/to/config.yaml"),
config.WithEnvPrefix("MYAPP"),
)
```
## Configuration Sources
### 1. Config Files
Place a `config.yaml` file in one of these locations:
- Current directory (`.`)
- `./config/`
- `/etc/resolvespec/`
- `$HOME/.resolvespec/`
Example `config.yaml`:
```yaml
server:
addr: ":8080"
shutdown_timeout: 30s
tracing:
enabled: true
service_name: "my-service"
cache:
provider: "redis"
redis:
host: "localhost"
port: 6379
```
### 2. Environment Variables
All configuration can be set via environment variables with the `RESOLVESPEC_` prefix:
```bash
export RESOLVESPEC_SERVER_ADDR=":9090"
export RESOLVESPEC_TRACING_ENABLED=true
export RESOLVESPEC_CACHE_PROVIDER=redis
export RESOLVESPEC_CACHE_REDIS_HOST=localhost
```
Nested configuration uses underscores:
- `server.addr``RESOLVESPEC_SERVER_ADDR`
- `cache.redis.host``RESOLVESPEC_CACHE_REDIS_HOST`
### 3. Programmatic Configuration
```go
mgr := config.NewManager()
mgr.Set("server.addr", ":9090")
mgr.Set("tracing.enabled", true)
cfg, _ := mgr.GetConfig()
```
## Configuration Options
### Server Configuration
```yaml
server:
addr: ":8080" # Server address
shutdown_timeout: 30s # Graceful shutdown timeout
drain_timeout: 25s # Connection drain timeout
read_timeout: 10s # HTTP read timeout
write_timeout: 10s # HTTP write timeout
idle_timeout: 120s # HTTP idle timeout
```
### Tracing Configuration
```yaml
tracing:
enabled: false # Enable/disable tracing
service_name: "resolvespec" # Service name
service_version: "1.0.0" # Service version
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
```
### Cache Configuration
```yaml
cache:
provider: "memory" # Options: memory, redis, memcache
redis:
host: "localhost"
port: 6379
password: ""
db: 0
memcache:
servers:
- "localhost:11211"
max_idle_conns: 10
timeout: 100ms
```
### Logger Configuration
```yaml
logger:
dev: false # Development mode (human-readable output)
path: "" # Log file path (empty = stdout)
```
### Middleware Configuration
```yaml
middleware:
rate_limit_rps: 100.0 # Requests per second
rate_limit_burst: 200 # Burst size
max_request_size: 10485760 # Max request size in bytes (10MB)
```
### CORS Configuration
```yaml
cors:
allowed_origins:
- "*"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
- "OPTIONS"
allowed_headers:
- "*"
max_age: 3600
```
### Database Configuration
```yaml
database:
url: "host=localhost user=postgres password=postgres dbname=mydb port=5432 sslmode=disable"
```
## Priority and Overrides
Configuration sources are applied in this order (highest priority first):
1. **Environment Variables** (highest priority)
2. **Config File**
3. **Defaults** (lowest priority)
This allows you to:
- Set defaults in code
- Override with a config file
- Override specific values with environment variables
## Examples
### Production Setup
```yaml
# config.yaml
server:
addr: ":8080"
tracing:
enabled: true
service_name: "myapi"
endpoint: "http://jaeger:4318/v1/traces"
cache:
provider: "redis"
redis:
host: "redis"
port: 6379
password: "${REDIS_PASSWORD}"
logger:
dev: false
path: "/var/log/myapi/app.log"
```
### Development Setup
```bash
# Use environment variables for development
export RESOLVESPEC_LOGGER_DEV=true
export RESOLVESPEC_TRACING_ENABLED=false
export RESOLVESPEC_CACHE_PROVIDER=memory
```
### Testing Setup
```go
// Override config for tests
mgr := config.NewManager()
mgr.Set("cache.provider", "memory")
mgr.Set("database.url", testDBURL)
cfg, _ := mgr.GetConfig()
```
## Best Practices
1. **Use config files for base configuration** - Define your standard settings
2. **Use environment variables for secrets** - Never commit passwords/tokens
3. **Use environment variables for deployment-specific values** - Different per environment
4. **Keep defaults sensible** - Application should work with minimal configuration
5. **Document your configuration** - Comment your config.yaml files
## Integration with ResolveSpec Components
The configuration system integrates seamlessly with ResolveSpec components:
```go
cfg, _ := config.NewManager().Load().GetConfig()
// Server
srv := server.NewGracefulServer(server.Config{
Addr: cfg.Server.Addr,
ShutdownTimeout: cfg.Server.ShutdownTimeout,
// ... other fields
})
// Tracing
if cfg.Tracing.Enabled {
tracer := tracing.Init(tracing.Config{
ServiceName: cfg.Tracing.ServiceName,
ServiceVersion: cfg.Tracing.ServiceVersion,
Endpoint: cfg.Tracing.Endpoint,
})
defer tracer.Shutdown(context.Background())
}
// Cache
var cacheProvider cache.Provider
switch cfg.Cache.Provider {
case "redis":
cacheProvider = cache.NewRedisProvider(cfg.Cache.Redis.Host, cfg.Cache.Redis.Port, ...)
case "memcache":
cacheProvider = cache.NewMemcacheProvider(cfg.Cache.Memcache.Servers, ...)
default:
cacheProvider = cache.NewMemoryProvider()
}
// Logger
logger.Init(cfg.Logger.Dev)
if cfg.Logger.Path != "" {
logger.UpdateLoggerPath(cfg.Logger.Path, cfg.Logger.Dev)
}
```

143
pkg/config/config.go Normal file
View File

@ -0,0 +1,143 @@
package config
import "time"
// Config represents the complete application configuration
type Config struct {
Server ServerConfig `mapstructure:"server"`
Tracing TracingConfig `mapstructure:"tracing"`
Cache CacheConfig `mapstructure:"cache"`
Logger LoggerConfig `mapstructure:"logger"`
ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"`
Middleware MiddlewareConfig `mapstructure:"middleware"`
CORS CORSConfig `mapstructure:"cors"`
Database DatabaseConfig `mapstructure:"database"`
EventBroker EventBrokerConfig `mapstructure:"event_broker"`
}
// ServerConfig holds server-related configuration
type ServerConfig struct {
Addr string `mapstructure:"addr"`
ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
DrainTimeout time.Duration `mapstructure:"drain_timeout"`
ReadTimeout time.Duration `mapstructure:"read_timeout"`
WriteTimeout time.Duration `mapstructure:"write_timeout"`
IdleTimeout time.Duration `mapstructure:"idle_timeout"`
}
// TracingConfig holds OpenTelemetry tracing configuration
type TracingConfig struct {
Enabled bool `mapstructure:"enabled"`
ServiceName string `mapstructure:"service_name"`
ServiceVersion string `mapstructure:"service_version"`
Endpoint string `mapstructure:"endpoint"`
}
// CacheConfig holds cache provider configuration
type CacheConfig struct {
Provider string `mapstructure:"provider"` // memory, redis, memcache
Redis RedisConfig `mapstructure:"redis"`
Memcache MemcacheConfig `mapstructure:"memcache"`
}
// RedisConfig holds Redis-specific configuration
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
}
// MemcacheConfig holds Memcache-specific configuration
type MemcacheConfig struct {
Servers []string `mapstructure:"servers"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
Timeout time.Duration `mapstructure:"timeout"`
}
// LoggerConfig holds logger configuration
type LoggerConfig struct {
Dev bool `mapstructure:"dev"`
Path string `mapstructure:"path"`
}
// MiddlewareConfig holds middleware configuration
type MiddlewareConfig struct {
RateLimitRPS float64 `mapstructure:"rate_limit_rps"`
RateLimitBurst int `mapstructure:"rate_limit_burst"`
MaxRequestSize int64 `mapstructure:"max_request_size"`
}
// CORSConfig holds CORS configuration
type CORSConfig struct {
AllowedOrigins []string `mapstructure:"allowed_origins"`
AllowedMethods []string `mapstructure:"allowed_methods"`
AllowedHeaders []string `mapstructure:"allowed_headers"`
MaxAge int `mapstructure:"max_age"`
}
// DatabaseConfig holds database configuration (primarily for testing)
type DatabaseConfig struct {
URL string `mapstructure:"url"`
}
// ErrorTrackingConfig holds error tracking configuration
type ErrorTrackingConfig struct {
Enabled bool `mapstructure:"enabled"`
Provider string `mapstructure:"provider"` // sentry, noop
DSN string `mapstructure:"dsn"` // Sentry DSN
Environment string `mapstructure:"environment"` // e.g., production, staging, development
Release string `mapstructure:"release"` // Application version/release
Debug bool `mapstructure:"debug"` // Enable debug mode
SampleRate float64 `mapstructure:"sample_rate"` // Error sample rate (0.0-1.0)
TracesSampleRate float64 `mapstructure:"traces_sample_rate"` // Traces sample rate (0.0-1.0)
}
// EventBrokerConfig contains configuration for the event broker
type EventBrokerConfig struct {
Enabled bool `mapstructure:"enabled"`
Provider string `mapstructure:"provider"` // memory, redis, nats, database
Mode string `mapstructure:"mode"` // sync, async
WorkerCount int `mapstructure:"worker_count"`
BufferSize int `mapstructure:"buffer_size"`
InstanceID string `mapstructure:"instance_id"`
Redis EventBrokerRedisConfig `mapstructure:"redis"`
NATS EventBrokerNATSConfig `mapstructure:"nats"`
Database EventBrokerDatabaseConfig `mapstructure:"database"`
RetryPolicy EventBrokerRetryPolicyConfig `mapstructure:"retry_policy"`
}
// EventBrokerRedisConfig contains Redis-specific configuration
type EventBrokerRedisConfig struct {
StreamName string `mapstructure:"stream_name"`
ConsumerGroup string `mapstructure:"consumer_group"`
MaxLen int64 `mapstructure:"max_len"`
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
}
// EventBrokerNATSConfig contains NATS-specific configuration
type EventBrokerNATSConfig struct {
URL string `mapstructure:"url"`
StreamName string `mapstructure:"stream_name"`
Subjects []string `mapstructure:"subjects"`
Storage string `mapstructure:"storage"` // file, memory
MaxAge time.Duration `mapstructure:"max_age"`
}
// EventBrokerDatabaseConfig contains database provider configuration
type EventBrokerDatabaseConfig struct {
TableName string `mapstructure:"table_name"`
Channel string `mapstructure:"channel"` // PostgreSQL NOTIFY channel name
PollInterval time.Duration `mapstructure:"poll_interval"`
}
// EventBrokerRetryPolicyConfig contains retry policy configuration
type EventBrokerRetryPolicyConfig struct {
MaxRetries int `mapstructure:"max_retries"`
InitialDelay time.Duration `mapstructure:"initial_delay"`
MaxDelay time.Duration `mapstructure:"max_delay"`
BackoffFactor float64 `mapstructure:"backoff_factor"`
}

203
pkg/config/manager.go Normal file
View File

@ -0,0 +1,203 @@
package config
import (
"fmt"
"strings"
"github.com/spf13/viper"
)
// Manager handles configuration loading from multiple sources
type Manager struct {
v *viper.Viper
}
// NewManager creates a new configuration manager with defaults
func NewManager() *Manager {
v := viper.New()
// Set configuration file settings
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(".")
v.AddConfigPath("./config")
v.AddConfigPath("/etc/resolvespec")
v.AddConfigPath("$HOME/.resolvespec")
// Enable environment variable support
v.SetEnvPrefix("RESOLVESPEC")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
// Set default values
setDefaults(v)
return &Manager{v: v}
}
// NewManagerWithOptions creates a new configuration manager with custom options
func NewManagerWithOptions(opts ...Option) *Manager {
m := NewManager()
for _, opt := range opts {
opt(m)
}
return m
}
// Option is a functional option for configuring the Manager
type Option func(*Manager)
// WithConfigFile sets a specific config file path
func WithConfigFile(path string) Option {
return func(m *Manager) {
m.v.SetConfigFile(path)
}
}
// WithConfigName sets the config file name (without extension)
func WithConfigName(name string) Option {
return func(m *Manager) {
m.v.SetConfigName(name)
}
}
// WithConfigPath adds a path to search for config files
func WithConfigPath(path string) Option {
return func(m *Manager) {
m.v.AddConfigPath(path)
}
}
// WithEnvPrefix sets the environment variable prefix
func WithEnvPrefix(prefix string) Option {
return func(m *Manager) {
m.v.SetEnvPrefix(prefix)
}
}
// Load attempts to load configuration from file and environment
func (m *Manager) Load() error {
// Try to read config file (not an error if it doesn't exist)
if err := m.v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return fmt.Errorf("error reading config file: %w", err)
}
// Config file not found; will rely on defaults and env vars
}
return nil
}
// GetConfig returns the complete configuration
func (m *Manager) GetConfig() (*Config, error) {
var cfg Config
if err := m.v.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
return &cfg, nil
}
// Get returns a configuration value by key
func (m *Manager) Get(key string) interface{} {
return m.v.Get(key)
}
// GetString returns a string configuration value
func (m *Manager) GetString(key string) string {
return m.v.GetString(key)
}
// GetInt returns an int configuration value
func (m *Manager) GetInt(key string) int {
return m.v.GetInt(key)
}
// GetBool returns a bool configuration value
func (m *Manager) GetBool(key string) bool {
return m.v.GetBool(key)
}
// Set sets a configuration value
func (m *Manager) Set(key string, value interface{}) {
m.v.Set(key, value)
}
// setDefaults sets default configuration values
func setDefaults(v *viper.Viper) {
// Server defaults
v.SetDefault("server.addr", ":8080")
v.SetDefault("server.shutdown_timeout", "30s")
v.SetDefault("server.drain_timeout", "25s")
v.SetDefault("server.read_timeout", "10s")
v.SetDefault("server.write_timeout", "10s")
v.SetDefault("server.idle_timeout", "120s")
// Tracing defaults
v.SetDefault("tracing.enabled", false)
v.SetDefault("tracing.service_name", "resolvespec")
v.SetDefault("tracing.service_version", "1.0.0")
v.SetDefault("tracing.endpoint", "")
// Cache defaults
v.SetDefault("cache.provider", "memory")
v.SetDefault("cache.redis.host", "localhost")
v.SetDefault("cache.redis.port", 6379)
v.SetDefault("cache.redis.password", "")
v.SetDefault("cache.redis.db", 0)
v.SetDefault("cache.memcache.servers", []string{"localhost:11211"})
v.SetDefault("cache.memcache.max_idle_conns", 10)
v.SetDefault("cache.memcache.timeout", "100ms")
// Logger defaults
v.SetDefault("logger.dev", false)
v.SetDefault("logger.path", "")
// Middleware defaults
v.SetDefault("middleware.rate_limit_rps", 100.0)
v.SetDefault("middleware.rate_limit_burst", 200)
v.SetDefault("middleware.max_request_size", 10485760) // 10MB
// CORS defaults
v.SetDefault("cors.allowed_origins", []string{"*"})
v.SetDefault("cors.allowed_methods", []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"})
v.SetDefault("cors.allowed_headers", []string{"*"})
v.SetDefault("cors.max_age", 3600)
// Database defaults
v.SetDefault("database.url", "")
// Event Broker defaults
v.SetDefault("event_broker.enabled", false)
v.SetDefault("event_broker.provider", "memory")
v.SetDefault("event_broker.mode", "async")
v.SetDefault("event_broker.worker_count", 10)
v.SetDefault("event_broker.buffer_size", 1000)
v.SetDefault("event_broker.instance_id", "")
// Event Broker - Redis defaults
v.SetDefault("event_broker.redis.stream_name", "resolvespec:events")
v.SetDefault("event_broker.redis.consumer_group", "resolvespec-workers")
v.SetDefault("event_broker.redis.max_len", 10000)
v.SetDefault("event_broker.redis.host", "localhost")
v.SetDefault("event_broker.redis.port", 6379)
v.SetDefault("event_broker.redis.password", "")
v.SetDefault("event_broker.redis.db", 0)
// Event Broker - NATS defaults
v.SetDefault("event_broker.nats.url", "nats://localhost:4222")
v.SetDefault("event_broker.nats.stream_name", "RESOLVESPEC_EVENTS")
v.SetDefault("event_broker.nats.subjects", []string{"events.>"})
v.SetDefault("event_broker.nats.storage", "file")
v.SetDefault("event_broker.nats.max_age", "24h")
// Event Broker - Database defaults
v.SetDefault("event_broker.database.table_name", "events")
v.SetDefault("event_broker.database.channel", "resolvespec_events")
v.SetDefault("event_broker.database.poll_interval", "1s")
// Event Broker - Retry Policy defaults
v.SetDefault("event_broker.retry_policy.max_retries", 3)
v.SetDefault("event_broker.retry_policy.initial_delay", "1s")
v.SetDefault("event_broker.retry_policy.max_delay", "30s")
v.SetDefault("event_broker.retry_policy.backoff_factor", 2.0)
}

166
pkg/config/manager_test.go Normal file
View File

@ -0,0 +1,166 @@
package config
import (
"os"
"testing"
"time"
)
func TestNewManager(t *testing.T) {
mgr := NewManager()
if mgr == nil {
t.Fatal("Expected manager to be non-nil")
}
if mgr.v == nil {
t.Fatal("Expected viper instance to be non-nil")
}
}
func TestDefaultValues(t *testing.T) {
mgr := NewManager()
if err := mgr.Load(); err != nil {
t.Fatalf("Failed to load config: %v", err)
}
cfg, err := mgr.GetConfig()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
// Test default values
tests := []struct {
name string
got interface{}
expected interface{}
}{
{"server.addr", cfg.Server.Addr, ":8080"},
{"server.shutdown_timeout", cfg.Server.ShutdownTimeout, 30 * time.Second},
{"tracing.enabled", cfg.Tracing.Enabled, false},
{"tracing.service_name", cfg.Tracing.ServiceName, "resolvespec"},
{"cache.provider", cfg.Cache.Provider, "memory"},
{"cache.redis.host", cfg.Cache.Redis.Host, "localhost"},
{"cache.redis.port", cfg.Cache.Redis.Port, 6379},
{"logger.dev", cfg.Logger.Dev, false},
{"middleware.rate_limit_rps", cfg.Middleware.RateLimitRPS, 100.0},
{"middleware.rate_limit_burst", cfg.Middleware.RateLimitBurst, 200},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.expected {
t.Errorf("%s: got %v, want %v", tt.name, tt.got, tt.expected)
}
})
}
}
func TestEnvironmentVariableOverrides(t *testing.T) {
// Set environment variables
os.Setenv("RESOLVESPEC_SERVER_ADDR", ":9090")
os.Setenv("RESOLVESPEC_TRACING_ENABLED", "true")
os.Setenv("RESOLVESPEC_CACHE_PROVIDER", "redis")
os.Setenv("RESOLVESPEC_LOGGER_DEV", "true")
defer func() {
os.Unsetenv("RESOLVESPEC_SERVER_ADDR")
os.Unsetenv("RESOLVESPEC_TRACING_ENABLED")
os.Unsetenv("RESOLVESPEC_CACHE_PROVIDER")
os.Unsetenv("RESOLVESPEC_LOGGER_DEV")
}()
mgr := NewManager()
if err := mgr.Load(); err != nil {
t.Fatalf("Failed to load config: %v", err)
}
cfg, err := mgr.GetConfig()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
// Test environment variable overrides
tests := []struct {
name string
got interface{}
expected interface{}
}{
{"server.addr", cfg.Server.Addr, ":9090"},
{"tracing.enabled", cfg.Tracing.Enabled, true},
{"cache.provider", cfg.Cache.Provider, "redis"},
{"logger.dev", cfg.Logger.Dev, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.expected {
t.Errorf("%s: got %v, want %v", tt.name, tt.got, tt.expected)
}
})
}
}
func TestProgrammaticConfiguration(t *testing.T) {
mgr := NewManager()
mgr.Set("server.addr", ":7070")
mgr.Set("tracing.service_name", "test-service")
cfg, err := mgr.GetConfig()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
if cfg.Server.Addr != ":7070" {
t.Errorf("server.addr: got %s, want :7070", cfg.Server.Addr)
}
if cfg.Tracing.ServiceName != "test-service" {
t.Errorf("tracing.service_name: got %s, want test-service", cfg.Tracing.ServiceName)
}
}
func TestGetterMethods(t *testing.T) {
mgr := NewManager()
mgr.Set("test.string", "value")
mgr.Set("test.int", 42)
mgr.Set("test.bool", true)
if got := mgr.GetString("test.string"); got != "value" {
t.Errorf("GetString: got %s, want value", got)
}
if got := mgr.GetInt("test.int"); got != 42 {
t.Errorf("GetInt: got %d, want 42", got)
}
if got := mgr.GetBool("test.bool"); !got {
t.Errorf("GetBool: got %v, want true", got)
}
}
func TestWithOptions(t *testing.T) {
mgr := NewManagerWithOptions(
WithEnvPrefix("MYAPP"),
WithConfigName("myconfig"),
)
if mgr == nil {
t.Fatal("Expected manager to be non-nil")
}
// Set environment variable with custom prefix
os.Setenv("MYAPP_SERVER_ADDR", ":5000")
defer os.Unsetenv("MYAPP_SERVER_ADDR")
if err := mgr.Load(); err != nil {
t.Fatalf("Failed to load config: %v", err)
}
cfg, err := mgr.GetConfig()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
if cfg.Server.Addr != ":5000" {
t.Errorf("server.addr: got %s, want :5000", cfg.Server.Addr)
}
}

150
pkg/errortracking/README.md Normal file
View File

@ -0,0 +1,150 @@
# Error Tracking
This package provides error tracking integration for ResolveSpec, with built-in support for Sentry.
## Features
- **Provider Interface**: Flexible design supporting multiple error tracking backends
- **Sentry Integration**: Full-featured Sentry support with automatic error, warning, and panic tracking
- **Automatic Logger Integration**: All `logger.Error()` and `logger.Warn()` calls are automatically sent to the error tracker
- **Panic Tracking**: Automatic panic capture with stack traces
- **NoOp Provider**: Zero-overhead when error tracking is disabled
## Configuration
Add error tracking configuration to your config file:
```yaml
error_tracking:
enabled: true
provider: "sentry" # Currently supports: "sentry" or "noop"
dsn: "https://your-sentry-dsn@sentry.io/project-id"
environment: "production" # e.g., production, staging, development
release: "v1.0.0" # Your application version
debug: false
sample_rate: 1.0 # Error sample rate (0.0-1.0)
traces_sample_rate: 0.1 # Traces sample rate (0.0-1.0)
```
## Usage
### Initialization
Initialize error tracking in your application startup:
```go
package main
import (
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
func main() {
// Load your configuration
cfg := config.Config{
ErrorTracking: config.ErrorTrackingConfig{
Enabled: true,
Provider: "sentry",
DSN: "https://your-sentry-dsn@sentry.io/project-id",
Environment: "production",
Release: "v1.0.0",
SampleRate: 1.0,
},
}
// Initialize logger
logger.Init(false)
// Initialize error tracking
provider, err := errortracking.NewProviderFromConfig(cfg.ErrorTracking)
if err != nil {
logger.Error("Failed to initialize error tracking: %v", err)
} else {
logger.InitErrorTracking(provider)
}
// Your application code...
// Cleanup on shutdown
defer logger.CloseErrorTracking()
}
```
### Automatic Tracking
Once initialized, all logger errors and warnings are automatically sent to the error tracker:
```go
// This will be logged AND sent to Sentry
logger.Error("Database connection failed: %v", err)
// This will also be logged AND sent to Sentry
logger.Warn("Cache miss for key: %s", key)
```
### Panic Tracking
Panics are automatically captured when using the logger's panic handlers:
```go
// Using CatchPanic
defer logger.CatchPanic("MyFunction")
// Using CatchPanicCallback
defer logger.CatchPanicCallback("MyFunction", func(err any) {
// Custom cleanup
})
// Using HandlePanic
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("MyMethod", r)
}
}()
```
### Manual Tracking
You can also use the provider directly for custom error tracking:
```go
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
func someFunction() {
tracker := logger.GetErrorTracker()
if tracker != nil {
// Capture an error
tracker.CaptureError(context.Background(), err, errortracking.SeverityError, map[string]interface{}{
"user_id": userID,
"request_id": requestID,
})
// Capture a message
tracker.CaptureMessage(context.Background(), "Important event occurred", errortracking.SeverityInfo, map[string]interface{}{
"event_type": "user_signup",
})
// Capture a panic
tracker.CapturePanic(context.Background(), recovered, stackTrace, map[string]interface{}{
"context": "background_job",
})
}
}
```
## Severity Levels
The package supports the following severity levels:
- `SeverityError`: For errors that should be tracked and investigated
- `SeverityWarning`: For warnings that may indicate potential issues
- `SeverityInfo`: For informational messages
- `SeverityDebug`: For debug-level information
```

View File

@ -0,0 +1,67 @@
package errortracking
import (
"context"
"errors"
"testing"
)
func TestNoOpProvider(t *testing.T) {
provider := NewNoOpProvider()
// Test that all methods can be called without panicking
t.Run("CaptureError", func(t *testing.T) {
provider.CaptureError(context.Background(), errors.New("test error"), SeverityError, nil)
})
t.Run("CaptureMessage", func(t *testing.T) {
provider.CaptureMessage(context.Background(), "test message", SeverityWarning, nil)
})
t.Run("CapturePanic", func(t *testing.T) {
provider.CapturePanic(context.Background(), "panic!", []byte("stack trace"), nil)
})
t.Run("Flush", func(t *testing.T) {
result := provider.Flush(5)
if !result {
t.Error("Expected Flush to return true")
}
})
t.Run("Close", func(t *testing.T) {
err := provider.Close()
if err != nil {
t.Errorf("Expected Close to return nil, got %v", err)
}
})
}
func TestSeverityLevels(t *testing.T) {
tests := []struct {
name string
severity Severity
expected string
}{
{"Error", SeverityError, "error"},
{"Warning", SeverityWarning, "warning"},
{"Info", SeverityInfo, "info"},
{"Debug", SeverityDebug, "debug"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if string(tt.severity) != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, string(tt.severity))
}
})
}
}
func TestProviderInterface(t *testing.T) {
// Test that NoOpProvider implements Provider interface
var _ Provider = (*NoOpProvider)(nil)
// Test that SentryProvider implements Provider interface
var _ Provider = (*SentryProvider)(nil)
}

View File

@ -0,0 +1,33 @@
package errortracking
import (
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/config"
)
// NewProviderFromConfig creates an error tracking provider based on the configuration
func NewProviderFromConfig(cfg config.ErrorTrackingConfig) (Provider, error) {
if !cfg.Enabled {
return NewNoOpProvider(), nil
}
switch cfg.Provider {
case "sentry":
if cfg.DSN == "" {
return nil, fmt.Errorf("sentry DSN is required when error tracking is enabled")
}
return NewSentryProvider(SentryConfig{
DSN: cfg.DSN,
Environment: cfg.Environment,
Release: cfg.Release,
Debug: cfg.Debug,
SampleRate: cfg.SampleRate,
TracesSampleRate: cfg.TracesSampleRate,
})
case "noop", "":
return NewNoOpProvider(), nil
default:
return nil, fmt.Errorf("unknown error tracking provider: %s", cfg.Provider)
}
}

View File

@ -0,0 +1,33 @@
package errortracking
import (
"context"
)
// Severity represents the severity level of an error
type Severity string
const (
SeverityError Severity = "error"
SeverityWarning Severity = "warning"
SeverityInfo Severity = "info"
SeverityDebug Severity = "debug"
)
// Provider defines the interface for error tracking providers
type Provider interface {
// CaptureError captures an error with the given severity and additional context
CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{})
// CaptureMessage captures a message with the given severity and additional context
CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{})
// CapturePanic captures a panic with stack trace
CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{})
// Flush waits for all events to be sent (useful for graceful shutdown)
Flush(timeout int) bool
// Close closes the provider and releases resources
Close() error
}

37
pkg/errortracking/noop.go Normal file
View File

@ -0,0 +1,37 @@
package errortracking
import "context"
// NoOpProvider is a no-op implementation of the Provider interface
// Used when error tracking is disabled
type NoOpProvider struct{}
// NewNoOpProvider creates a new NoOp provider
func NewNoOpProvider() *NoOpProvider {
return &NoOpProvider{}
}
// CaptureError does nothing
func (n *NoOpProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
// No-op
}
// CaptureMessage does nothing
func (n *NoOpProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
// No-op
}
// CapturePanic does nothing
func (n *NoOpProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
// No-op
}
// Flush does nothing and returns true
func (n *NoOpProvider) Flush(timeout int) bool {
return true
}
// Close does nothing
func (n *NoOpProvider) Close() error {
return nil
}

154
pkg/errortracking/sentry.go Normal file
View File

@ -0,0 +1,154 @@
package errortracking
import (
"context"
"fmt"
"time"
"github.com/getsentry/sentry-go"
)
// SentryProvider implements the Provider interface using Sentry
type SentryProvider struct {
hub *sentry.Hub
}
// SentryConfig holds the configuration for Sentry
type SentryConfig struct {
DSN string
Environment string
Release string
Debug bool
SampleRate float64
TracesSampleRate float64
}
// NewSentryProvider creates a new Sentry provider
func NewSentryProvider(config SentryConfig) (*SentryProvider, error) {
err := sentry.Init(sentry.ClientOptions{
Dsn: config.DSN,
Environment: config.Environment,
Release: config.Release,
Debug: config.Debug,
AttachStacktrace: true,
SampleRate: config.SampleRate,
TracesSampleRate: config.TracesSampleRate,
})
if err != nil {
return nil, fmt.Errorf("failed to initialize Sentry: %w", err)
}
return &SentryProvider{
hub: sentry.CurrentHub(),
}, nil
}
// CaptureError captures an error with the given severity and additional context
func (s *SentryProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
if err == nil {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = s.convertSeverity(severity)
event.Message = err.Error()
event.Exception = []sentry.Exception{
{
Value: err.Error(),
Type: fmt.Sprintf("%T", err),
Stacktrace: sentry.ExtractStacktrace(err),
},
}
if extra != nil {
event.Extra = extra
}
hub.CaptureEvent(event)
}
// CaptureMessage captures a message with the given severity and additional context
func (s *SentryProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
if message == "" {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = s.convertSeverity(severity)
event.Message = message
if extra != nil {
event.Extra = extra
}
hub.CaptureEvent(event)
}
// CapturePanic captures a panic with stack trace
func (s *SentryProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
if recovered == nil {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = sentry.LevelError
event.Message = fmt.Sprintf("Panic: %v", recovered)
event.Exception = []sentry.Exception{
{
Value: fmt.Sprintf("%v", recovered),
Type: "panic",
},
}
if extra != nil {
event.Extra = extra
}
if stackTrace != nil {
event.Extra["stack_trace"] = string(stackTrace)
}
hub.CaptureEvent(event)
}
// Flush waits for all events to be sent (useful for graceful shutdown)
func (s *SentryProvider) Flush(timeout int) bool {
return sentry.Flush(time.Duration(timeout) * time.Second)
}
// Close closes the provider and releases resources
func (s *SentryProvider) Close() error {
sentry.Flush(2 * time.Second)
return nil
}
// convertSeverity converts our Severity to Sentry's Level
func (s *SentryProvider) convertSeverity(severity Severity) sentry.Level {
switch severity {
case SeverityError:
return sentry.LevelError
case SeverityWarning:
return sentry.LevelWarning
case SeverityInfo:
return sentry.LevelInfo
case SeverityDebug:
return sentry.LevelDebug
default:
return sentry.LevelError
}
}

327
pkg/eventbroker/README.md Normal file
View File

@ -0,0 +1,327 @@
# Event Broker System
A comprehensive event handler/broker system for ResolveSpec that provides real-time event publishing, subscription, and cross-instance communication.
## Features
- **Multiple Sources**: Events from database, websockets, frontend, system, and internal sources
- **Event Status Tracking**: Pending, processing, completed, failed states with timestamps
- **Rich Metadata**: User IDs, session IDs, instance IDs, JSON payloads, and custom metadata
- **Sync & Async Modes**: Choose between synchronous or asynchronous event processing
- **Pattern Matching**: Subscribe to events using glob-style patterns
- **Multiple Providers**: In-memory, Redis Streams, NATS JetStream, PostgreSQL with NOTIFY
- **Hook Integration**: Automatic CRUD event capture via restheadspec hooks
- **Retry Logic**: Configurable retry policy with exponential backoff
- **Metrics**: Prometheus-compatible metrics for monitoring
- **Graceful Shutdown**: Proper cleanup and event flushing on shutdown
## Quick Start
### 1. Configuration
Add to your `config.yaml`:
```yaml
event_broker:
enabled: true
provider: memory # memory, redis, nats, database
mode: async # sync, async
worker_count: 10
buffer_size: 1000
instance_id: "${HOSTNAME}"
```
### 2. Initialize
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/eventbroker"
)
func main() {
// Load configuration
cfgMgr := config.NewManager()
cfg, _ := cfgMgr.GetConfig()
// Initialize event broker
if err := eventbroker.Initialize(cfg.EventBroker); err != nil {
log.Fatal(err)
}
}
```
### 3. Subscribe to Events
```go
// Subscribe to specific events
eventbroker.Subscribe("public.users.create", eventbroker.EventHandlerFunc(
func(ctx context.Context, event *eventbroker.Event) error {
log.Printf("New user created: %s", event.Payload)
// Send welcome email, update cache, etc.
return nil
},
))
// Subscribe with patterns
eventbroker.Subscribe("*.*.delete", eventbroker.EventHandlerFunc(
func(ctx context.Context, event *eventbroker.Event) error {
log.Printf("Deleted: %s.%s", event.Schema, event.Entity)
return nil
},
))
```
### 4. Publish Events
```go
// Create and publish an event
event := eventbroker.NewEvent(eventbroker.EventSourceDatabase, "public.users.update")
event.InstanceID = eventbroker.GetDefaultBroker().InstanceID()
event.UserID = 123
event.SessionID = "session-456"
event.Schema = "public"
event.Entity = "users"
event.Operation = "update"
event.SetPayload(map[string]interface{}{
"id": 123,
"name": "John Doe",
})
// Async (non-blocking)
eventbroker.PublishAsync(ctx, event)
// Sync (blocking)
eventbroker.PublishSync(ctx, event)
```
## Automatic CRUD Event Capture
Automatically capture database CRUD operations:
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/eventbroker"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
)
func setupHooks(handler *restheadspec.Handler) {
broker := eventbroker.GetDefaultBroker()
// Configure which operations to capture
config := eventbroker.DefaultCRUDHookConfig()
config.EnableRead = false // Disable read events for performance
// Register hooks
eventbroker.RegisterCRUDHooks(broker, handler.Hooks(), config)
// Now all create/update/delete operations automatically publish events!
}
```
## Event Structure
Every event contains:
```go
type Event struct {
ID string // UUID
Source EventSource // database, websocket, system, frontend, internal
Type string // Pattern: schema.entity.operation
Status EventStatus // pending, processing, completed, failed
Payload json.RawMessage // JSON payload
UserID int // User who triggered the event
SessionID string // Session identifier
InstanceID string // Server instance identifier
Schema string // Database schema
Entity string // Database entity/table
Operation string // create, update, delete, read
CreatedAt time.Time // When event was created
ProcessedAt *time.Time // When processing started
CompletedAt *time.Time // When processing completed
Error string // Error message if failed
Metadata map[string]interface{} // Additional context
RetryCount int // Number of retry attempts
}
```
## Pattern Matching
Subscribe to events using glob-style patterns:
| Pattern | Matches | Example |
|---------|---------|---------|
| `*` | All events | Any event |
| `public.users.*` | All user operations | `public.users.create`, `public.users.update` |
| `*.*.create` | All create operations | `public.users.create`, `auth.sessions.create` |
| `public.*.*` | All events in public schema | `public.users.create`, `public.posts.delete` |
| `public.users.create` | Exact match | Only `public.users.create` |
## Providers
### Memory Provider (Default)
Best for: Development, single-instance deployments
- **Pros**: Fast, no dependencies, simple
- **Cons**: Events lost on restart, single-instance only
```yaml
event_broker:
provider: memory
```
### Redis Provider (Future)
Best for: Production, multi-instance deployments
- **Pros**: Persistent, cross-instance pub/sub, reliable
- **Cons**: Requires Redis
```yaml
event_broker:
provider: redis
redis:
stream_name: "resolvespec:events"
consumer_group: "resolvespec-workers"
host: "localhost"
port: 6379
```
### NATS Provider (Future)
Best for: High-performance, low-latency requirements
- **Pros**: Very fast, built-in clustering, durable
- **Cons**: Requires NATS server
```yaml
event_broker:
provider: nats
nats:
url: "nats://localhost:4222"
stream_name: "RESOLVESPEC_EVENTS"
```
### Database Provider (Future)
Best for: Audit trails, event replay, SQL queries
- **Pros**: No additional infrastructure, full SQL query support, PostgreSQL NOTIFY for real-time
- **Cons**: Slower than Redis/NATS
```yaml
event_broker:
provider: database
database:
table_name: "events"
channel: "resolvespec_events"
```
## Processing Modes
### Async Mode (Recommended)
Events are queued and processed by worker pool:
- Non-blocking event publishing
- Configurable worker count
- Better throughput
- Events may be processed out of order
```yaml
event_broker:
mode: async
worker_count: 10
buffer_size: 1000
```
### Sync Mode
Events are processed immediately:
- Blocking event publishing
- Guaranteed ordering
- Immediate error feedback
- Lower throughput
```yaml
event_broker:
mode: sync
```
## Retry Policy
Configure automatic retries for failed handlers:
```yaml
event_broker:
retry_policy:
max_retries: 3
initial_delay: 1s
max_delay: 30s
backoff_factor: 2.0 # Exponential backoff
```
## Metrics
The event broker exposes Prometheus metrics:
- `eventbroker_events_published_total{source, type}` - Total events published
- `eventbroker_events_processed_total{source, type, status}` - Total events processed
- `eventbroker_event_processing_duration_seconds{source, type}` - Event processing duration
- `eventbroker_queue_size` - Current queue size (async mode)
## Best Practices
1. **Use Async Mode**: For better performance, use async mode in production
2. **Disable Read Events**: Read events can be high volume; disable if not needed
3. **Pattern Matching**: Use specific patterns to avoid processing unnecessary events
4. **Error Handling**: Always handle errors in event handlers; they won't fail the original operation
5. **Idempotency**: Make handlers idempotent as events may be retried
6. **Payload Size**: Keep payloads reasonable; avoid large objects
7. **Monitoring**: Monitor metrics to detect issues early
## Examples
See `example_usage.go` for comprehensive examples including:
- Basic event publishing and subscription
- Hook integration
- Error handling
- Configuration
- Pattern matching
## Architecture
```
┌─────────────────┐
│ Application │
└────────┬────────┘
├─ Publish Events
┌────────▼────────┐ ┌──────────────┐
│ Event Broker │◄────►│ Subscribers │
└────────┬────────┘ └──────────────┘
├─ Store Events
┌────────▼────────┐
│ Provider │
│ (Memory/Redis │
│ /NATS/DB) │
└─────────────────┘
```
## Future Enhancements
- [ ] Database Provider with PostgreSQL NOTIFY
- [ ] Redis Streams Provider
- [ ] NATS JetStream Provider
- [ ] Event replay functionality
- [ ] Dead letter queue
- [ ] Event filtering at provider level
- [ ] Batch publishing
- [ ] Event compression
- [ ] Schema versioning

453
pkg/eventbroker/broker.go Normal file
View File

@ -0,0 +1,453 @@
package eventbroker
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Broker is the main interface for event publishing and subscription
type Broker interface {
// Publish publishes an event (mode-dependent: sync or async)
Publish(ctx context.Context, event *Event) error
// PublishSync publishes an event synchronously (blocks until all handlers complete)
PublishSync(ctx context.Context, event *Event) error
// PublishAsync publishes an event asynchronously (returns immediately)
PublishAsync(ctx context.Context, event *Event) error
// Subscribe registers a handler for events matching the pattern
Subscribe(pattern string, handler EventHandler) (SubscriptionID, error)
// Unsubscribe removes a subscription
Unsubscribe(id SubscriptionID) error
// Start starts the broker (begins processing events)
Start(ctx context.Context) error
// Stop stops the broker gracefully (flushes pending events)
Stop(ctx context.Context) error
// Stats returns broker statistics
Stats(ctx context.Context) (*BrokerStats, error)
// InstanceID returns the instance ID of this broker
InstanceID() string
}
// ProcessingMode determines how events are processed
type ProcessingMode string
const (
ProcessingModeSync ProcessingMode = "sync"
ProcessingModeAsync ProcessingMode = "async"
)
// BrokerStats contains broker statistics
type BrokerStats struct {
InstanceID string `json:"instance_id"`
Mode ProcessingMode `json:"mode"`
IsRunning bool `json:"is_running"`
TotalPublished int64 `json:"total_published"`
TotalProcessed int64 `json:"total_processed"`
TotalFailed int64 `json:"total_failed"`
ActiveSubscribers int `json:"active_subscribers"`
QueueSize int `json:"queue_size,omitempty"` // For async mode
ActiveWorkers int `json:"active_workers,omitempty"` // For async mode
ProviderStats *ProviderStats `json:"provider_stats,omitempty"`
AdditionalStats map[string]interface{} `json:"additional_stats,omitempty"`
}
// EventBroker implements the Broker interface
type EventBroker struct {
provider Provider
subscriptions *subscriptionManager
mode ProcessingMode
instanceID string
retryPolicy *RetryPolicy
// Async mode fields (initialized in Phase 4)
workerPool *workerPool
// Runtime state
isRunning atomic.Bool
stopOnce sync.Once
stopCh chan struct{}
wg sync.WaitGroup
// Statistics
statsPublished atomic.Int64
statsProcessed atomic.Int64
statsFailed atomic.Int64
}
// RetryPolicy defines how failed events should be retried
type RetryPolicy struct {
MaxRetries int
InitialDelay time.Duration
MaxDelay time.Duration
BackoffFactor float64
}
// DefaultRetryPolicy returns a sensible default retry policy
func DefaultRetryPolicy() *RetryPolicy {
return &RetryPolicy{
MaxRetries: 3,
InitialDelay: 1 * time.Second,
MaxDelay: 30 * time.Second,
BackoffFactor: 2.0,
}
}
// Options for creating a new broker
type Options struct {
Provider Provider
Mode ProcessingMode
WorkerCount int // For async mode
BufferSize int // For async mode
RetryPolicy *RetryPolicy
InstanceID string
}
// NewBroker creates a new event broker with the given options
func NewBroker(opts Options) (*EventBroker, error) {
if opts.Provider == nil {
return nil, fmt.Errorf("provider is required")
}
if opts.InstanceID == "" {
return nil, fmt.Errorf("instance ID is required")
}
if opts.Mode == "" {
opts.Mode = ProcessingModeAsync // Default to async
}
if opts.RetryPolicy == nil {
opts.RetryPolicy = DefaultRetryPolicy()
}
broker := &EventBroker{
provider: opts.Provider,
subscriptions: newSubscriptionManager(),
mode: opts.Mode,
instanceID: opts.InstanceID,
retryPolicy: opts.RetryPolicy,
stopCh: make(chan struct{}),
}
// Worker pool will be initialized in Phase 4 for async mode
if opts.Mode == ProcessingModeAsync {
if opts.WorkerCount == 0 {
opts.WorkerCount = 10 // Default
}
if opts.BufferSize == 0 {
opts.BufferSize = 1000 // Default
}
broker.workerPool = newWorkerPool(opts.WorkerCount, opts.BufferSize, broker.processEvent)
}
return broker, nil
}
// Functional option pattern helpers
func WithProvider(p Provider) func(*Options) {
return func(o *Options) { o.Provider = p }
}
func WithMode(m ProcessingMode) func(*Options) {
return func(o *Options) { o.Mode = m }
}
func WithWorkerCount(count int) func(*Options) {
return func(o *Options) { o.WorkerCount = count }
}
func WithBufferSize(size int) func(*Options) {
return func(o *Options) { o.BufferSize = size }
}
func WithRetryPolicy(policy *RetryPolicy) func(*Options) {
return func(o *Options) { o.RetryPolicy = policy }
}
func WithInstanceID(id string) func(*Options) {
return func(o *Options) { o.InstanceID = id }
}
// Start starts the broker
func (b *EventBroker) Start(ctx context.Context) error {
if b.isRunning.Load() {
return fmt.Errorf("broker already running")
}
b.isRunning.Store(true)
// Start worker pool for async mode
if b.mode == ProcessingModeAsync && b.workerPool != nil {
b.workerPool.Start()
}
logger.Info("Event broker started (mode: %s, instance: %s)", b.mode, b.instanceID)
return nil
}
// Stop stops the broker gracefully
func (b *EventBroker) Stop(ctx context.Context) error {
var stopErr error
b.stopOnce.Do(func() {
logger.Info("Stopping event broker...")
// Mark as not running
b.isRunning.Store(false)
// Close the stop channel
close(b.stopCh)
// Stop worker pool for async mode
if b.mode == ProcessingModeAsync && b.workerPool != nil {
if err := b.workerPool.Stop(ctx); err != nil {
logger.Error("Error stopping worker pool: %v", err)
stopErr = err
}
}
// Wait for all goroutines
b.wg.Wait()
// Close provider
if err := b.provider.Close(); err != nil {
logger.Error("Error closing provider: %v", err)
if stopErr == nil {
stopErr = err
}
}
logger.Info("Event broker stopped")
})
return stopErr
}
// Publish publishes an event based on the broker's mode
func (b *EventBroker) Publish(ctx context.Context, event *Event) error {
if b.mode == ProcessingModeSync {
return b.PublishSync(ctx, event)
}
return b.PublishAsync(ctx, event)
}
// PublishSync publishes an event synchronously
func (b *EventBroker) PublishSync(ctx context.Context, event *Event) error {
if !b.isRunning.Load() {
return fmt.Errorf("broker is not running")
}
// Validate event
if err := event.Validate(); err != nil {
return fmt.Errorf("invalid event: %w", err)
}
// Store event in provider
if err := b.provider.Publish(ctx, event); err != nil {
return fmt.Errorf("failed to publish event: %w", err)
}
b.statsPublished.Add(1)
// Record metrics
recordEventPublished(event)
// Process event synchronously
if err := b.processEvent(ctx, event); err != nil {
logger.Error("Failed to process event %s: %v", event.ID, err)
b.statsFailed.Add(1)
return err
}
b.statsProcessed.Add(1)
return nil
}
// PublishAsync publishes an event asynchronously
func (b *EventBroker) PublishAsync(ctx context.Context, event *Event) error {
if !b.isRunning.Load() {
return fmt.Errorf("broker is not running")
}
// Validate event
if err := event.Validate(); err != nil {
return fmt.Errorf("invalid event: %w", err)
}
// Store event in provider
if err := b.provider.Publish(ctx, event); err != nil {
return fmt.Errorf("failed to publish event: %w", err)
}
b.statsPublished.Add(1)
// Record metrics
recordEventPublished(event)
// Queue for async processing
if b.mode == ProcessingModeAsync && b.workerPool != nil {
// Update queue size metrics
updateQueueSize(int64(b.workerPool.QueueSize()))
return b.workerPool.Submit(ctx, event)
}
// Fallback to sync if async not configured
return b.processEvent(ctx, event)
}
// Subscribe adds a subscription for events matching the pattern
func (b *EventBroker) Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
return b.subscriptions.Subscribe(pattern, handler)
}
// Unsubscribe removes a subscription
func (b *EventBroker) Unsubscribe(id SubscriptionID) error {
return b.subscriptions.Unsubscribe(id)
}
// processEvent processes an event by calling all matching handlers
func (b *EventBroker) processEvent(ctx context.Context, event *Event) error {
startTime := time.Now()
// Get all handlers matching this event type
handlers := b.subscriptions.GetMatching(event.Type)
if len(handlers) == 0 {
logger.Debug("No handlers for event type: %s", event.Type)
return nil
}
logger.Debug("Processing event %s with %d handler(s)", event.ID, len(handlers))
// Mark event as processing
event.MarkProcessing()
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusProcessing, ""); err != nil {
logger.Warn("Failed to update event status: %v", err)
}
// Execute all handlers
var lastErr error
for i, handler := range handlers {
if err := b.executeHandlerWithRetry(ctx, handler, event); err != nil {
logger.Error("Handler %d failed for event %s: %v", i+1, event.ID, err)
lastErr = err
// Continue processing other handlers
}
}
// Update final status
if lastErr != nil {
event.MarkFailed(lastErr)
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusFailed, lastErr.Error()); err != nil {
logger.Warn("Failed to update event status: %v", err)
}
// Record metrics
recordEventProcessed(event, time.Since(startTime))
return lastErr
}
event.MarkCompleted()
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusCompleted, ""); err != nil {
logger.Warn("Failed to update event status: %v", err)
}
// Record metrics
recordEventProcessed(event, time.Since(startTime))
return nil
}
// executeHandlerWithRetry executes a handler with retry logic
func (b *EventBroker) executeHandlerWithRetry(ctx context.Context, handler EventHandler, event *Event) error {
var lastErr error
for attempt := 0; attempt <= b.retryPolicy.MaxRetries; attempt++ {
if attempt > 0 {
// Calculate backoff delay
delay := b.calculateBackoff(attempt)
logger.Debug("Retrying event %s (attempt %d/%d) after %v",
event.ID, attempt, b.retryPolicy.MaxRetries, delay)
select {
case <-time.After(delay):
case <-ctx.Done():
return ctx.Err()
}
event.IncrementRetry()
}
// Execute handler
if err := handler.Handle(ctx, event); err != nil {
lastErr = err
logger.Warn("Handler failed for event %s (attempt %d): %v", event.ID, attempt+1, err)
continue
}
// Success
return nil
}
return fmt.Errorf("handler failed after %d attempts: %w", b.retryPolicy.MaxRetries+1, lastErr)
}
// calculateBackoff calculates the backoff delay for a retry attempt
func (b *EventBroker) calculateBackoff(attempt int) time.Duration {
delay := float64(b.retryPolicy.InitialDelay) * pow(b.retryPolicy.BackoffFactor, float64(attempt-1))
if delay > float64(b.retryPolicy.MaxDelay) {
delay = float64(b.retryPolicy.MaxDelay)
}
return time.Duration(delay)
}
// pow is a simple integer power function
func pow(base float64, exp float64) float64 {
result := 1.0
for i := 0.0; i < exp; i++ {
result *= base
}
return result
}
// Stats returns broker statistics
func (b *EventBroker) Stats(ctx context.Context) (*BrokerStats, error) {
providerStats, err := b.provider.Stats(ctx)
if err != nil {
logger.Warn("Failed to get provider stats: %v", err)
}
stats := &BrokerStats{
InstanceID: b.instanceID,
Mode: b.mode,
IsRunning: b.isRunning.Load(),
TotalPublished: b.statsPublished.Load(),
TotalProcessed: b.statsProcessed.Load(),
TotalFailed: b.statsFailed.Load(),
ActiveSubscribers: b.subscriptions.Count(),
ProviderStats: providerStats,
}
// Add async-specific stats
if b.mode == ProcessingModeAsync && b.workerPool != nil {
stats.QueueSize = b.workerPool.QueueSize()
stats.ActiveWorkers = b.workerPool.ActiveWorkers()
}
return stats, nil
}
// InstanceID returns the instance ID
func (b *EventBroker) InstanceID() string {
return b.instanceID
}

View File

@ -0,0 +1,524 @@
package eventbroker
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestNewBroker(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
MaxEvents: 1000,
})
tests := []struct {
name string
opts Options
wantError bool
}{
{
name: "valid options",
opts: Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
},
wantError: false,
},
{
name: "missing provider",
opts: Options{
InstanceID: "test-instance",
},
wantError: true,
},
{
name: "missing instance ID",
opts: Options{
Provider: provider,
},
wantError: true,
},
{
name: "async mode with defaults",
opts: Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeAsync,
},
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
broker, err := NewBroker(tt.opts)
if (err != nil) != tt.wantError {
t.Errorf("NewBroker() error = %v, wantError %v", err, tt.wantError)
}
if err == nil && broker == nil {
t.Error("Expected non-nil broker")
}
})
}
}
func TestBrokerStartStop(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, err := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
})
if err != nil {
t.Fatalf("Failed to create broker: %v", err)
}
// Test Start
if err := broker.Start(context.Background()); err != nil {
t.Fatalf("Failed to start broker: %v", err)
}
// Test double start (should fail)
if err := broker.Start(context.Background()); err == nil {
t.Error("Expected error on double start")
}
// Test Stop
if err := broker.Stop(context.Background()); err != nil {
t.Fatalf("Failed to stop broker: %v", err)
}
// Test double stop (should not fail)
if err := broker.Stop(context.Background()); err != nil {
t.Error("Double stop should not fail")
}
}
func TestBrokerPublishSync(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
// Subscribe to events
called := false
var receivedEvent *Event
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
called = true
receivedEvent = event
return nil
}))
// Publish event
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
err := broker.PublishSync(context.Background(), event)
if err != nil {
t.Fatalf("PublishSync failed: %v", err)
}
// Verify handler was called
if !called {
t.Error("Expected handler to be called")
}
if receivedEvent == nil || receivedEvent.ID != event.ID {
t.Error("Expected to receive the published event")
}
// Verify event status
if event.Status != EventStatusCompleted {
t.Errorf("Expected status %s, got %s", EventStatusCompleted, event.Status)
}
}
func TestBrokerPublishAsync(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeAsync,
WorkerCount: 2,
BufferSize: 10,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
// Subscribe to events
var callCount atomic.Int32
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
callCount.Add(1)
return nil
}))
// Publish multiple events
for i := 0; i < 5; i++ {
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
if err := broker.PublishAsync(context.Background(), event); err != nil {
t.Fatalf("PublishAsync failed: %v", err)
}
}
// Wait for events to be processed
time.Sleep(100 * time.Millisecond)
if callCount.Load() != 5 {
t.Errorf("Expected 5 handler calls, got %d", callCount.Load())
}
}
func TestBrokerPublishBeforeStart(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
})
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
err := broker.Publish(context.Background(), event)
if err == nil {
t.Error("Expected error when publishing before start")
}
}
func TestBrokerHandlerError(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
RetryPolicy: &RetryPolicy{
MaxRetries: 2,
InitialDelay: 10 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
BackoffFactor: 2.0,
},
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
// Subscribe with failing handler
var callCount atomic.Int32
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
callCount.Add(1)
return errors.New("handler error")
}))
// Publish event
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
err := broker.PublishSync(context.Background(), event)
// Should fail after retries
if err == nil {
t.Error("Expected error from handler")
}
// Should have been called MaxRetries+1 times (initial + retries)
if callCount.Load() != 3 {
t.Errorf("Expected 3 calls (1 initial + 2 retries), got %d", callCount.Load())
}
// Event should be marked as failed
if event.Status != EventStatusFailed {
t.Errorf("Expected status %s, got %s", EventStatusFailed, event.Status)
}
}
func TestBrokerMultipleHandlers(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
// Subscribe multiple handlers
var called1, called2, called3 bool
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
called1 = true
return nil
}))
broker.Subscribe("test.event", EventHandlerFunc(func(ctx context.Context, event *Event) error {
called2 = true
return nil
}))
broker.Subscribe("*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
called3 = true
return nil
}))
// Publish event
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
broker.PublishSync(context.Background(), event)
// All handlers should be called
if !called1 || !called2 || !called3 {
t.Error("Expected all handlers to be called")
}
}
func TestBrokerUnsubscribe(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
// Subscribe
called := false
id, _ := broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
called = true
return nil
}))
// Unsubscribe
if err := broker.Unsubscribe(id); err != nil {
t.Fatalf("Unsubscribe failed: %v", err)
}
// Publish event
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
broker.PublishSync(context.Background(), event)
// Handler should not be called
if called {
t.Error("Expected handler not to be called after unsubscribe")
}
}
func TestBrokerStats(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
// Subscribe
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
return nil
}))
// Publish events
for i := 0; i < 3; i++ {
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
broker.PublishSync(context.Background(), event)
}
// Get stats
stats, err := broker.Stats(context.Background())
if err != nil {
t.Fatalf("Stats failed: %v", err)
}
if stats.InstanceID != "test-instance" {
t.Errorf("Expected instance ID 'test-instance', got %s", stats.InstanceID)
}
if stats.TotalPublished != 3 {
t.Errorf("Expected 3 published events, got %d", stats.TotalPublished)
}
if stats.TotalProcessed != 3 {
t.Errorf("Expected 3 processed events, got %d", stats.TotalProcessed)
}
if stats.ActiveSubscribers != 1 {
t.Errorf("Expected 1 active subscriber, got %d", stats.ActiveSubscribers)
}
}
func TestBrokerInstanceID(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "my-instance",
})
if broker.InstanceID() != "my-instance" {
t.Errorf("Expected instance ID 'my-instance', got %s", broker.InstanceID())
}
}
func TestBrokerConcurrentPublish(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeAsync,
WorkerCount: 5,
BufferSize: 100,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
var callCount atomic.Int32
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
callCount.Add(1)
return nil
}))
// Publish concurrently
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
broker.PublishAsync(context.Background(), event)
}()
}
wg.Wait()
time.Sleep(200 * time.Millisecond) // Wait for async processing
if callCount.Load() != 50 {
t.Errorf("Expected 50 handler calls, got %d", callCount.Load())
}
}
func TestBrokerGracefulShutdown(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeAsync,
WorkerCount: 2,
BufferSize: 10,
})
broker.Start(context.Background())
var processedCount atomic.Int32
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
time.Sleep(50 * time.Millisecond) // Simulate work
processedCount.Add(1)
return nil
}))
// Publish events
for i := 0; i < 5; i++ {
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
broker.PublishAsync(context.Background(), event)
}
// Stop broker (should wait for events to be processed)
if err := broker.Stop(context.Background()); err != nil {
t.Fatalf("Stop failed: %v", err)
}
// All events should be processed
if processedCount.Load() != 5 {
t.Errorf("Expected 5 processed events, got %d", processedCount.Load())
}
}
func TestBrokerDefaultRetryPolicy(t *testing.T) {
policy := DefaultRetryPolicy()
if policy.MaxRetries != 3 {
t.Errorf("Expected MaxRetries 3, got %d", policy.MaxRetries)
}
if policy.InitialDelay != 1*time.Second {
t.Errorf("Expected InitialDelay 1s, got %v", policy.InitialDelay)
}
if policy.BackoffFactor != 2.0 {
t.Errorf("Expected BackoffFactor 2.0, got %f", policy.BackoffFactor)
}
}
func TestBrokerProcessingModes(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
tests := []struct {
name string
mode ProcessingMode
}{
{"sync mode", ProcessingModeSync},
{"async mode", ProcessingModeAsync},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: tt.mode,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
called := false
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
called = true
return nil
}))
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
broker.Publish(context.Background(), event)
if tt.mode == ProcessingModeAsync {
time.Sleep(50 * time.Millisecond)
}
if !called {
t.Error("Expected handler to be called")
}
})
}
}

175
pkg/eventbroker/event.go Normal file
View File

@ -0,0 +1,175 @@
package eventbroker
import (
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
)
// EventSource represents where an event originated from
type EventSource string
const (
EventSourceDatabase EventSource = "database"
EventSourceWebSocket EventSource = "websocket"
EventSourceFrontend EventSource = "frontend"
EventSourceSystem EventSource = "system"
EventSourceInternal EventSource = "internal"
)
// EventStatus represents the current state of an event
type EventStatus string
const (
EventStatusPending EventStatus = "pending"
EventStatusProcessing EventStatus = "processing"
EventStatusCompleted EventStatus = "completed"
EventStatusFailed EventStatus = "failed"
)
// Event represents a single event in the system with complete metadata
type Event struct {
// Identification
ID string `json:"id" db:"id"`
// Source & Classification
Source EventSource `json:"source" db:"source"`
Type string `json:"type" db:"type"` // Pattern: schema.entity.operation
// Status Tracking
Status EventStatus `json:"status" db:"status"`
RetryCount int `json:"retry_count" db:"retry_count"`
Error string `json:"error,omitempty" db:"error"`
// Payload
Payload json.RawMessage `json:"payload" db:"payload"`
// Context Information
UserID int `json:"user_id" db:"user_id"`
SessionID string `json:"session_id" db:"session_id"`
InstanceID string `json:"instance_id" db:"instance_id"`
// Database Context
Schema string `json:"schema" db:"schema"`
Entity string `json:"entity" db:"entity"`
Operation string `json:"operation" db:"operation"` // create, update, delete, read
// Timestamps
CreatedAt time.Time `json:"created_at" db:"created_at"`
ProcessedAt *time.Time `json:"processed_at,omitempty" db:"processed_at"`
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
// Extensibility
Metadata map[string]interface{} `json:"metadata" db:"metadata"`
}
// NewEvent creates a new event with defaults
func NewEvent(source EventSource, eventType string) *Event {
return &Event{
ID: uuid.New().String(),
Source: source,
Type: eventType,
Status: EventStatusPending,
CreatedAt: time.Now(),
Metadata: make(map[string]interface{}),
RetryCount: 0,
}
}
// EventType generates a type string from schema, entity, and operation
// Pattern: schema.entity.operation (e.g., "public.users.create")
func EventType(schema, entity, operation string) string {
return fmt.Sprintf("%s.%s.%s", schema, entity, operation)
}
// MarkProcessing marks the event as being processed
func (e *Event) MarkProcessing() {
e.Status = EventStatusProcessing
now := time.Now()
e.ProcessedAt = &now
}
// MarkCompleted marks the event as successfully completed
func (e *Event) MarkCompleted() {
e.Status = EventStatusCompleted
now := time.Now()
e.CompletedAt = &now
}
// MarkFailed marks the event as failed with an error message
func (e *Event) MarkFailed(err error) {
e.Status = EventStatusFailed
e.Error = err.Error()
now := time.Now()
e.CompletedAt = &now
}
// IncrementRetry increments the retry counter
func (e *Event) IncrementRetry() {
e.RetryCount++
}
// SetPayload sets the event payload from any value by marshaling to JSON
func (e *Event) SetPayload(v interface{}) error {
data, err := json.Marshal(v)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
e.Payload = data
return nil
}
// GetPayload unmarshals the payload into the provided value
func (e *Event) GetPayload(v interface{}) error {
if len(e.Payload) == 0 {
return fmt.Errorf("payload is empty")
}
if err := json.Unmarshal(e.Payload, v); err != nil {
return fmt.Errorf("failed to unmarshal payload: %w", err)
}
return nil
}
// Clone creates a deep copy of the event
func (e *Event) Clone() *Event {
clone := *e
// Deep copy metadata
if e.Metadata != nil {
clone.Metadata = make(map[string]interface{})
for k, v := range e.Metadata {
clone.Metadata[k] = v
}
}
// Deep copy timestamps
if e.ProcessedAt != nil {
t := *e.ProcessedAt
clone.ProcessedAt = &t
}
if e.CompletedAt != nil {
t := *e.CompletedAt
clone.CompletedAt = &t
}
return &clone
}
// Validate performs basic validation on the event
func (e *Event) Validate() error {
if e.ID == "" {
return fmt.Errorf("event ID is required")
}
if e.Source == "" {
return fmt.Errorf("event source is required")
}
if e.Type == "" {
return fmt.Errorf("event type is required")
}
if e.InstanceID == "" {
return fmt.Errorf("instance ID is required")
}
return nil
}

View File

@ -0,0 +1,314 @@
package eventbroker
import (
"encoding/json"
"errors"
"testing"
"time"
)
func TestNewEvent(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
if event.ID == "" {
t.Error("Expected event ID to be generated")
}
if event.Source != EventSourceDatabase {
t.Errorf("Expected source %s, got %s", EventSourceDatabase, event.Source)
}
if event.Type != "public.users.create" {
t.Errorf("Expected type 'public.users.create', got %s", event.Type)
}
if event.Status != EventStatusPending {
t.Errorf("Expected status %s, got %s", EventStatusPending, event.Status)
}
if event.CreatedAt.IsZero() {
t.Error("Expected CreatedAt to be set")
}
if event.Metadata == nil {
t.Error("Expected Metadata to be initialized")
}
}
func TestEventType(t *testing.T) {
tests := []struct {
schema string
entity string
operation string
expected string
}{
{"public", "users", "create", "public.users.create"},
{"admin", "roles", "update", "admin.roles.update"},
{"", "system", "start", ".system.start"}, // Empty schema results in leading dot
}
for _, tt := range tests {
result := EventType(tt.schema, tt.entity, tt.operation)
if result != tt.expected {
t.Errorf("EventType(%q, %q, %q) = %q, expected %q",
tt.schema, tt.entity, tt.operation, result, tt.expected)
}
}
}
func TestEventValidate(t *testing.T) {
tests := []struct {
name string
event *Event
wantError bool
}{
{
name: "valid event",
event: func() *Event {
e := NewEvent(EventSourceDatabase, "public.users.create")
e.InstanceID = "test-instance"
return e
}(),
wantError: false,
},
{
name: "missing ID",
event: &Event{
Source: EventSourceDatabase,
Type: "public.users.create",
Status: EventStatusPending,
},
wantError: true,
},
{
name: "missing source",
event: &Event{
ID: "test-id",
Type: "public.users.create",
Status: EventStatusPending,
},
wantError: true,
},
{
name: "missing type",
event: &Event{
ID: "test-id",
Source: EventSourceDatabase,
Status: EventStatusPending,
},
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.event.Validate()
if (err != nil) != tt.wantError {
t.Errorf("Event.Validate() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestEventSetPayload(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
payload := map[string]interface{}{
"id": 1,
"name": "John Doe",
"email": "john@example.com",
}
err := event.SetPayload(payload)
if err != nil {
t.Fatalf("SetPayload failed: %v", err)
}
if event.Payload == nil {
t.Fatal("Expected payload to be set")
}
// Verify payload can be unmarshaled
var result map[string]interface{}
if err := json.Unmarshal(event.Payload, &result); err != nil {
t.Fatalf("Failed to unmarshal payload: %v", err)
}
if result["name"] != "John Doe" {
t.Errorf("Expected name 'John Doe', got %v", result["name"])
}
}
func TestEventGetPayload(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
payload := map[string]interface{}{
"id": float64(1), // JSON unmarshals numbers as float64
"name": "John Doe",
}
if err := event.SetPayload(payload); err != nil {
t.Fatalf("SetPayload failed: %v", err)
}
var result map[string]interface{}
if err := event.GetPayload(&result); err != nil {
t.Fatalf("GetPayload failed: %v", err)
}
if result["name"] != "John Doe" {
t.Errorf("Expected name 'John Doe', got %v", result["name"])
}
}
func TestEventMarkProcessing(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
event.MarkProcessing()
if event.Status != EventStatusProcessing {
t.Errorf("Expected status %s, got %s", EventStatusProcessing, event.Status)
}
if event.ProcessedAt == nil {
t.Error("Expected ProcessedAt to be set")
}
}
func TestEventMarkCompleted(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
event.MarkCompleted()
if event.Status != EventStatusCompleted {
t.Errorf("Expected status %s, got %s", EventStatusCompleted, event.Status)
}
if event.CompletedAt == nil {
t.Error("Expected CompletedAt to be set")
}
}
func TestEventMarkFailed(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
testErr := errors.New("test error")
event.MarkFailed(testErr)
if event.Status != EventStatusFailed {
t.Errorf("Expected status %s, got %s", EventStatusFailed, event.Status)
}
if event.Error != "test error" {
t.Errorf("Expected error %q, got %q", "test error", event.Error)
}
if event.CompletedAt == nil {
t.Error("Expected CompletedAt to be set")
}
}
func TestEventIncrementRetry(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
initialCount := event.RetryCount
event.IncrementRetry()
if event.RetryCount != initialCount+1 {
t.Errorf("Expected retry count %d, got %d", initialCount+1, event.RetryCount)
}
}
func TestEventJSONMarshaling(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
event.UserID = 123
event.SessionID = "session-123"
event.InstanceID = "instance-1"
event.Schema = "public"
event.Entity = "users"
event.Operation = "create"
event.SetPayload(map[string]interface{}{"name": "Test"})
// Marshal to JSON
data, err := json.Marshal(event)
if err != nil {
t.Fatalf("Failed to marshal event: %v", err)
}
// Unmarshal back
var decoded Event
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Failed to unmarshal event: %v", err)
}
// Verify fields
if decoded.ID != event.ID {
t.Errorf("Expected ID %s, got %s", event.ID, decoded.ID)
}
if decoded.Source != event.Source {
t.Errorf("Expected source %s, got %s", event.Source, decoded.Source)
}
if decoded.UserID != event.UserID {
t.Errorf("Expected UserID %d, got %d", event.UserID, decoded.UserID)
}
}
func TestEventStatusString(t *testing.T) {
statuses := []EventStatus{
EventStatusPending,
EventStatusProcessing,
EventStatusCompleted,
EventStatusFailed,
}
for _, status := range statuses {
if string(status) == "" {
t.Errorf("EventStatus %v has empty string representation", status)
}
}
}
func TestEventSourceString(t *testing.T) {
sources := []EventSource{
EventSourceDatabase,
EventSourceWebSocket,
EventSourceFrontend,
EventSourceSystem,
EventSourceInternal,
}
for _, source := range sources {
if string(source) == "" {
t.Errorf("EventSource %v has empty string representation", source)
}
}
}
func TestEventMetadata(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
// Test setting metadata
event.Metadata["key1"] = "value1"
event.Metadata["key2"] = 123
if event.Metadata["key1"] != "value1" {
t.Errorf("Expected metadata key1 to be 'value1', got %v", event.Metadata["key1"])
}
if event.Metadata["key2"] != 123 {
t.Errorf("Expected metadata key2 to be 123, got %v", event.Metadata["key2"])
}
}
func TestEventTimestamps(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
createdAt := event.CreatedAt
// Wait a tiny bit to ensure timestamps differ
time.Sleep(time.Millisecond)
event.MarkProcessing()
if event.ProcessedAt == nil {
t.Fatal("ProcessedAt should be set")
}
if !event.ProcessedAt.After(createdAt) {
t.Error("ProcessedAt should be after CreatedAt")
}
time.Sleep(time.Millisecond)
event.MarkCompleted()
if event.CompletedAt == nil {
t.Fatal("CompletedAt should be set")
}
if !event.CompletedAt.After(*event.ProcessedAt) {
t.Error("CompletedAt should be after ProcessedAt")
}
}

View File

@ -0,0 +1,160 @@
package eventbroker
import (
"context"
"fmt"
"sync"
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/server"
)
var (
defaultBroker Broker
brokerMu sync.RWMutex
)
// Initialize initializes the global event broker from configuration
func Initialize(cfg config.EventBrokerConfig) error {
if !cfg.Enabled {
logger.Info("Event broker is disabled")
return nil
}
// Create provider
provider, err := NewProviderFromConfig(cfg)
if err != nil {
return fmt.Errorf("failed to create provider: %w", err)
}
// Parse mode
mode := ProcessingModeAsync
if cfg.Mode == "sync" {
mode = ProcessingModeSync
}
// Convert retry policy
retryPolicy := &RetryPolicy{
MaxRetries: cfg.RetryPolicy.MaxRetries,
InitialDelay: cfg.RetryPolicy.InitialDelay,
MaxDelay: cfg.RetryPolicy.MaxDelay,
BackoffFactor: cfg.RetryPolicy.BackoffFactor,
}
if retryPolicy.MaxRetries == 0 {
retryPolicy = DefaultRetryPolicy()
}
// Create broker options
opts := Options{
Provider: provider,
Mode: mode,
WorkerCount: cfg.WorkerCount,
BufferSize: cfg.BufferSize,
RetryPolicy: retryPolicy,
InstanceID: getInstanceID(cfg.InstanceID),
}
// Create broker
broker, err := NewBroker(opts)
if err != nil {
return fmt.Errorf("failed to create broker: %w", err)
}
// Start broker
if err := broker.Start(context.Background()); err != nil {
return fmt.Errorf("failed to start broker: %w", err)
}
// Set as default
SetDefaultBroker(broker)
// Register shutdown callback
RegisterShutdown(broker)
logger.Info("Event broker initialized successfully (provider: %s, mode: %s, instance: %s)",
cfg.Provider, cfg.Mode, opts.InstanceID)
return nil
}
// SetDefaultBroker sets the default global broker
func SetDefaultBroker(broker Broker) {
brokerMu.Lock()
defer brokerMu.Unlock()
defaultBroker = broker
}
// GetDefaultBroker returns the default global broker
func GetDefaultBroker() Broker {
brokerMu.RLock()
defer brokerMu.RUnlock()
return defaultBroker
}
// IsInitialized returns true if the default broker is initialized
func IsInitialized() bool {
return GetDefaultBroker() != nil
}
// Publish publishes an event using the default broker
func Publish(ctx context.Context, event *Event) error {
broker := GetDefaultBroker()
if broker == nil {
return fmt.Errorf("event broker not initialized")
}
return broker.Publish(ctx, event)
}
// PublishSync publishes an event synchronously using the default broker
func PublishSync(ctx context.Context, event *Event) error {
broker := GetDefaultBroker()
if broker == nil {
return fmt.Errorf("event broker not initialized")
}
return broker.PublishSync(ctx, event)
}
// PublishAsync publishes an event asynchronously using the default broker
func PublishAsync(ctx context.Context, event *Event) error {
broker := GetDefaultBroker()
if broker == nil {
return fmt.Errorf("event broker not initialized")
}
return broker.PublishAsync(ctx, event)
}
// Subscribe subscribes to events using the default broker
func Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
broker := GetDefaultBroker()
if broker == nil {
return "", fmt.Errorf("event broker not initialized")
}
return broker.Subscribe(pattern, handler)
}
// Unsubscribe unsubscribes from events using the default broker
func Unsubscribe(id SubscriptionID) error {
broker := GetDefaultBroker()
if broker == nil {
return fmt.Errorf("event broker not initialized")
}
return broker.Unsubscribe(id)
}
// Stats returns statistics from the default broker
func Stats(ctx context.Context) (*BrokerStats, error) {
broker := GetDefaultBroker()
if broker == nil {
return nil, fmt.Errorf("event broker not initialized")
}
return broker.Stats(ctx)
}
// RegisterShutdown registers the broker's shutdown with the server shutdown callbacks
func RegisterShutdown(broker Broker) {
server.RegisterShutdownCallback(func(ctx context.Context) error {
logger.Info("Shutting down event broker...")
return broker.Stop(ctx)
})
}

View File

@ -0,0 +1,266 @@
// nolint
package eventbroker
import (
"context"
"fmt"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Example demonstrates basic usage of the event broker
func Example() {
// 1. Create a memory provider
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "example-instance",
MaxEvents: 1000,
CleanupInterval: 5 * time.Minute,
MaxAge: 1 * time.Hour,
})
// 2. Create a broker
broker, err := NewBroker(Options{
Provider: provider,
Mode: ProcessingModeAsync,
WorkerCount: 5,
BufferSize: 100,
RetryPolicy: DefaultRetryPolicy(),
InstanceID: "example-instance",
})
if err != nil {
logger.Error("Failed to create broker: %v", err)
return
}
// 3. Start the broker
if err := broker.Start(context.Background()); err != nil {
logger.Error("Failed to start broker: %v", err)
return
}
defer func() {
err := broker.Stop(context.Background())
if err != nil {
logger.Error("Failed to stop broker: %v", err)
}
}()
// 4. Subscribe to events
broker.Subscribe("public.users.*", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
logger.Info("User event: %s (operation: %s)", event.Type, event.Operation)
return nil
},
))
broker.Subscribe("*.*.create", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
logger.Info("Create event: %s.%s", event.Schema, event.Entity)
return nil
},
))
// 5. Publish events
ctx := context.Background()
// Database event
dbEvent := NewEvent(EventSourceDatabase, EventType("public", "users", "create"))
dbEvent.InstanceID = "example-instance"
dbEvent.UserID = 123
dbEvent.SessionID = "session-456"
dbEvent.Schema = "public"
dbEvent.Entity = "users"
dbEvent.Operation = "create"
dbEvent.SetPayload(map[string]interface{}{
"id": 123,
"name": "John Doe",
"email": "john@example.com",
})
if err := broker.PublishAsync(ctx, dbEvent); err != nil {
logger.Error("Failed to publish event: %v", err)
}
// WebSocket event
wsEvent := NewEvent(EventSourceWebSocket, "chat.message")
wsEvent.InstanceID = "example-instance"
wsEvent.UserID = 123
wsEvent.SessionID = "session-456"
wsEvent.SetPayload(map[string]interface{}{
"room": "general",
"message": "Hello, World!",
})
if err := broker.PublishAsync(ctx, wsEvent); err != nil {
logger.Error("Failed to publish event: %v", err)
}
// 6. Get statistics
time.Sleep(1 * time.Second) // Wait for processing
stats, _ := broker.Stats(ctx)
logger.Info("Broker stats: %d published, %d processed", stats.TotalPublished, stats.TotalProcessed)
}
// ExampleWithHooks demonstrates integration with the hook system
func ExampleWithHooks() {
// This would typically be called in your main.go or initialization code
// after setting up your restheadspec.Handler
// Pseudo-code (actual implementation would use real handler):
/*
broker := eventbroker.GetDefaultBroker()
hookRegistry := handler.Hooks()
// Register CRUD hooks
config := eventbroker.DefaultCRUDHookConfig()
config.EnableRead = false // Disable read events for performance
if err := eventbroker.RegisterCRUDHooks(broker, hookRegistry, config); err != nil {
logger.Error("Failed to register CRUD hooks: %v", err)
}
// Now all CRUD operations will automatically publish events
*/
}
// ExampleSubscriptionPatterns demonstrates different subscription patterns
func ExampleSubscriptionPatterns() {
broker := GetDefaultBroker()
if broker == nil {
return
}
// Pattern 1: Subscribe to all events from a specific entity
broker.Subscribe("public.users.*", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
fmt.Printf("User event: %s\n", event.Operation)
return nil
},
))
// Pattern 2: Subscribe to a specific operation across all entities
broker.Subscribe("*.*.create", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
fmt.Printf("Create event: %s.%s\n", event.Schema, event.Entity)
return nil
},
))
// Pattern 3: Subscribe to all events in a schema
broker.Subscribe("public.*.*", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
fmt.Printf("Public schema event: %s.%s\n", event.Entity, event.Operation)
return nil
},
))
// Pattern 4: Subscribe to everything (use with caution)
broker.Subscribe("*", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
fmt.Printf("Any event: %s\n", event.Type)
return nil
},
))
}
// ExampleErrorHandling demonstrates error handling in event handlers
func ExampleErrorHandling() {
broker := GetDefaultBroker()
if broker == nil {
return
}
// Handler that may fail
broker.Subscribe("public.users.create", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
// Simulate processing
var user struct {
ID int `json:"id"`
Email string `json:"email"`
}
if err := event.GetPayload(&user); err != nil {
return fmt.Errorf("invalid payload: %w", err)
}
// Validate
if user.Email == "" {
return fmt.Errorf("email is required")
}
// Process (e.g., send email)
logger.Info("Sending welcome email to %s", user.Email)
return nil
},
))
}
// ExampleConfiguration demonstrates initializing from configuration
func ExampleConfiguration() {
// This would typically be in your main.go
// Pseudo-code:
/*
// Load configuration
cfgMgr := config.NewManager()
if err := cfgMgr.Load(); err != nil {
logger.Fatal("Failed to load config: %v", err)
}
cfg, err := cfgMgr.GetConfig()
if err != nil {
logger.Fatal("Failed to get config: %v", err)
}
// Initialize event broker
if err := eventbroker.Initialize(cfg.EventBroker); err != nil {
logger.Fatal("Failed to initialize event broker: %v", err)
}
// Use the default broker
eventbroker.Subscribe("*.*.create", eventbroker.EventHandlerFunc(
func(ctx context.Context, event *eventbroker.Event) error {
logger.Info("Created: %s.%s", event.Schema, event.Entity)
return nil
},
))
*/
}
// ExampleYAMLConfiguration shows example YAML configuration
const ExampleYAMLConfiguration = `
event_broker:
enabled: true
provider: memory # memory, redis, nats, database
mode: async # sync, async
worker_count: 10
buffer_size: 1000
instance_id: "${HOSTNAME}"
# Memory provider is default, no additional config needed
# Redis provider (when provider: redis)
redis:
stream_name: "resolvespec:events"
consumer_group: "resolvespec-workers"
host: "localhost"
port: 6379
# NATS provider (when provider: nats)
nats:
url: "nats://localhost:4222"
stream_name: "RESOLVESPEC_EVENTS"
# Database provider (when provider: database)
database:
table_name: "events"
channel: "resolvespec_events"
# Retry policy
retry_policy:
max_retries: 3
initial_delay: 1s
max_delay: 30s
backoff_factor: 2.0
`

View File

@ -0,0 +1,56 @@
package eventbroker
import (
"fmt"
"os"
"time"
"github.com/bitechdev/ResolveSpec/pkg/config"
)
// NewProviderFromConfig creates a provider based on configuration
func NewProviderFromConfig(cfg config.EventBrokerConfig) (Provider, error) {
switch cfg.Provider {
case "memory":
cleanupInterval := 5 * time.Minute
if cfg.Database.PollInterval > 0 {
cleanupInterval = cfg.Database.PollInterval
}
return NewMemoryProvider(MemoryProviderOptions{
InstanceID: getInstanceID(cfg.InstanceID),
MaxEvents: 10000,
CleanupInterval: cleanupInterval,
}), nil
case "redis":
// Redis provider will be implemented in Phase 8
return nil, fmt.Errorf("redis provider not yet implemented")
case "nats":
// NATS provider will be implemented in Phase 9
return nil, fmt.Errorf("nats provider not yet implemented")
case "database":
// Database provider will be implemented in Phase 7
return nil, fmt.Errorf("database provider not yet implemented")
default:
return nil, fmt.Errorf("unknown provider: %s", cfg.Provider)
}
}
// getInstanceID returns the instance ID, defaulting to hostname if not specified
func getInstanceID(configID string) string {
if configID != "" {
return configID
}
// Try to get hostname
if hostname, err := os.Hostname(); err == nil {
return hostname
}
// Fallback to a default
return "resolvespec-instance"
}

View File

@ -0,0 +1,17 @@
package eventbroker
import "context"
// EventHandler processes an event
type EventHandler interface {
Handle(ctx context.Context, event *Event) error
}
// EventHandlerFunc is a function adapter for EventHandler
// This allows using regular functions as event handlers
type EventHandlerFunc func(ctx context.Context, event *Event) error
// Handle implements EventHandler
func (f EventHandlerFunc) Handle(ctx context.Context, event *Event) error {
return f(ctx, event)
}

137
pkg/eventbroker/hooks.go Normal file
View File

@ -0,0 +1,137 @@
package eventbroker
import (
"encoding/json"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// CRUDHookConfig configures which CRUD operations should trigger events
type CRUDHookConfig struct {
EnableCreate bool
EnableRead bool
EnableUpdate bool
EnableDelete bool
}
// DefaultCRUDHookConfig returns default configuration (all enabled)
func DefaultCRUDHookConfig() *CRUDHookConfig {
return &CRUDHookConfig{
EnableCreate: true,
EnableRead: false, // Typically disabled for performance
EnableUpdate: true,
EnableDelete: true,
}
}
// RegisterCRUDHooks registers event hooks for CRUD operations
// This integrates with the restheadspec.HookRegistry to automatically
// capture database events
func RegisterCRUDHooks(broker Broker, hookRegistry *restheadspec.HookRegistry, config *CRUDHookConfig) error {
if broker == nil {
return fmt.Errorf("broker cannot be nil")
}
if hookRegistry == nil {
return fmt.Errorf("hookRegistry cannot be nil")
}
if config == nil {
config = DefaultCRUDHookConfig()
}
// Create hook handler factory
createHookHandler := func(operation string) restheadspec.HookFunc {
return func(hookCtx *restheadspec.HookContext) error {
// Get user context from Go context
userCtx, ok := security.GetUserContext(hookCtx.Context)
if !ok || userCtx == nil {
logger.Debug("No user context found in hook")
userCtx = &security.UserContext{} // Empty user context
}
// Create event
event := NewEvent(EventSourceDatabase, EventType(hookCtx.Schema, hookCtx.Entity, operation))
event.InstanceID = broker.InstanceID()
event.UserID = userCtx.UserID
event.SessionID = userCtx.SessionID
event.Schema = hookCtx.Schema
event.Entity = hookCtx.Entity
event.Operation = operation
// Set payload based on operation
var payload interface{}
switch operation {
case "create":
payload = hookCtx.Result
case "read":
payload = hookCtx.Result
case "update":
payload = map[string]interface{}{
"id": hookCtx.ID,
"data": hookCtx.Data,
}
case "delete":
payload = map[string]interface{}{
"id": hookCtx.ID,
}
}
if payload != nil {
if err := event.SetPayload(payload); err != nil {
logger.Error("Failed to set event payload: %v", err)
payload = map[string]interface{}{"error": "failed to serialize payload"}
event.Payload, _ = json.Marshal(payload)
}
}
// Add metadata
if userCtx.UserName != "" {
event.Metadata["user_name"] = userCtx.UserName
}
if userCtx.Email != "" {
event.Metadata["user_email"] = userCtx.Email
}
if len(userCtx.Roles) > 0 {
event.Metadata["user_roles"] = userCtx.Roles
}
event.Metadata["table_name"] = hookCtx.TableName
// Publish asynchronously to not block CRUD operation
if err := broker.PublishAsync(hookCtx.Context, event); err != nil {
logger.Error("Failed to publish %s event for %s.%s: %v",
operation, hookCtx.Schema, hookCtx.Entity, err)
// Don't fail the CRUD operation if event publishing fails
return nil
}
logger.Debug("Published %s event for %s.%s (ID: %s)",
operation, hookCtx.Schema, hookCtx.Entity, event.ID)
return nil
}
}
// Register hooks based on configuration
if config.EnableCreate {
hookRegistry.Register(restheadspec.AfterCreate, createHookHandler("create"))
logger.Info("Registered event hook for CREATE operations")
}
if config.EnableRead {
hookRegistry.Register(restheadspec.AfterRead, createHookHandler("read"))
logger.Info("Registered event hook for READ operations")
}
if config.EnableUpdate {
hookRegistry.Register(restheadspec.AfterUpdate, createHookHandler("update"))
logger.Info("Registered event hook for UPDATE operations")
}
if config.EnableDelete {
hookRegistry.Register(restheadspec.AfterDelete, createHookHandler("delete"))
logger.Info("Registered event hook for DELETE operations")
}
return nil
}

View File

@ -0,0 +1,28 @@
package eventbroker
import (
"time"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
)
// recordEventPublished records an event publication metric
func recordEventPublished(event *Event) {
if mp := metrics.GetProvider(); mp != nil {
mp.RecordEventPublished(string(event.Source), event.Type)
}
}
// recordEventProcessed records an event processing metric
func recordEventProcessed(event *Event, duration time.Duration) {
if mp := metrics.GetProvider(); mp != nil {
mp.RecordEventProcessed(string(event.Source), event.Type, string(event.Status), duration)
}
}
// updateQueueSize updates the event queue size metric
func updateQueueSize(size int64) {
if mp := metrics.GetProvider(); mp != nil {
mp.UpdateEventQueueSize(size)
}
}

View File

@ -0,0 +1,70 @@
package eventbroker
import (
"context"
"time"
)
// Provider defines the storage backend interface for events
// Implementations: MemoryProvider, RedisProvider, NATSProvider, DatabaseProvider
type Provider interface {
// Store stores an event
Store(ctx context.Context, event *Event) error
// Get retrieves an event by ID
Get(ctx context.Context, id string) (*Event, error)
// List lists events with optional filters
List(ctx context.Context, filter *EventFilter) ([]*Event, error)
// UpdateStatus updates the status of an event
UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error
// Delete deletes an event by ID
Delete(ctx context.Context, id string) error
// Stream returns a channel of events for real-time consumption
// Used for cross-instance pub/sub
// The channel is closed when the context is canceled or an error occurs
Stream(ctx context.Context, pattern string) (<-chan *Event, error)
// Publish publishes an event to all subscribers (for distributed providers)
// For in-memory provider, this is the same as Store
// For Redis/NATS/Database, this triggers cross-instance delivery
Publish(ctx context.Context, event *Event) error
// Close closes the provider and releases resources
Close() error
// Stats returns provider statistics
Stats(ctx context.Context) (*ProviderStats, error)
}
// EventFilter defines filter criteria for listing events
type EventFilter struct {
Source *EventSource
Status *EventStatus
UserID *int
Schema string
Entity string
Operation string
InstanceID string
StartTime *time.Time
EndTime *time.Time
Limit int
Offset int
}
// ProviderStats contains statistics about the provider
type ProviderStats struct {
ProviderType string `json:"provider_type"`
TotalEvents int64 `json:"total_events"`
PendingEvents int64 `json:"pending_events"`
ProcessingEvents int64 `json:"processing_events"`
CompletedEvents int64 `json:"completed_events"`
FailedEvents int64 `json:"failed_events"`
EventsPublished int64 `json:"events_published"`
EventsConsumed int64 `json:"events_consumed"`
ActiveSubscribers int `json:"active_subscribers"`
ProviderSpecific map[string]interface{} `json:"provider_specific,omitempty"`
}

View File

@ -0,0 +1,446 @@
package eventbroker
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// MemoryProvider implements Provider interface using in-memory storage
// Features:
// - Thread-safe event storage with RW mutex
// - LRU eviction when max events reached
// - In-process pub/sub (not cross-instance)
// - Automatic cleanup of old completed events
type MemoryProvider struct {
mu sync.RWMutex
events map[string]*Event
eventOrder []string // For LRU tracking
subscribers map[string][]chan *Event
instanceID string
maxEvents int
cleanupInterval time.Duration
maxAge time.Duration
// Statistics
stats MemoryProviderStats
// Lifecycle
stopCleanup chan struct{}
wg sync.WaitGroup
isRunning atomic.Bool
}
// MemoryProviderStats contains statistics for the memory provider
type MemoryProviderStats struct {
TotalEvents atomic.Int64
PendingEvents atomic.Int64
ProcessingEvents atomic.Int64
CompletedEvents atomic.Int64
FailedEvents atomic.Int64
EventsPublished atomic.Int64
EventsConsumed atomic.Int64
ActiveSubscribers atomic.Int32
Evictions atomic.Int64
}
// MemoryProviderOptions configures the memory provider
type MemoryProviderOptions struct {
InstanceID string
MaxEvents int
CleanupInterval time.Duration
MaxAge time.Duration
}
// NewMemoryProvider creates a new in-memory event provider
func NewMemoryProvider(opts MemoryProviderOptions) *MemoryProvider {
if opts.MaxEvents == 0 {
opts.MaxEvents = 10000 // Default
}
if opts.CleanupInterval == 0 {
opts.CleanupInterval = 5 * time.Minute // Default
}
if opts.MaxAge == 0 {
opts.MaxAge = 24 * time.Hour // Default: keep events for 24 hours
}
mp := &MemoryProvider{
events: make(map[string]*Event),
eventOrder: make([]string, 0),
subscribers: make(map[string][]chan *Event),
instanceID: opts.InstanceID,
maxEvents: opts.MaxEvents,
cleanupInterval: opts.CleanupInterval,
maxAge: opts.MaxAge,
stopCleanup: make(chan struct{}),
}
mp.isRunning.Store(true)
// Start cleanup goroutine
mp.wg.Add(1)
go mp.cleanupLoop()
logger.Info("Memory provider initialized (max_events: %d, cleanup: %v, max_age: %v)",
opts.MaxEvents, opts.CleanupInterval, opts.MaxAge)
return mp
}
// Store stores an event
func (mp *MemoryProvider) Store(ctx context.Context, event *Event) error {
mp.mu.Lock()
defer mp.mu.Unlock()
// Check if we need to evict oldest events
if len(mp.events) >= mp.maxEvents {
mp.evictOldestLocked()
}
// Store event
mp.events[event.ID] = event.Clone()
mp.eventOrder = append(mp.eventOrder, event.ID)
// Update statistics
mp.stats.TotalEvents.Add(1)
mp.updateStatusCountsLocked(event.Status, 1)
return nil
}
// Get retrieves an event by ID
func (mp *MemoryProvider) Get(ctx context.Context, id string) (*Event, error) {
mp.mu.RLock()
defer mp.mu.RUnlock()
event, exists := mp.events[id]
if !exists {
return nil, fmt.Errorf("event not found: %s", id)
}
return event.Clone(), nil
}
// List lists events with optional filters
func (mp *MemoryProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
mp.mu.RLock()
defer mp.mu.RUnlock()
var results []*Event
for _, event := range mp.events {
if mp.matchesFilter(event, filter) {
results = append(results, event.Clone())
}
}
// Apply limit and offset
if filter != nil {
if filter.Offset > 0 && filter.Offset < len(results) {
results = results[filter.Offset:]
}
if filter.Limit > 0 && filter.Limit < len(results) {
results = results[:filter.Limit]
}
}
return results, nil
}
// UpdateStatus updates the status of an event
func (mp *MemoryProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
mp.mu.Lock()
defer mp.mu.Unlock()
event, exists := mp.events[id]
if !exists {
return fmt.Errorf("event not found: %s", id)
}
// Update status counts
mp.updateStatusCountsLocked(event.Status, -1)
mp.updateStatusCountsLocked(status, 1)
// Update event
event.Status = status
if errorMsg != "" {
event.Error = errorMsg
}
return nil
}
// Delete deletes an event by ID
func (mp *MemoryProvider) Delete(ctx context.Context, id string) error {
mp.mu.Lock()
defer mp.mu.Unlock()
event, exists := mp.events[id]
if !exists {
return fmt.Errorf("event not found: %s", id)
}
// Update counts
mp.stats.TotalEvents.Add(-1)
mp.updateStatusCountsLocked(event.Status, -1)
// Delete event
delete(mp.events, id)
// Remove from order tracking
for i, eid := range mp.eventOrder {
if eid == id {
mp.eventOrder = append(mp.eventOrder[:i], mp.eventOrder[i+1:]...)
break
}
}
return nil
}
// Stream returns a channel of events for real-time consumption
// Note: This is in-process only, not cross-instance
func (mp *MemoryProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
mp.mu.Lock()
defer mp.mu.Unlock()
// Create buffered channel for events
ch := make(chan *Event, 100)
// Store subscriber
mp.subscribers[pattern] = append(mp.subscribers[pattern], ch)
mp.stats.ActiveSubscribers.Add(1)
// Goroutine to clean up on context cancellation
mp.wg.Add(1)
go func() {
defer mp.wg.Done()
<-ctx.Done()
mp.mu.Lock()
defer mp.mu.Unlock()
// Remove subscriber
subs := mp.subscribers[pattern]
for i, subCh := range subs {
if subCh == ch {
mp.subscribers[pattern] = append(subs[:i], subs[i+1:]...)
break
}
}
mp.stats.ActiveSubscribers.Add(-1)
close(ch)
}()
logger.Debug("Stream created for pattern: %s", pattern)
return ch, nil
}
// Publish publishes an event to all subscribers
func (mp *MemoryProvider) Publish(ctx context.Context, event *Event) error {
// Store the event first
if err := mp.Store(ctx, event); err != nil {
return err
}
mp.stats.EventsPublished.Add(1)
// Notify subscribers
mp.mu.RLock()
defer mp.mu.RUnlock()
for pattern, channels := range mp.subscribers {
if matchPattern(pattern, event.Type) {
for _, ch := range channels {
select {
case ch <- event.Clone():
mp.stats.EventsConsumed.Add(1)
default:
// Channel full, skip
logger.Warn("Subscriber channel full for pattern: %s", pattern)
}
}
}
}
return nil
}
// Close closes the provider and releases resources
func (mp *MemoryProvider) Close() error {
if !mp.isRunning.Load() {
return nil
}
mp.isRunning.Store(false)
// Stop cleanup loop
close(mp.stopCleanup)
// Wait for goroutines
mp.wg.Wait()
// Close all subscriber channels
mp.mu.Lock()
for _, channels := range mp.subscribers {
for _, ch := range channels {
close(ch)
}
}
mp.subscribers = make(map[string][]chan *Event)
mp.mu.Unlock()
logger.Info("Memory provider closed")
return nil
}
// Stats returns provider statistics
func (mp *MemoryProvider) Stats(ctx context.Context) (*ProviderStats, error) {
return &ProviderStats{
ProviderType: "memory",
TotalEvents: mp.stats.TotalEvents.Load(),
PendingEvents: mp.stats.PendingEvents.Load(),
ProcessingEvents: mp.stats.ProcessingEvents.Load(),
CompletedEvents: mp.stats.CompletedEvents.Load(),
FailedEvents: mp.stats.FailedEvents.Load(),
EventsPublished: mp.stats.EventsPublished.Load(),
EventsConsumed: mp.stats.EventsConsumed.Load(),
ActiveSubscribers: int(mp.stats.ActiveSubscribers.Load()),
ProviderSpecific: map[string]interface{}{
"max_events": mp.maxEvents,
"cleanup_interval": mp.cleanupInterval.String(),
"max_age": mp.maxAge.String(),
"evictions": mp.stats.Evictions.Load(),
},
}, nil
}
// cleanupLoop periodically cleans up old completed events
func (mp *MemoryProvider) cleanupLoop() {
defer mp.wg.Done()
ticker := time.NewTicker(mp.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
mp.cleanup()
case <-mp.stopCleanup:
return
}
}
}
// cleanup removes old completed/failed events
func (mp *MemoryProvider) cleanup() {
mp.mu.Lock()
defer mp.mu.Unlock()
cutoff := time.Now().Add(-mp.maxAge)
removed := 0
for id, event := range mp.events {
// Only clean up completed or failed events that are old
if (event.Status == EventStatusCompleted || event.Status == EventStatusFailed) &&
event.CreatedAt.Before(cutoff) {
delete(mp.events, id)
mp.stats.TotalEvents.Add(-1)
mp.updateStatusCountsLocked(event.Status, -1)
// Remove from order tracking
for i, eid := range mp.eventOrder {
if eid == id {
mp.eventOrder = append(mp.eventOrder[:i], mp.eventOrder[i+1:]...)
break
}
}
removed++
}
}
if removed > 0 {
logger.Debug("Cleanup removed %d old events", removed)
}
}
// evictOldestLocked evicts the oldest event (LRU)
// Caller must hold write lock
func (mp *MemoryProvider) evictOldestLocked() {
if len(mp.eventOrder) == 0 {
return
}
// Get oldest event ID
oldestID := mp.eventOrder[0]
mp.eventOrder = mp.eventOrder[1:]
// Remove event
if event, exists := mp.events[oldestID]; exists {
delete(mp.events, oldestID)
mp.stats.TotalEvents.Add(-1)
mp.updateStatusCountsLocked(event.Status, -1)
mp.stats.Evictions.Add(1)
logger.Debug("Evicted oldest event: %s", oldestID)
}
}
// matchesFilter checks if an event matches the filter criteria
func (mp *MemoryProvider) matchesFilter(event *Event, filter *EventFilter) bool {
if filter == nil {
return true
}
if filter.Source != nil && event.Source != *filter.Source {
return false
}
if filter.Status != nil && event.Status != *filter.Status {
return false
}
if filter.UserID != nil && event.UserID != *filter.UserID {
return false
}
if filter.Schema != "" && event.Schema != filter.Schema {
return false
}
if filter.Entity != "" && event.Entity != filter.Entity {
return false
}
if filter.Operation != "" && event.Operation != filter.Operation {
return false
}
if filter.InstanceID != "" && event.InstanceID != filter.InstanceID {
return false
}
if filter.StartTime != nil && event.CreatedAt.Before(*filter.StartTime) {
return false
}
if filter.EndTime != nil && event.CreatedAt.After(*filter.EndTime) {
return false
}
return true
}
// updateStatusCountsLocked updates status statistics
// Caller must hold write lock
func (mp *MemoryProvider) updateStatusCountsLocked(status EventStatus, delta int64) {
switch status {
case EventStatusPending:
mp.stats.PendingEvents.Add(delta)
case EventStatusProcessing:
mp.stats.ProcessingEvents.Add(delta)
case EventStatusCompleted:
mp.stats.CompletedEvents.Add(delta)
case EventStatusFailed:
mp.stats.FailedEvents.Add(delta)
}
}

View File

@ -0,0 +1,419 @@
package eventbroker
import (
"context"
"testing"
"time"
)
func TestNewMemoryProvider(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
MaxEvents: 100,
CleanupInterval: 1 * time.Minute,
})
if provider == nil {
t.Fatal("Expected non-nil provider")
}
stats, err := provider.Stats(context.Background())
if err != nil {
t.Fatalf("Stats failed: %v", err)
}
if stats.ProviderType != "memory" {
t.Errorf("Expected provider type 'memory', got %s", stats.ProviderType)
}
}
func TestMemoryProviderPublishAndGet(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
event := NewEvent(EventSourceDatabase, "public.users.create")
event.UserID = 123
// Publish event
if err := provider.Publish(context.Background(), event); err != nil {
t.Fatalf("Publish failed: %v", err)
}
// Get event
retrieved, err := provider.Get(context.Background(), event.ID)
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if retrieved.ID != event.ID {
t.Errorf("Expected event ID %s, got %s", event.ID, retrieved.ID)
}
if retrieved.UserID != 123 {
t.Errorf("Expected user ID 123, got %d", retrieved.UserID)
}
}
func TestMemoryProviderGetNonExistent(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
_, err := provider.Get(context.Background(), "non-existent-id")
if err == nil {
t.Error("Expected error when getting non-existent event")
}
}
func TestMemoryProviderUpdateStatus(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
event := NewEvent(EventSourceDatabase, "public.users.create")
provider.Publish(context.Background(), event)
// Update status to processing
err := provider.UpdateStatus(context.Background(), event.ID, EventStatusProcessing, "")
if err != nil {
t.Fatalf("UpdateStatus failed: %v", err)
}
retrieved, _ := provider.Get(context.Background(), event.ID)
if retrieved.Status != EventStatusProcessing {
t.Errorf("Expected status %s, got %s", EventStatusProcessing, retrieved.Status)
}
// Update status to failed with error
err = provider.UpdateStatus(context.Background(), event.ID, EventStatusFailed, "test error")
if err != nil {
t.Fatalf("UpdateStatus failed: %v", err)
}
retrieved, _ = provider.Get(context.Background(), event.ID)
if retrieved.Status != EventStatusFailed {
t.Errorf("Expected status %s, got %s", EventStatusFailed, retrieved.Status)
}
if retrieved.Error != "test error" {
t.Errorf("Expected error 'test error', got %s", retrieved.Error)
}
}
func TestMemoryProviderList(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Publish multiple events
for i := 0; i < 5; i++ {
event := NewEvent(EventSourceDatabase, "public.users.create")
provider.Publish(context.Background(), event)
}
// List all events
events, err := provider.List(context.Background(), &EventFilter{})
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(events) != 5 {
t.Errorf("Expected 5 events, got %d", len(events))
}
}
func TestMemoryProviderListWithFilter(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Publish events with different types
event1 := NewEvent(EventSourceDatabase, "public.users.create")
provider.Publish(context.Background(), event1)
event2 := NewEvent(EventSourceDatabase, "public.roles.create")
provider.Publish(context.Background(), event2)
event3 := NewEvent(EventSourceWebSocket, "chat.message")
provider.Publish(context.Background(), event3)
// Filter by source
source := EventSourceDatabase
events, err := provider.List(context.Background(), &EventFilter{
Source: &source,
})
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(events) != 2 {
t.Errorf("Expected 2 events with database source, got %d", len(events))
}
// Filter by status
status := EventStatusPending
events, err = provider.List(context.Background(), &EventFilter{
Status: &status,
})
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(events) != 3 {
t.Errorf("Expected 3 events with pending status, got %d", len(events))
}
}
func TestMemoryProviderListWithLimit(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Publish multiple events
for i := 0; i < 10; i++ {
event := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event)
}
// List with limit
events, err := provider.List(context.Background(), &EventFilter{
Limit: 5,
})
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(events) != 5 {
t.Errorf("Expected 5 events (limited), got %d", len(events))
}
}
func TestMemoryProviderDelete(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
event := NewEvent(EventSourceDatabase, "public.users.create")
provider.Publish(context.Background(), event)
// Delete event
err := provider.Delete(context.Background(), event.ID)
if err != nil {
t.Fatalf("Delete failed: %v", err)
}
// Verify deleted
_, err = provider.Get(context.Background(), event.ID)
if err == nil {
t.Error("Expected error when getting deleted event")
}
}
func TestMemoryProviderLRUEviction(t *testing.T) {
// Create provider with small max events
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
MaxEvents: 3,
})
// Publish 5 events
events := make([]*Event, 5)
for i := 0; i < 5; i++ {
events[i] = NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), events[i])
}
// First 2 events should be evicted
_, err := provider.Get(context.Background(), events[0].ID)
if err == nil {
t.Error("Expected first event to be evicted")
}
_, err = provider.Get(context.Background(), events[1].ID)
if err == nil {
t.Error("Expected second event to be evicted")
}
// Last 3 events should still exist
for i := 2; i < 5; i++ {
_, err := provider.Get(context.Background(), events[i].ID)
if err != nil {
t.Errorf("Expected event %d to still exist", i)
}
}
}
func TestMemoryProviderCleanup(t *testing.T) {
// Create provider with short cleanup interval
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
CleanupInterval: 100 * time.Millisecond,
MaxAge: 200 * time.Millisecond,
})
// Publish and complete an event
event := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event)
provider.UpdateStatus(context.Background(), event.ID, EventStatusCompleted, "")
// Wait for cleanup to run
time.Sleep(400 * time.Millisecond)
// Event should be cleaned up
_, err := provider.Get(context.Background(), event.ID)
if err == nil {
t.Error("Expected event to be cleaned up")
}
provider.Close()
}
func TestMemoryProviderStats(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
MaxEvents: 100,
})
// Publish events
for i := 0; i < 5; i++ {
event := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event)
}
stats, err := provider.Stats(context.Background())
if err != nil {
t.Fatalf("Stats failed: %v", err)
}
if stats.ProviderType != "memory" {
t.Errorf("Expected provider type 'memory', got %s", stats.ProviderType)
}
if stats.TotalEvents != 5 {
t.Errorf("Expected 5 total events, got %d", stats.TotalEvents)
}
}
func TestMemoryProviderClose(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
CleanupInterval: 100 * time.Millisecond,
})
// Publish event
event := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event)
// Close provider
err := provider.Close()
if err != nil {
t.Fatalf("Close failed: %v", err)
}
// Cleanup goroutine should be stopped
time.Sleep(200 * time.Millisecond)
}
func TestMemoryProviderConcurrency(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Concurrent publish
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func() {
defer func() { done <- true }()
event := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event)
}()
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
// Verify all events were stored
events, _ := provider.List(context.Background(), &EventFilter{})
if len(events) != 10 {
t.Errorf("Expected 10 events, got %d", len(events))
}
}
func TestMemoryProviderStream(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Stream is implemented for memory provider (in-process pub/sub)
ch, err := provider.Stream(context.Background(), "test.*")
if err != nil {
t.Fatalf("Stream failed: %v", err)
}
if ch == nil {
t.Error("Expected non-nil channel")
}
}
func TestMemoryProviderTimeRangeFilter(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Publish events at different times
event1 := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event1)
time.Sleep(10 * time.Millisecond)
event2 := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event2)
time.Sleep(10 * time.Millisecond)
event3 := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event3)
// Filter by time range
startTime := event2.CreatedAt.Add(-1 * time.Millisecond)
events, err := provider.List(context.Background(), &EventFilter{
StartTime: &startTime,
})
if err != nil {
t.Fatalf("List failed: %v", err)
}
// Should get events 2 and 3
if len(events) != 2 {
t.Errorf("Expected 2 events after start time, got %d", len(events))
}
}
func TestMemoryProviderInstanceIDFilter(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Publish events with different instance IDs
event1 := NewEvent(EventSourceDatabase, "test.event")
event1.InstanceID = "instance-1"
provider.Publish(context.Background(), event1)
event2 := NewEvent(EventSourceDatabase, "test.event")
event2.InstanceID = "instance-2"
provider.Publish(context.Background(), event2)
// Filter by instance ID
events, err := provider.List(context.Background(), &EventFilter{
InstanceID: "instance-1",
})
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(events) != 1 {
t.Errorf("Expected 1 event with instance-1, got %d", len(events))
}
if events[0].InstanceID != "instance-1" {
t.Errorf("Expected instance ID 'instance-1', got %s", events[0].InstanceID)
}
}

View File

@ -0,0 +1,140 @@
package eventbroker
import (
"fmt"
"strings"
"sync"
"sync/atomic"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// SubscriptionID uniquely identifies a subscription
type SubscriptionID string
// subscription represents a single subscription with its handler and pattern
type subscription struct {
id SubscriptionID
pattern string
handler EventHandler
}
// subscriptionManager manages event subscriptions and pattern matching
type subscriptionManager struct {
mu sync.RWMutex
subscriptions map[SubscriptionID]*subscription
nextID atomic.Uint64
}
// newSubscriptionManager creates a new subscription manager
func newSubscriptionManager() *subscriptionManager {
return &subscriptionManager{
subscriptions: make(map[SubscriptionID]*subscription),
}
}
// Subscribe adds a new subscription
func (sm *subscriptionManager) Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
if pattern == "" {
return "", fmt.Errorf("pattern cannot be empty")
}
if handler == nil {
return "", fmt.Errorf("handler cannot be nil")
}
id := SubscriptionID(fmt.Sprintf("sub-%d", sm.nextID.Add(1)))
sm.mu.Lock()
sm.subscriptions[id] = &subscription{
id: id,
pattern: pattern,
handler: handler,
}
sm.mu.Unlock()
logger.Info("Subscribed to pattern '%s' with ID: %s", pattern, id)
return id, nil
}
// Unsubscribe removes a subscription
func (sm *subscriptionManager) Unsubscribe(id SubscriptionID) error {
sm.mu.Lock()
defer sm.mu.Unlock()
if _, exists := sm.subscriptions[id]; !exists {
return fmt.Errorf("subscription not found: %s", id)
}
delete(sm.subscriptions, id)
logger.Info("Unsubscribed: %s", id)
return nil
}
// GetMatching returns all handlers that match the event type
func (sm *subscriptionManager) GetMatching(eventType string) []EventHandler {
sm.mu.RLock()
defer sm.mu.RUnlock()
var handlers []EventHandler
for _, sub := range sm.subscriptions {
if matchPattern(sub.pattern, eventType) {
handlers = append(handlers, sub.handler)
}
}
return handlers
}
// Count returns the number of active subscriptions
func (sm *subscriptionManager) Count() int {
sm.mu.RLock()
defer sm.mu.RUnlock()
return len(sm.subscriptions)
}
// Clear removes all subscriptions
func (sm *subscriptionManager) Clear() {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.subscriptions = make(map[SubscriptionID]*subscription)
logger.Info("Cleared all subscriptions")
}
// matchPattern implements glob-style pattern matching for event types
// Patterns:
// - "*" matches any single segment
// - "a.b.c" matches exactly "a.b.c"
// - "a.*.c" matches "a.anything.c"
// - "a.b.*" matches any operation on a.b
// - "*" matches everything
//
// Event type format: schema.entity.operation (e.g., "public.users.create")
func matchPattern(pattern, eventType string) bool {
// Wildcard matches everything
if pattern == "*" {
return true
}
// Exact match
if pattern == eventType {
return true
}
// Split pattern and event type by dots
patternParts := strings.Split(pattern, ".")
eventParts := strings.Split(eventType, ".")
// Different number of parts can only match if pattern has wildcards
if len(patternParts) != len(eventParts) {
return false
}
// Match each part
for i := range patternParts {
if patternParts[i] != "*" && patternParts[i] != eventParts[i] {
return false
}
}
return true
}

View File

@ -0,0 +1,270 @@
package eventbroker
import (
"context"
"testing"
)
func TestMatchPattern(t *testing.T) {
tests := []struct {
pattern string
eventType string
expected bool
}{
// Exact matches
{"public.users.create", "public.users.create", true},
{"public.users.create", "public.users.update", false},
// Wildcard matches
{"*", "public.users.create", true},
{"*", "anything", true},
{"public.*", "public.users", true},
{"public.*", "public.users.create", false}, // Different number of parts
{"public.*", "admin.users", false},
{"*.users.create", "public.users.create", true},
{"*.users.create", "admin.users.create", true},
{"*.users.create", "public.roles.create", false},
{"public.*.create", "public.users.create", true},
{"public.*.create", "public.roles.create", true},
{"public.*.create", "public.users.update", false},
// Multiple wildcards
{"*.*", "public.users", true},
{"*.*", "public.users.create", false}, // Different number of parts
{"*.*.create", "public.users.create", true},
{"*.*.create", "admin.roles.create", true},
{"*.*.create", "public.users.update", false},
// Edge cases
{"", "", true},
{"", "something", false},
{"something", "", false},
}
for _, tt := range tests {
t.Run(tt.pattern+"_vs_"+tt.eventType, func(t *testing.T) {
result := matchPattern(tt.pattern, tt.eventType)
if result != tt.expected {
t.Errorf("matchPattern(%q, %q) = %v, expected %v",
tt.pattern, tt.eventType, result, tt.expected)
}
})
}
}
func TestSubscriptionManager(t *testing.T) {
manager := newSubscriptionManager()
// Create test handler
called := false
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
called = true
return nil
})
// Test Subscribe
id, err := manager.Subscribe("public.users.*", handler)
if err != nil {
t.Fatalf("Subscribe failed: %v", err)
}
if id == "" {
t.Fatal("Expected non-empty subscription ID")
}
// Test GetMatching
handlers := manager.GetMatching("public.users.create")
if len(handlers) != 1 {
t.Fatalf("Expected 1 handler, got %d", len(handlers))
}
// Test handler execution
event := NewEvent(EventSourceDatabase, "public.users.create")
if err := handlers[0].Handle(context.Background(), event); err != nil {
t.Fatalf("Handler execution failed: %v", err)
}
if !called {
t.Error("Expected handler to be called")
}
// Test Count
if manager.Count() != 1 {
t.Errorf("Expected count 1, got %d", manager.Count())
}
// Test Unsubscribe
if err := manager.Unsubscribe(id); err != nil {
t.Fatalf("Unsubscribe failed: %v", err)
}
// Verify unsubscribed
handlers = manager.GetMatching("public.users.create")
if len(handlers) != 0 {
t.Errorf("Expected 0 handlers after unsubscribe, got %d", len(handlers))
}
if manager.Count() != 0 {
t.Errorf("Expected count 0 after unsubscribe, got %d", manager.Count())
}
}
func TestSubscriptionManagerMultipleHandlers(t *testing.T) {
manager := newSubscriptionManager()
called1 := false
handler1 := EventHandlerFunc(func(ctx context.Context, event *Event) error {
called1 = true
return nil
})
called2 := false
handler2 := EventHandlerFunc(func(ctx context.Context, event *Event) error {
called2 = true
return nil
})
// Subscribe multiple handlers
id1, _ := manager.Subscribe("public.users.*", handler1)
id2, _ := manager.Subscribe("*.users.*", handler2)
// Both should match
handlers := manager.GetMatching("public.users.create")
if len(handlers) != 2 {
t.Fatalf("Expected 2 handlers, got %d", len(handlers))
}
// Execute all handlers
event := NewEvent(EventSourceDatabase, "public.users.create")
for _, h := range handlers {
h.Handle(context.Background(), event)
}
if !called1 || !called2 {
t.Error("Expected both handlers to be called")
}
// Unsubscribe one
manager.Unsubscribe(id1)
handlers = manager.GetMatching("public.users.create")
if len(handlers) != 1 {
t.Errorf("Expected 1 handler after unsubscribe, got %d", len(handlers))
}
// Unsubscribe remaining
manager.Unsubscribe(id2)
if manager.Count() != 0 {
t.Errorf("Expected count 0 after all unsubscribe, got %d", manager.Count())
}
}
func TestSubscriptionManagerConcurrency(t *testing.T) {
manager := newSubscriptionManager()
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
return nil
})
// Subscribe and unsubscribe concurrently
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func() {
defer func() { done <- true }()
id, _ := manager.Subscribe("test.*", handler)
manager.GetMatching("test.event")
manager.Unsubscribe(id)
}()
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
// Should have no subscriptions left
if manager.Count() != 0 {
t.Errorf("Expected count 0 after concurrent operations, got %d", manager.Count())
}
}
func TestSubscriptionManagerUnsubscribeNonExistent(t *testing.T) {
manager := newSubscriptionManager()
// Try to unsubscribe a non-existent ID
err := manager.Unsubscribe("non-existent-id")
if err == nil {
t.Error("Expected error when unsubscribing non-existent ID")
}
}
func TestSubscriptionIDGeneration(t *testing.T) {
manager := newSubscriptionManager()
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
return nil
})
// Subscribe multiple times and ensure unique IDs
ids := make(map[SubscriptionID]bool)
for i := 0; i < 100; i++ {
id, _ := manager.Subscribe("test.*", handler)
if ids[id] {
t.Fatalf("Duplicate subscription ID: %s", id)
}
ids[id] = true
}
}
func TestEventHandlerFunc(t *testing.T) {
called := false
var receivedEvent *Event
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
called = true
receivedEvent = event
return nil
})
event := NewEvent(EventSourceDatabase, "test.event")
err := handler.Handle(context.Background(), event)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !called {
t.Error("Expected handler to be called")
}
if receivedEvent != event {
t.Error("Expected to receive the same event")
}
}
func TestSubscriptionManagerPatternPriority(t *testing.T) {
manager := newSubscriptionManager()
// More specific patterns should still match
specificCalled := false
genericCalled := false
manager.Subscribe("public.users.create", EventHandlerFunc(func(ctx context.Context, event *Event) error {
specificCalled = true
return nil
}))
manager.Subscribe("*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
genericCalled = true
return nil
}))
handlers := manager.GetMatching("public.users.create")
if len(handlers) != 2 {
t.Fatalf("Expected 2 matching handlers, got %d", len(handlers))
}
// Execute all handlers
event := NewEvent(EventSourceDatabase, "public.users.create")
for _, h := range handlers {
h.Handle(context.Background(), event)
}
if !specificCalled || !genericCalled {
t.Error("Expected both specific and generic handlers to be called")
}
}

View File

@ -0,0 +1,141 @@
package eventbroker
import (
"context"
"sync"
"sync/atomic"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// workerPool manages a pool of workers for async event processing
type workerPool struct {
workerCount int
bufferSize int
eventQueue chan *Event
processor func(context.Context, *Event) error
activeWorkers atomic.Int32
isRunning atomic.Bool
stopCh chan struct{}
wg sync.WaitGroup
}
// newWorkerPool creates a new worker pool
func newWorkerPool(workerCount, bufferSize int, processor func(context.Context, *Event) error) *workerPool {
return &workerPool{
workerCount: workerCount,
bufferSize: bufferSize,
eventQueue: make(chan *Event, bufferSize),
processor: processor,
stopCh: make(chan struct{}),
}
}
// Start starts the worker pool
func (wp *workerPool) Start() {
if wp.isRunning.Load() {
return
}
wp.isRunning.Store(true)
// Start workers
for i := 0; i < wp.workerCount; i++ {
wp.wg.Add(1)
go wp.worker(i)
}
logger.Info("Worker pool started with %d workers", wp.workerCount)
}
// Stop stops the worker pool gracefully
func (wp *workerPool) Stop(ctx context.Context) error {
if !wp.isRunning.Load() {
return nil
}
wp.isRunning.Store(false)
// Close event queue to signal workers
close(wp.eventQueue)
// Wait for workers to finish with context timeout
done := make(chan struct{})
go func() {
wp.wg.Wait()
close(done)
}()
select {
case <-done:
logger.Info("Worker pool stopped gracefully")
return nil
case <-ctx.Done():
logger.Warn("Worker pool stop timed out, some events may be lost")
return ctx.Err()
}
}
// Submit submits an event to the queue
func (wp *workerPool) Submit(ctx context.Context, event *Event) error {
if !wp.isRunning.Load() {
return ErrWorkerPoolStopped
}
select {
case wp.eventQueue <- event:
return nil
case <-ctx.Done():
return ctx.Err()
default:
return ErrQueueFull
}
}
// worker is a worker goroutine that processes events from the queue
func (wp *workerPool) worker(id int) {
defer wp.wg.Done()
logger.Debug("Worker %d started", id)
for event := range wp.eventQueue {
wp.activeWorkers.Add(1)
// Process event with background context (detached from original request)
ctx := context.Background()
if err := wp.processor(ctx, event); err != nil {
logger.Error("Worker %d failed to process event %s: %v", id, event.ID, err)
}
wp.activeWorkers.Add(-1)
}
logger.Debug("Worker %d stopped", id)
}
// QueueSize returns the current queue size
func (wp *workerPool) QueueSize() int {
return len(wp.eventQueue)
}
// ActiveWorkers returns the number of currently active workers
func (wp *workerPool) ActiveWorkers() int {
return int(wp.activeWorkers.Load())
}
// Error definitions
var (
ErrWorkerPoolStopped = &BrokerError{Code: "worker_pool_stopped", Message: "worker pool is stopped"}
ErrQueueFull = &BrokerError{Code: "queue_full", Message: "event queue is full"}
)
// BrokerError represents an error from the event broker
type BrokerError struct {
Code string
Message string
}
func (e *BrokerError) Error() string {
return e.Message
}

1056
pkg/funcspec/function_api.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,910 @@
package funcspec
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// MockDatabase implements common.Database interface for testing
type MockDatabase struct {
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
}
func (m *MockDatabase) NewSelect() common.SelectQuery {
return nil
}
func (m *MockDatabase) NewInsert() common.InsertQuery {
return nil
}
func (m *MockDatabase) NewUpdate() common.UpdateQuery {
return nil
}
func (m *MockDatabase) NewDelete() common.DeleteQuery {
return nil
}
func (m *MockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
if m.ExecFunc != nil {
return m.ExecFunc(ctx, query, args...)
}
return &MockResult{rows: 0}, nil
}
func (m *MockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if m.QueryFunc != nil {
return m.QueryFunc(ctx, dest, query, args...)
}
return nil
}
func (m *MockDatabase) BeginTx(ctx context.Context) (common.Database, error) {
return m, nil
}
func (m *MockDatabase) CommitTx(ctx context.Context) error {
return nil
}
func (m *MockDatabase) RollbackTx(ctx context.Context) error {
return nil
}
func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
if m.RunInTransactionFunc != nil {
return m.RunInTransactionFunc(ctx, fn)
}
return fn(m)
}
func (m *MockDatabase) GetUnderlyingDB() interface{} {
return m
}
// MockResult implements common.Result interface for testing
type MockResult struct {
rows int64
id int64
}
func (m *MockResult) RowsAffected() int64 {
return m.rows
}
func (m *MockResult) LastInsertId() (int64, error) {
return m.id, nil
}
// Helper function to create a test request with user context
func createTestRequest(method, path string, queryParams map[string]string, headers map[string]string, body []byte) *http.Request {
u, _ := url.Parse(path)
if queryParams != nil {
q := u.Query()
for k, v := range queryParams {
q.Set(k, v)
}
u.RawQuery = q.Encode()
}
var bodyReader *bytes.Reader
if body != nil {
bodyReader = bytes.NewReader(body)
} else {
bodyReader = bytes.NewReader([]byte{})
}
req := httptest.NewRequest(method, u.String(), bodyReader)
if headers != nil {
for k, v := range headers {
req.Header.Set(k, v)
}
}
// Add user context
userCtx := &security.UserContext{
UserID: 1,
UserName: "testuser",
SessionID: "test-session-123",
}
ctx := context.WithValue(req.Context(), security.UserContextKey, userCtx)
req = req.WithContext(ctx)
return req
}
// TestNewHandler tests handler creation
func TestNewHandler(t *testing.T) {
db := &MockDatabase{}
handler := NewHandler(db)
if handler == nil {
t.Fatal("Expected handler to be created, got nil")
}
if handler.db != db {
t.Error("Expected handler to have the provided database")
}
if handler.hooks == nil {
t.Error("Expected handler to have a hook registry")
}
}
// TestHandlerHooks tests the Hooks method
func TestHandlerHooks(t *testing.T) {
handler := NewHandler(&MockDatabase{})
hooks := handler.Hooks()
if hooks == nil {
t.Fatal("Expected hooks registry to be non-nil")
}
// Should return the same instance
hooks2 := handler.Hooks()
if hooks != hooks2 {
t.Error("Expected Hooks() to return the same registry instance")
}
}
// TestExtractInputVariables tests the extractInputVariables function
func TestExtractInputVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
expectedVars []string
}{
{
name: "No variables",
sqlQuery: "SELECT * FROM users",
expectedVars: []string{},
},
{
name: "Single variable",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
expectedVars: []string{"[user_id]"},
},
{
name: "Multiple variables",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = [user_name]",
expectedVars: []string{"[user_id]", "[user_name]"},
},
{
name: "Nested brackets",
sqlQuery: "SELECT * FROM users WHERE data::jsonb @> '[field]'::jsonb AND id = [user_id]",
expectedVars: []string{"[field]", "[user_id]"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
inputvars := make([]string, 0)
result := handler.extractInputVariables(tt.sqlQuery, &inputvars)
if result != tt.sqlQuery {
t.Errorf("Expected SQL query to be unchanged, got %s", result)
}
if len(inputvars) != len(tt.expectedVars) {
t.Errorf("Expected %d variables, got %d: %v", len(tt.expectedVars), len(inputvars), inputvars)
return
}
for i, expected := range tt.expectedVars {
if inputvars[i] != expected {
t.Errorf("Expected variable %d to be %s, got %s", i, expected, inputvars[i])
}
}
})
}
}
// TestValidSQL tests the SQL sanitization function
func TestValidSQL(t *testing.T) {
tests := []struct {
name string
input string
mode string
expected string
}{
{
name: "Column name with valid characters",
input: "user_id",
mode: "colname",
expected: "user_id",
},
{
name: "Column name with dots (table.column)",
input: "users.user_id",
mode: "colname",
expected: "users.user_id",
},
{
name: "Column name with SQL injection attempt",
input: "id'; DROP TABLE users--",
mode: "colname",
expected: "idDROPTABLEusers",
},
{
name: "Column value with single quotes",
input: "O'Brien",
mode: "colvalue",
expected: "O''Brien",
},
{
name: "Select with dangerous keywords",
input: "name, email; DROP TABLE users",
mode: "select",
expected: "name, email TABLE users",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ValidSQL(tt.input, tt.mode)
if result != tt.expected {
t.Errorf("ValidSQL(%q, %q) = %q, expected %q", tt.input, tt.mode, result, tt.expected)
}
})
}
}
// TestIsNumeric tests the IsNumeric function
func TestIsNumeric(t *testing.T) {
tests := []struct {
input string
expected bool
}{
{"123", true},
{"123.45", true},
{"-123", true},
{"-123.45", true},
{"0", true},
{"abc", false},
{"12.34.56", false},
{"", false},
{"123abc", false},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := IsNumeric(tt.input)
if result != tt.expected {
t.Errorf("IsNumeric(%q) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
// TestSqlQryWhere tests the WHERE clause manipulation
func TestSqlQryWhere(t *testing.T) {
tests := []struct {
name string
sqlQuery string
condition string
expected string
}{
{
name: "Add WHERE to query without WHERE",
sqlQuery: "SELECT * FROM users",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' ",
},
{
name: "Add AND to query with existing WHERE",
sqlQuery: "SELECT * FROM users WHERE id > 0",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE id > 0 AND status = 'active' ",
},
{
name: "Add WHERE before ORDER BY",
sqlQuery: "SELECT * FROM users ORDER BY name",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' ORDER BY name",
},
{
name: "Add WHERE before GROUP BY",
sqlQuery: "SELECT COUNT(*) FROM users GROUP BY department",
condition: "status = 'active'",
expected: "SELECT COUNT(*) FROM users WHERE status = 'active' GROUP BY department",
},
{
name: "Add WHERE before LIMIT",
sqlQuery: "SELECT * FROM users LIMIT 10",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' LIMIT 10",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sqlQryWhere(tt.sqlQuery, tt.condition)
if result != tt.expected {
t.Errorf("sqlQryWhere() = %q, expected %q", result, tt.expected)
}
})
}
}
// TestGetIPAddress tests IP address extraction
func TestGetIPAddress(t *testing.T) {
tests := []struct {
name string
setupReq func() *http.Request
expected string
}{
{
name: "X-Forwarded-For header",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
return req
},
expected: "192.168.1.100",
},
{
name: "X-Real-IP header",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.200")
return req
},
expected: "192.168.1.200",
},
{
name: "RemoteAddr fallback",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
return req
},
expected: "192.168.1.1:12345",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupReq()
result := getIPAddress(req)
if result != tt.expected {
t.Errorf("getIPAddress() = %q, expected %q", result, tt.expected)
}
})
}
}
// TestParsePaginationParams tests pagination parameter parsing
func TestParsePaginationParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
queryParams map[string]string
expectedSort string
expectedLimit int
expectedOffset int
}{
{
name: "No parameters - defaults",
queryParams: map[string]string{},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "All parameters provided",
queryParams: map[string]string{
"sort": "name,-created_at",
"limit": "100",
"offset": "50",
},
expectedSort: "name,-created_at",
expectedLimit: 100,
expectedOffset: 50,
},
{
name: "Invalid limit - use default",
queryParams: map[string]string{
"limit": "invalid",
},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "Negative offset - use default",
queryParams: map[string]string{
"offset": "-10",
},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
sort, limit, offset := handler.parsePaginationParams(req)
if sort != tt.expectedSort {
t.Errorf("Expected sort=%q, got %q", tt.expectedSort, sort)
}
if limit != tt.expectedLimit {
t.Errorf("Expected limit=%d, got %d", tt.expectedLimit, limit)
}
if offset != tt.expectedOffset {
t.Errorf("Expected offset=%d, got %d", tt.expectedOffset, offset)
}
})
}
}
// TestSqlQuery tests the SqlQuery handler for single record queries
func TestSqlQuery(t *testing.T) {
tests := []struct {
name string
sqlQuery string
blankParams bool
queryParams map[string]string
headers map[string]string
setupDB func() *MockDatabase
expectedStatus int
validateResp func(t *testing.T, body []byte)
}{
{
name: "Basic query - returns single record",
sqlQuery: "SELECT * FROM users WHERE id = 1",
blankParams: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "Test User", "email": "test@example.com"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, body []byte) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if result["name"] != "Test User" {
t.Errorf("Expected name='Test User', got %v", result["name"])
}
},
},
{
name: "Query with no results",
sqlQuery: "SELECT * FROM users WHERE id = 999",
blankParams: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
// Return empty array
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, body []byte) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 0 {
t.Errorf("Expected empty result, got %v", result)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.setupDB()
handler := NewHandler(db)
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQuery(tt.sqlQuery, SqlQueryOptions{BlankParams: tt.blankParams})
handlerFunc(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
if tt.validateResp != nil {
tt.validateResp(t, w.Body.Bytes())
}
})
}
}
// TestSqlQueryList tests the SqlQueryList handler for list queries
func TestSqlQueryList(t *testing.T) {
tests := []struct {
name string
sqlQuery string
noCount bool
blankParams bool
allowFilter bool
queryParams map[string]string
headers map[string]string
setupDB func() *MockDatabase
expectedStatus int
validateResp func(t *testing.T, w *httptest.ResponseRecorder)
}{
{
name: "Basic list query",
sqlQuery: "SELECT * FROM users",
noCount: false,
blankParams: false,
allowFilter: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
callCount := 0
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
callCount++
if strings.Contains(query, "COUNT") {
// Count query
countResult := dest.(*struct{ Count int64 })
countResult.Count = 2
} else {
// Main query
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "User 1"},
{"id": float64(2), "name": "User 2"},
}
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var result []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 2 {
t.Errorf("Expected 2 results, got %d", len(result))
}
// Check Content-Range header
contentRange := w.Header().Get("Content-Range")
if !strings.Contains(contentRange, "2") {
t.Errorf("Expected Content-Range to contain total count, got: %s", contentRange)
}
},
},
{
name: "List query with noCount",
sqlQuery: "SELECT * FROM users",
noCount: true,
blankParams: false,
allowFilter: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if strings.Contains(query, "COUNT") {
t.Error("Count query should not be executed when noCount is true")
}
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "User 1"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var result []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 1 {
t.Errorf("Expected 1 result, got %d", len(result))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.setupDB()
handler := NewHandler(db)
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQueryList(tt.sqlQuery, SqlQueryOptions{NoCount: tt.noCount, BlankParams: tt.blankParams, AllowFilter: tt.allowFilter})
handlerFunc(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", tt.expectedStatus, w.Code, w.Body.String())
}
if tt.validateResp != nil {
tt.validateResp(t, w)
}
})
}
}
// TestMergeQueryParams tests query parameter merging
func TestMergeQueryParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
queryParams map[string]string
allowFilter bool
expectedQuery string
checkVars func(t *testing.T, vars map[string]interface{})
}{
{
name: "Replace placeholder with parameter",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
queryParams: map[string]string{"p-user_id": "123"},
allowFilter: false,
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["p-user_id"] != "123" {
t.Errorf("Expected p-user_id=123, got %v", vars["p-user_id"])
}
},
},
{
name: "Add filter when allowed",
sqlQuery: "SELECT * FROM users",
queryParams: map[string]string{"status": "active"},
allowFilter: true,
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["status"] != "active" {
t.Errorf("Expected status=active, got %v", vars["status"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
result := handler.mergeQueryParams(req, tt.sqlQuery, variables, tt.allowFilter, propQry)
if result == "" {
t.Error("Expected non-empty SQL query result")
}
if tt.checkVars != nil {
tt.checkVars(t, variables)
}
})
}
}
// TestMergeHeaderParams tests header parameter merging
func TestMergeHeaderParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
headers map[string]string
expectedQuery string
checkVars func(t *testing.T, vars map[string]interface{})
}{
{
name: "Field filter header",
sqlQuery: "SELECT * FROM users",
headers: map[string]string{"X-FieldFilter-Status": "1"},
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["x-fieldfilter-status"] != "1" {
t.Errorf("Expected x-fieldfilter-status=1, got %v", vars["x-fieldfilter-status"])
}
},
},
{
name: "Search filter header",
sqlQuery: "SELECT * FROM users",
headers: map[string]string{"X-SearchFilter-Name": "john"},
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["x-searchfilter-name"] != "john" {
t.Errorf("Expected x-searchfilter-name=john, got %v", vars["x-searchfilter-name"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", nil, tt.headers, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
complexAPI := false
result := handler.mergeHeaderParams(req, tt.sqlQuery, variables, propQry, &complexAPI)
if result == "" {
t.Error("Expected non-empty SQL query result")
}
if tt.checkVars != nil {
tt.checkVars(t, variables)
}
})
}
}
// TestReplaceMetaVariables tests meta variable replacement
func TestReplaceMetaVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
userCtx := &security.UserContext{
UserID: 123,
UserName: "testuser",
SessionID: "ABC456",
SessionRID: 456,
}
metainfo := map[string]interface{}{
"ipaddress": "192.168.1.1",
"url": "/api/test",
}
variables := map[string]interface{}{
"param1": "value1",
}
tests := []struct {
name string
sqlQuery string
expectedCheck func(result string) bool
}{
{
name: "Replace [rid_user]",
sqlQuery: "SELECT * FROM users WHERE created_by = [rid_user]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "123")
},
},
{
name: "Replace [user]",
sqlQuery: "SELECT * FROM audit WHERE username = [user]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "'testuser'")
},
},
{
name: "Replace [rid_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [rid_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "456")
},
}, {
name: "Replace [id_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [id_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "ABC456")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", nil, nil, nil)
result := handler.replaceMetaVariables(tt.sqlQuery, req, userCtx, metainfo, variables)
if !tt.expectedCheck(result) {
t.Errorf("Meta variable replacement failed. Query: %s", result)
}
})
}
}
// TestGetReplacementForBlankParam tests the blank parameter replacement logic
func TestGetReplacementForBlankParam(t *testing.T) {
tests := []struct {
name string
sqlQuery string
param string
expected string
}{
{
name: "Parameter in single quotes",
sqlQuery: "SELECT * FROM users WHERE name = '[username]'",
param: "[username]",
expected: "",
},
{
name: "Parameter in dollar quotes",
sqlQuery: "SELECT * FROM users WHERE data = $[jsondata]$",
param: "[jsondata]",
expected: "",
},
{
name: "Parameter not in quotes",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter not in quotes with AND",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND status = 1",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter in mixed quote context - before quote",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = 'test'",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter in mixed quote context - in quotes",
sqlQuery: "SELECT * FROM users WHERE name = '[username]' AND id = 1",
param: "[username]",
expected: "",
},
{
name: "Parameter with dollar quote tag",
sqlQuery: "SELECT * FROM users WHERE body = $tag$[content]$tag$",
param: "[content]",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getReplacementForBlankParam(tt.sqlQuery, tt.param)
if result != tt.expected {
t.Errorf("Expected replacement '%s', got '%s' for query: %s", tt.expected, result, tt.sqlQuery)
}
})
}
}

160
pkg/funcspec/hooks.go Normal file
View File

@ -0,0 +1,160 @@
package funcspec
import (
"context"
"fmt"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// HookType defines the type of hook to execute
type HookType string
const (
// Query operation hooks (for SqlQuery - single record)
BeforeQuery HookType = "before_query"
AfterQuery HookType = "after_query"
// Query list operation hooks (for SqlQueryList - multiple records)
BeforeQueryList HookType = "before_query_list"
AfterQueryList HookType = "after_query_list"
// SQL execution hooks (just before SQL is executed)
BeforeSQLExec HookType = "before_sql_exec"
AfterSQLExec HookType = "after_sql_exec"
// Response hooks (before response is sent)
BeforeResponse HookType = "before_response"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler // Reference to the handler for accessing database
Request *http.Request
Writer http.ResponseWriter
// SQL query and variables
SQLQuery string // The SQL query being executed (can be modified by hooks)
Variables map[string]interface{} // Variables extracted from request
InputVars []string // Input variable placeholders found in query
MetaInfo map[string]interface{} // Metadata about the request
PropQry map[string]string // Property query parameters
// User context
UserContext *security.UserContext
// Pagination and filtering (for list queries)
SortColumns string
Limit int
Offset int
// Results
Result interface{} // Query result (single record or list)
Total int64 // Total count (for list queries)
Error error // Error if operation failed
ComplexAPI bool // Whether complex API response format is requested
NoCount bool // Whether count query should be skipped
BlankParams bool // Whether blank parameters should be removed
AllowFilter bool // Whether filtering is allowed
// Allow hooks to abort the operation
Abort bool // If set to true, the operation will be aborted
AbortMessage string // Message to return if aborted
AbortCode int // HTTP status code if aborted
}
// HookFunc is the signature for hook functions
// It receives a HookContext and can modify it or return an error
// If an error is returned, the operation will be aborted
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
// NewHookRegistry creates a new hook registry
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
// Register adds a new hook for the specified hook type
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered funcspec hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
// RegisterMultiple registers a hook for multiple hook types
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
// Execute runs all hooks for the specified type in order
// If any hook returns an error, execution stops and the error is returned
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
return nil
}
logger.Debug("Executing %d funcspec hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("Funcspec hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
// Check if hook requested abort
if ctx.Abort {
logger.Warn("Funcspec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil
}
// Clear removes all hooks for the specified type
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
logger.Info("Cleared all funcspec hooks for %s", hookType)
}
// ClearAll removes all registered hooks
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
logger.Info("Cleared all funcspec hooks")
}
// Count returns the number of hooks registered for a specific type
func (r *HookRegistry) Count(hookType HookType) int {
if hooks, exists := r.hooks[hookType]; exists {
return len(hooks)
}
return 0
}
// HasHooks returns true if there are any hooks registered for the specified type
func (r *HookRegistry) HasHooks(hookType HookType) bool {
return r.Count(hookType) > 0
}
// GetAllHookTypes returns all hook types that have registered hooks
func (r *HookRegistry) GetAllHookTypes() []HookType {
types := make([]HookType, 0, len(r.hooks))
for hookType := range r.hooks {
types = append(types, hookType)
}
return types
}

View File

@ -0,0 +1,137 @@
package funcspec
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Example hook functions demonstrating various use cases
// ExampleLoggingHook logs all SQL queries before execution
func ExampleLoggingHook(ctx *HookContext) error {
logger.Info("Executing SQL query for user %s: %s", ctx.UserContext.UserName, ctx.SQLQuery)
return nil
}
// ExampleSecurityHook validates user permissions before executing queries
func ExampleSecurityHook(ctx *HookContext) error {
// Example: Block queries that try to access sensitive tables
if strings.Contains(strings.ToLower(ctx.SQLQuery), "sensitive_table") {
if ctx.UserContext.UserID != 1 { // Only admin can access
ctx.Abort = true
ctx.AbortCode = 403
ctx.AbortMessage = "Access denied: insufficient permissions"
return fmt.Errorf("access denied to sensitive_table")
}
}
return nil
}
// ExampleQueryModificationHook modifies SQL queries to add user-specific filtering
func ExampleQueryModificationHook(ctx *HookContext) error {
// Example: Automatically add user_id filter for non-admin users
if ctx.UserContext.UserID != 1 { // Not admin
// Add WHERE clause to filter by user_id
if !strings.Contains(strings.ToLower(ctx.SQLQuery), "where") {
ctx.SQLQuery = fmt.Sprintf("%s WHERE user_id = %d", ctx.SQLQuery, ctx.UserContext.UserID)
} else {
ctx.SQLQuery = strings.Replace(
ctx.SQLQuery,
"WHERE",
fmt.Sprintf("WHERE user_id = %d AND", ctx.UserContext.UserID),
1,
)
}
logger.Debug("Modified query for user %d: %s", ctx.UserContext.UserID, ctx.SQLQuery)
}
return nil
}
// ExampleResultFilterHook filters results after query execution
func ExampleResultFilterHook(ctx *HookContext) error {
// Example: Remove sensitive fields from results for non-admin users
if ctx.UserContext.UserID != 1 { // Not admin
switch result := ctx.Result.(type) {
case []map[string]interface{}:
// Filter list results
for i := range result {
delete(result[i], "password")
delete(result[i], "ssn")
delete(result[i], "credit_card")
}
case map[string]interface{}:
// Filter single record
delete(result, "password")
delete(result, "ssn")
delete(result, "credit_card")
}
}
return nil
}
// ExampleAuditHook logs all queries and results for audit purposes
func ExampleAuditHook(ctx *HookContext) error {
// Log to audit table or external system
logger.Info("AUDIT: User %s (%d) executed query from %s",
ctx.UserContext.UserName,
ctx.UserContext.UserID,
ctx.Request.RemoteAddr,
)
// In a real implementation, you might:
// - Insert into an audit log table
// - Send to a logging service
// - Write to a file
return nil
}
// ExampleCacheHook implements simple response caching
func ExampleCacheHook(ctx *HookContext) error {
// This is a simplified example - real caching would use a proper cache store
// Check if we have a cached result for this query
// cacheKey := fmt.Sprintf("%s:%s", ctx.UserContext.UserName, ctx.SQLQuery)
// if cachedResult := checkCache(cacheKey); cachedResult != nil {
// ctx.Result = cachedResult
// ctx.Abort = true // Skip query execution
// ctx.AbortMessage = "Serving from cache"
// }
return nil
}
// ExampleErrorHandlingHook provides custom error handling
func ExampleErrorHandlingHook(ctx *HookContext) error {
if ctx.Error != nil {
// Log error with context
logger.Error("Query failed for user %s: %v\nQuery: %s",
ctx.UserContext.UserName,
ctx.Error,
ctx.SQLQuery,
)
// You could send notifications, update metrics, etc.
}
return nil
}
// Example of registering hooks:
//
// func SetupHooks(handler *Handler) {
// hooks := handler.Hooks()
//
// // Register security hook before query execution
// hooks.Register(BeforeQuery, ExampleSecurityHook)
// hooks.Register(BeforeQueryList, ExampleSecurityHook)
//
// // Register logging hook before SQL execution
// hooks.Register(BeforeSQLExec, ExampleLoggingHook)
//
// // Register result filtering after query
// hooks.Register(AfterQuery, ExampleResultFilterHook)
// hooks.Register(AfterQueryList, ExampleResultFilterHook)
//
// // Register audit hook after execution
// hooks.RegisterMultiple([]HookType{AfterQuery, AfterQueryList}, ExampleAuditHook)
// }

589
pkg/funcspec/hooks_test.go Normal file
View File

@ -0,0 +1,589 @@
package funcspec
import (
"context"
"fmt"
"net/http/httptest"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// TestNewHookRegistry tests hook registry creation
func TestNewHookRegistry(t *testing.T) {
registry := NewHookRegistry()
if registry == nil {
t.Fatal("Expected registry to be created, got nil")
}
if registry.hooks == nil {
t.Error("Expected hooks map to be initialized")
}
}
// TestRegisterHook tests registering a single hook
func TestRegisterHook(t *testing.T) {
registry := NewHookRegistry()
hookCalled := false
testHook := func(ctx *HookContext) error {
hookCalled = true
return nil
}
registry.Register(BeforeQuery, testHook)
if !registry.HasHooks(BeforeQuery) {
t.Error("Expected hook to be registered")
}
if registry.Count(BeforeQuery) != 1 {
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeQuery))
}
// Execute the hook
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if !hookCalled {
t.Error("Expected hook to be called")
}
}
// TestRegisterMultipleHooks tests registering multiple hooks for same type
func TestRegisterMultipleHooks(t *testing.T) {
registry := NewHookRegistry()
callOrder := []int{}
hook1 := func(ctx *HookContext) error {
callOrder = append(callOrder, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
callOrder = append(callOrder, 2)
return nil
}
hook3 := func(ctx *HookContext) error {
callOrder = append(callOrder, 3)
return nil
}
registry.Register(BeforeQuery, hook1)
registry.Register(BeforeQuery, hook2)
registry.Register(BeforeQuery, hook3)
if registry.Count(BeforeQuery) != 3 {
t.Errorf("Expected 3 hooks, got %d", registry.Count(BeforeQuery))
}
// Execute hooks
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
// Verify hooks were called in order
if len(callOrder) != 3 {
t.Errorf("Expected 3 hooks to be called, got %d", len(callOrder))
}
for i, expected := range []int{1, 2, 3} {
if callOrder[i] != expected {
t.Errorf("Expected hook %d at position %d, got %d", expected, i, callOrder[i])
}
}
}
// TestRegisterMultipleHookTypes tests registering a hook for multiple types
func TestRegisterMultipleHookTypes(t *testing.T) {
registry := NewHookRegistry()
callCount := 0
testHook := func(ctx *HookContext) error {
callCount++
return nil
}
hookTypes := []HookType{BeforeQuery, AfterQuery, BeforeSQLExec}
registry.RegisterMultiple(hookTypes, testHook)
// Verify hook is registered for all types
for _, hookType := range hookTypes {
if !registry.HasHooks(hookType) {
t.Errorf("Expected hook to be registered for %s", hookType)
}
if registry.Count(hookType) != 1 {
t.Errorf("Expected 1 hook for %s, got %d", hookType, registry.Count(hookType))
}
}
// Execute each hook type
ctx := &HookContext{}
for _, hookType := range hookTypes {
if err := registry.Execute(hookType, ctx); err != nil {
t.Errorf("Hook execution failed for %s: %v", hookType, err)
}
}
if callCount != 3 {
t.Errorf("Expected hook to be called 3 times, got %d", callCount)
}
}
// TestHookError tests hook error handling
func TestHookError(t *testing.T) {
registry := NewHookRegistry()
expectedError := fmt.Errorf("test error")
errorHook := func(ctx *HookContext) error {
return expectedError
}
registry.Register(BeforeQuery, errorHook)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error from hook, got nil")
}
if err.Error() != fmt.Sprintf("hook execution failed: %v", expectedError) {
t.Errorf("Expected error message to contain hook error, got: %v", err)
}
}
// TestHookAbort tests hook abort functionality
func TestHookAbort(t *testing.T) {
registry := NewHookRegistry()
abortHook := func(ctx *HookContext) error {
ctx.Abort = true
ctx.AbortMessage = "Operation aborted by hook"
ctx.AbortCode = 403
return nil
}
registry.Register(BeforeQuery, abortHook)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error when hook aborts, got nil")
}
if !ctx.Abort {
t.Error("Expected Abort to be true")
}
if ctx.AbortMessage != "Operation aborted by hook" {
t.Errorf("Expected abort message, got: %s", ctx.AbortMessage)
}
if ctx.AbortCode != 403 {
t.Errorf("Expected abort code 403, got: %d", ctx.AbortCode)
}
}
// TestHookChainWithError tests that hook chain stops on first error
func TestHookChainWithError(t *testing.T) {
registry := NewHookRegistry()
callOrder := []int{}
hook1 := func(ctx *HookContext) error {
callOrder = append(callOrder, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
callOrder = append(callOrder, 2)
return fmt.Errorf("error in hook 2")
}
hook3 := func(ctx *HookContext) error {
callOrder = append(callOrder, 3)
return nil
}
registry.Register(BeforeQuery, hook1)
registry.Register(BeforeQuery, hook2)
registry.Register(BeforeQuery, hook3)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error from hook chain")
}
// Only first two hooks should have been called
if len(callOrder) != 2 {
t.Errorf("Expected 2 hooks to be called, got %d", len(callOrder))
}
if callOrder[0] != 1 || callOrder[1] != 2 {
t.Errorf("Expected hooks 1 and 2 to be called, got: %v", callOrder)
}
}
// TestClearHooks tests clearing hooks
func TestClearHooks(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
if !registry.HasHooks(BeforeQuery) {
t.Error("Expected BeforeQuery hook to be registered")
}
registry.Clear(BeforeQuery)
if registry.HasHooks(BeforeQuery) {
t.Error("Expected BeforeQuery hooks to be cleared")
}
if !registry.HasHooks(AfterQuery) {
t.Error("Expected AfterQuery hook to still be registered")
}
}
// TestClearAllHooks tests clearing all hooks
func TestClearAllHooks(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
registry.Register(BeforeSQLExec, testHook)
registry.ClearAll()
if registry.HasHooks(BeforeQuery) || registry.HasHooks(AfterQuery) || registry.HasHooks(BeforeSQLExec) {
t.Error("Expected all hooks to be cleared")
}
}
// TestGetAllHookTypes tests getting all registered hook types
func TestGetAllHookTypes(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
types := registry.GetAllHookTypes()
if len(types) != 2 {
t.Errorf("Expected 2 hook types, got %d", len(types))
}
// Verify the types are present
foundBefore := false
foundAfter := false
for _, hookType := range types {
if hookType == BeforeQuery {
foundBefore = true
}
if hookType == AfterQuery {
foundAfter = true
}
}
if !foundBefore || !foundAfter {
t.Error("Expected both BeforeQuery and AfterQuery hook types")
}
}
// TestHookContextModification tests that hooks can modify the context
func TestHookContextModification(t *testing.T) {
registry := NewHookRegistry()
// Hook that modifies SQL query
modifyHook := func(ctx *HookContext) error {
ctx.SQLQuery = "SELECT * FROM modified_table"
ctx.Variables["new_var"] = "new_value"
return nil
}
registry.Register(BeforeQuery, modifyHook)
ctx := &HookContext{
SQLQuery: "SELECT * FROM original_table",
Variables: make(map[string]interface{}),
}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if ctx.SQLQuery != "SELECT * FROM modified_table" {
t.Errorf("Expected SQL query to be modified, got: %s", ctx.SQLQuery)
}
if ctx.Variables["new_var"] != "new_value" {
t.Errorf("Expected variable to be added, got: %v", ctx.Variables)
}
}
// TestExampleHooks tests the example hooks
func TestExampleLoggingHook(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: "SELECT * FROM test",
UserContext: &security.UserContext{
UserName: "testuser",
},
}
err := ExampleLoggingHook(ctx)
if err != nil {
t.Errorf("ExampleLoggingHook failed: %v", err)
}
}
func TestExampleSecurityHook(t *testing.T) {
tests := []struct {
name string
sqlQuery string
userID int
shouldAbort bool
}{
{
name: "Admin accessing sensitive table",
sqlQuery: "SELECT * FROM sensitive_table",
userID: 1,
shouldAbort: false,
},
{
name: "Non-admin accessing sensitive table",
sqlQuery: "SELECT * FROM sensitive_table",
userID: 2,
shouldAbort: true,
},
{
name: "Non-admin accessing normal table",
sqlQuery: "SELECT * FROM users",
userID: 2,
shouldAbort: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: tt.sqlQuery,
UserContext: &security.UserContext{
UserID: tt.userID,
},
}
_ = ExampleSecurityHook(ctx)
if tt.shouldAbort {
if !ctx.Abort {
t.Error("Expected security hook to abort operation")
}
if ctx.AbortCode != 403 {
t.Errorf("Expected abort code 403, got %d", ctx.AbortCode)
}
} else {
if ctx.Abort {
t.Error("Expected security hook not to abort operation")
}
}
})
}
}
func TestExampleResultFilterHook(t *testing.T) {
tests := []struct {
name string
userID int
result interface{}
validate func(t *testing.T, result interface{})
}{
{
name: "Admin user - no filtering",
userID: 1,
result: map[string]interface{}{
"id": 1,
"name": "Test",
"password": "secret",
},
validate: func(t *testing.T, result interface{}) {
m := result.(map[string]interface{})
if _, exists := m["password"]; !exists {
t.Error("Expected password field to remain for admin")
}
},
},
{
name: "Regular user - sensitive fields removed",
userID: 2,
result: map[string]interface{}{
"id": 1,
"name": "Test",
"password": "secret",
"ssn": "123-45-6789",
},
validate: func(t *testing.T, result interface{}) {
m := result.(map[string]interface{})
if _, exists := m["password"]; exists {
t.Error("Expected password field to be removed")
}
if _, exists := m["ssn"]; exists {
t.Error("Expected ssn field to be removed")
}
if _, exists := m["name"]; !exists {
t.Error("Expected name field to remain")
}
},
},
{
name: "Regular user - list results filtered",
userID: 2,
result: []map[string]interface{}{
{"id": 1, "name": "User 1", "password": "secret1"},
{"id": 2, "name": "User 2", "password": "secret2"},
},
validate: func(t *testing.T, result interface{}) {
list := result.([]map[string]interface{})
for _, m := range list {
if _, exists := m["password"]; exists {
t.Error("Expected password field to be removed from list")
}
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
Result: tt.result,
UserContext: &security.UserContext{
UserID: tt.userID,
},
}
err := ExampleResultFilterHook(ctx)
if err != nil {
t.Errorf("Hook failed: %v", err)
}
if tt.validate != nil {
tt.validate(t, ctx.Result)
}
})
}
}
func TestExampleAuditHook(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
ctx := &HookContext{
Context: context.Background(),
Request: req,
UserContext: &security.UserContext{
UserID: 123,
UserName: "testuser",
},
}
err := ExampleAuditHook(ctx)
if err != nil {
t.Errorf("ExampleAuditHook failed: %v", err)
}
}
func TestExampleErrorHandlingHook(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: "SELECT * FROM test",
Error: fmt.Errorf("test error"),
UserContext: &security.UserContext{
UserName: "testuser",
},
}
err := ExampleErrorHandlingHook(ctx)
if err != nil {
t.Errorf("ExampleErrorHandlingHook failed: %v", err)
}
}
// TestHookIntegrationWithHandler tests hooks integrated with the handler
func TestHookIntegrationWithHandler(t *testing.T) {
db := &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
queryDB := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "Test User"},
}
return nil
},
}
return fn(queryDB)
},
}
handler := NewHandler(db)
// Register a hook that modifies the SQL query
hookCalled := false
handler.Hooks().Register(BeforeSQLExec, func(ctx *HookContext) error {
hookCalled = true
// Verify we can access context data
if ctx.SQLQuery == "" {
t.Error("Expected SQL query to be set")
}
if ctx.UserContext == nil {
t.Error("Expected user context to be set")
}
return nil
})
// Execute a query
req := createTestRequest("GET", "/test", nil, nil, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", SqlQueryOptions{})
handlerFunc(w, req)
if !hookCalled {
t.Error("Expected hook to be called during query execution")
}
if w.Code != 200 {
t.Errorf("Expected status 200, got %d", w.Code)
}
}

411
pkg/funcspec/parameters.go Normal file
View File

@ -0,0 +1,411 @@
package funcspec
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
)
// RequestParameters holds parsed parameters from headers and query string
type RequestParameters struct {
// Field selection
SelectFields []string
NotSelectFields []string
Distinct bool
// Filtering
FieldFilters map[string]string // column -> value (exact match)
SearchFilters map[string]string // column -> value (ILIKE)
SearchOps map[string]FilterOperator // column -> {operator, value, logic}
CustomSQLWhere string
CustomSQLOr string
// Sorting & Pagination
SortColumns string
Limit int
Offset int
// Advanced features
SkipCount bool
SkipCache bool
// Response format
ResponseFormat string // "simple", "detail", "syncfusion"
ComplexAPI bool // true if NOT simple API
}
// FilterOperator represents a filter with operator
type FilterOperator struct {
Operator string // eq, neq, gt, lt, gte, lte, like, ilike, in, between, etc.
Value string
Logic string // AND or OR
}
// ParseParameters parses all parameters from request headers and query string
func (h *Handler) ParseParameters(r *http.Request) *RequestParameters {
params := &RequestParameters{
FieldFilters: make(map[string]string),
SearchFilters: make(map[string]string),
SearchOps: make(map[string]FilterOperator),
Limit: 20, // Default limit
Offset: 0, // Default offset
ResponseFormat: "simple", // Default format
ComplexAPI: false, // Default to simple API
}
// Merge headers and query parameters
combined := make(map[string]string)
// Add all headers (normalize to lowercase)
for key, values := range r.Header {
if len(values) > 0 {
combined[strings.ToLower(key)] = values[0]
}
}
// Add all query parameters (override headers)
for key, values := range r.URL.Query() {
if len(values) > 0 {
combined[strings.ToLower(key)] = values[0]
}
}
// Parse each parameter
for key, value := range combined {
// Decode value if base64 encoded
decodedValue := h.decodeValue(value)
switch {
// Field Selection
case strings.HasPrefix(key, "x-select-fields"):
params.SelectFields = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(key, "x-not-select-fields"):
params.NotSelectFields = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(key, "x-distinct"):
params.Distinct = strings.EqualFold(decodedValue, "true")
// Filtering
case strings.HasPrefix(key, "x-fieldfilter-"):
colName := strings.TrimPrefix(key, "x-fieldfilter-")
params.FieldFilters[colName] = decodedValue
case strings.HasPrefix(key, "x-searchfilter-"):
colName := strings.TrimPrefix(key, "x-searchfilter-")
params.SearchFilters[colName] = decodedValue
case strings.HasPrefix(key, "x-searchop-"):
h.parseSearchOp(params, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-searchor-"):
h.parseSearchOp(params, key, decodedValue, "OR")
case strings.HasPrefix(key, "x-searchand-"):
h.parseSearchOp(params, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-custom-sql-w"):
if params.CustomSQLWhere != "" {
params.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", params.CustomSQLWhere, decodedValue)
} else {
params.CustomSQLWhere = decodedValue
}
case strings.HasPrefix(key, "x-custom-sql-or"):
if params.CustomSQLOr != "" {
params.CustomSQLOr = fmt.Sprintf("%s OR (%s)", params.CustomSQLOr, decodedValue)
} else {
params.CustomSQLOr = decodedValue
}
// Sorting & Pagination
case key == "sort" || strings.HasPrefix(key, "x-sort"):
params.SortColumns = decodedValue
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
// Handle sort(col1,-col2) syntax
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
params.SortColumns = sortValue
case key == "limit" || strings.HasPrefix(key, "x-limit"):
if limit, err := strconv.Atoi(decodedValue); err == nil && limit > 0 {
params.Limit = limit
}
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
// Handle limit(offset,limit) or limit(limit) syntax
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
parts := strings.Split(limitValue, ",")
if len(parts) > 1 {
if offset, err := strconv.Atoi(parts[0]); err == nil {
params.Offset = offset
}
if limit, err := strconv.Atoi(parts[1]); err == nil {
params.Limit = limit
}
} else {
if limit, err := strconv.Atoi(parts[0]); err == nil {
params.Limit = limit
}
}
case key == "offset" || strings.HasPrefix(key, "x-offset"):
if offset, err := strconv.Atoi(decodedValue); err == nil && offset >= 0 {
params.Offset = offset
}
// Advanced features
case strings.HasPrefix(key, "x-skipcount"):
params.SkipCount = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(key, "x-skipcache"):
params.SkipCache = strings.EqualFold(decodedValue, "true")
// Response Format
case strings.HasPrefix(key, "x-simpleapi"):
params.ResponseFormat = "simple"
params.ComplexAPI = decodedValue != "1" && !strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(key, "x-detailapi"):
params.ResponseFormat = "detail"
params.ComplexAPI = true
case strings.HasPrefix(key, "x-syncfusion"):
params.ResponseFormat = "syncfusion"
params.ComplexAPI = true
}
}
return params
}
// parseSearchOp parses x-searchop-{operator}-{column} or x-searchor-{operator}-{column}
func (h *Handler) parseSearchOp(params *RequestParameters, headerKey, value, logic string) {
var prefix string
if logic == "OR" {
prefix = "x-searchor-"
} else {
prefix = "x-searchop-"
if strings.HasPrefix(headerKey, "x-searchand-") {
prefix = "x-searchand-"
}
}
rest := strings.TrimPrefix(headerKey, prefix)
parts := strings.SplitN(rest, "-", 2)
if len(parts) != 2 {
logger.Warn("Invalid search operator header format: %s", headerKey)
return
}
operator := parts[0]
colName := parts[1]
params.SearchOps[colName] = FilterOperator{
Operator: operator,
Value: value,
Logic: logic,
}
logger.Debug("%s search operator: %s %s %s", logic, colName, operator, value)
}
// decodeValue decodes base64 encoded values (ZIP_ or __ prefix)
func (h *Handler) decodeValue(value string) string {
decoded, _ := restheadspec.DecodeParam(value)
return decoded
}
// parseCommaSeparated parses comma-separated values
func (h *Handler) parseCommaSeparated(value string) []string {
if value == "" {
return nil
}
parts := strings.Split(value, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part != "" {
result = append(result, part)
}
}
return result
}
// ApplyFieldSelection applies column selection to SQL query
func (h *Handler) ApplyFieldSelection(sqlQuery string, params *RequestParameters) string {
if len(params.SelectFields) == 0 && len(params.NotSelectFields) == 0 {
return sqlQuery
}
// This is a simplified implementation
// A full implementation would parse the SQL and replace the SELECT clause
// For now, we log a warning that this feature needs manual implementation
if len(params.SelectFields) > 0 {
logger.Debug("Field selection requested: %v (manual SQL adjustment may be needed)", params.SelectFields)
}
if len(params.NotSelectFields) > 0 {
logger.Debug("Field exclusion requested: %v (manual SQL adjustment may be needed)", params.NotSelectFields)
}
return sqlQuery
}
// ApplyFilters applies all filters to the SQL query
func (h *Handler) ApplyFilters(sqlQuery string, params *RequestParameters) string {
// Apply field filters (exact match)
for colName, value := range params.FieldFilters {
condition := ""
if value == "" || value == "0" {
condition = fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
} else {
condition = fmt.Sprintf("%s = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
}
sqlQuery = sqlQryWhere(sqlQuery, condition)
logger.Debug("Applied field filter: %s", condition)
}
// Apply search filters (ILIKE)
for colName, value := range params.SearchFilters {
sval := strings.ReplaceAll(value, "'", "")
if sval != "" {
condition := fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colName, "colname"), ValidSQL(sval, "colvalue"))
sqlQuery = sqlQryWhere(sqlQuery, condition)
logger.Debug("Applied search filter: %s", condition)
}
}
// Apply search operators
for colName, filterOp := range params.SearchOps {
condition := h.buildFilterCondition(colName, filterOp)
if condition != "" {
if filterOp.Logic == "OR" {
sqlQuery = sqlQryWhereOr(sqlQuery, condition)
} else {
sqlQuery = sqlQryWhere(sqlQuery, condition)
}
logger.Debug("Applied search operator: %s", condition)
}
}
// Apply custom SQL WHERE
if params.CustomSQLWhere != "" {
colval := ValidSQL(params.CustomSQLWhere, "select")
if colval != "" {
sqlQuery = sqlQryWhere(sqlQuery, colval)
logger.Debug("Applied custom SQL WHERE: %s", colval)
}
}
// Apply custom SQL OR
if params.CustomSQLOr != "" {
colval := ValidSQL(params.CustomSQLOr, "select")
if colval != "" {
sqlQuery = sqlQryWhereOr(sqlQuery, colval)
logger.Debug("Applied custom SQL OR: %s", colval)
}
}
return sqlQuery
}
// buildFilterCondition builds a SQL condition from a FilterOperator
func (h *Handler) buildFilterCondition(colName string, op FilterOperator) string {
safCol := ValidSQL(colName, "colname")
operator := strings.ToLower(op.Operator)
value := op.Value
switch operator {
case "contains", "contain", "like":
return fmt.Sprintf("%s ILIKE '%%%s%%'", safCol, ValidSQL(value, "colvalue"))
case "beginswith", "startswith":
return fmt.Sprintf("%s ILIKE '%s%%'", safCol, ValidSQL(value, "colvalue"))
case "endswith":
return fmt.Sprintf("%s ILIKE '%%%s'", safCol, ValidSQL(value, "colvalue"))
case "equals", "eq", "=":
if IsNumeric(value) {
return fmt.Sprintf("%s = %s", safCol, ValidSQL(value, "colvalue"))
}
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
case "notequals", "neq", "ne", "!=", "<>":
if IsNumeric(value) {
return fmt.Sprintf("%s != %s", safCol, ValidSQL(value, "colvalue"))
}
return fmt.Sprintf("%s != '%s'", safCol, ValidSQL(value, "colvalue"))
case "greaterthan", "gt", ">":
return fmt.Sprintf("%s > %s", safCol, ValidSQL(value, "colvalue"))
case "lessthan", "lt", "<":
return fmt.Sprintf("%s < %s", safCol, ValidSQL(value, "colvalue"))
case "greaterthanorequal", "gte", "ge", ">=":
return fmt.Sprintf("%s >= %s", safCol, ValidSQL(value, "colvalue"))
case "lessthanorequal", "lte", "le", "<=":
return fmt.Sprintf("%s <= %s", safCol, ValidSQL(value, "colvalue"))
case "between":
parts := strings.Split(value, ",")
if len(parts) == 2 {
return fmt.Sprintf("%s > %s AND %s < %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
}
case "betweeninclusive":
parts := strings.Split(value, ",")
if len(parts) == 2 {
return fmt.Sprintf("%s >= %s AND %s <= %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
}
case "in":
values := strings.Split(value, ",")
safeValues := make([]string, len(values))
for i, v := range values {
safeValues[i] = fmt.Sprintf("'%s'", ValidSQL(v, "colvalue"))
}
return fmt.Sprintf("%s IN (%s)", safCol, strings.Join(safeValues, ", "))
case "empty", "isnull", "null":
return fmt.Sprintf("(%s IS NULL OR %s = '')", safCol, safCol)
case "notempty", "isnotnull", "notnull":
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", safCol, safCol)
default:
logger.Warn("Unknown filter operator: %s, defaulting to equals", operator)
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
}
return ""
}
// ApplyDistinct adds DISTINCT to SQL query if requested
func (h *Handler) ApplyDistinct(sqlQuery string, params *RequestParameters) string {
if !params.Distinct {
return sqlQuery
}
// Add DISTINCT after SELECT
selectPos := strings.Index(strings.ToUpper(sqlQuery), "SELECT")
if selectPos >= 0 {
beforeSelect := sqlQuery[:selectPos+6] // "SELECT"
afterSelect := sqlQuery[selectPos+6:]
sqlQuery = beforeSelect + " DISTINCT" + afterSelect
logger.Debug("Applied DISTINCT to query")
}
return sqlQuery
}
// sqlQryWhereOr adds a WHERE clause with OR logic
func sqlQryWhereOr(sqlquery, condition string) string {
lowerQuery := strings.ToLower(sqlquery)
wherePos := strings.Index(lowerQuery, " where ")
groupPos := strings.Index(lowerQuery, " group by")
orderPos := strings.Index(lowerQuery, " order by")
limitPos := strings.Index(lowerQuery, " limit ")
// Find the insertion point
insertPos := len(sqlquery)
if groupPos > 0 && groupPos < insertPos {
insertPos = groupPos
}
if orderPos > 0 && orderPos < insertPos {
insertPos = orderPos
}
if limitPos > 0 && limitPos < insertPos {
insertPos = limitPos
}
if wherePos > 0 {
// WHERE exists, add OR condition
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s OR (%s) %s", before, condition, after)
} else {
// No WHERE exists, add it
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s WHERE %s %s", before, condition, after)
}
}

View File

@ -0,0 +1,549 @@
package funcspec
import (
"strings"
"testing"
)
// TestParseParameters tests the comprehensive parameter parsing
func TestParseParameters(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
queryParams map[string]string
headers map[string]string
validate func(t *testing.T, params *RequestParameters)
}{
{
name: "Parse field selection",
headers: map[string]string{
"X-Select-Fields": "id,name,email",
"X-Not-Select-Fields": "password,ssn",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SelectFields) != 3 {
t.Errorf("Expected 3 select fields, got %d", len(params.SelectFields))
}
if len(params.NotSelectFields) != 2 {
t.Errorf("Expected 2 not-select fields, got %d", len(params.NotSelectFields))
}
},
},
{
name: "Parse distinct flag",
headers: map[string]string{
"X-Distinct": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if !params.Distinct {
t.Error("Expected Distinct to be true")
}
},
},
{
name: "Parse field filters",
headers: map[string]string{
"X-FieldFilter-Status": "active",
"X-FieldFilter-Type": "admin",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.FieldFilters) != 2 {
t.Errorf("Expected 2 field filters, got %d", len(params.FieldFilters))
}
if params.FieldFilters["status"] != "active" {
t.Errorf("Expected status filter=active, got %s", params.FieldFilters["status"])
}
},
},
{
name: "Parse search filters",
headers: map[string]string{
"X-SearchFilter-Name": "john",
"X-SearchFilter-Email": "test",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SearchFilters) != 2 {
t.Errorf("Expected 2 search filters, got %d", len(params.SearchFilters))
}
},
},
{
name: "Parse sort columns",
queryParams: map[string]string{
"sort": "-created_at,name",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.SortColumns != "-created_at,name" {
t.Errorf("Expected sort columns=-created_at,name, got %s", params.SortColumns)
}
},
},
{
name: "Parse limit and offset",
queryParams: map[string]string{
"limit": "100",
"offset": "50",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.Limit != 100 {
t.Errorf("Expected limit=100, got %d", params.Limit)
}
if params.Offset != 50 {
t.Errorf("Expected offset=50, got %d", params.Offset)
}
},
},
{
name: "Parse skip count",
headers: map[string]string{
"X-SkipCount": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if !params.SkipCount {
t.Error("Expected SkipCount to be true")
}
},
},
{
name: "Parse response format - syncfusion",
headers: map[string]string{
"X-Syncfusion": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "syncfusion" {
t.Errorf("Expected ResponseFormat=syncfusion, got %s", params.ResponseFormat)
}
if !params.ComplexAPI {
t.Error("Expected ComplexAPI to be true for syncfusion format")
}
},
},
{
name: "Parse response format - detail",
headers: map[string]string{
"X-DetailAPI": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "detail" {
t.Errorf("Expected ResponseFormat=detail, got %s", params.ResponseFormat)
}
},
},
{
name: "Parse simple API",
headers: map[string]string{
"X-SimpleAPI": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "simple" {
t.Errorf("Expected ResponseFormat=simple, got %s", params.ResponseFormat)
}
if params.ComplexAPI {
t.Error("Expected ComplexAPI to be false for simple API")
}
},
},
{
name: "Parse custom SQL WHERE",
headers: map[string]string{
"X-Custom-SQL-W": "status = 'active' AND deleted = false",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.CustomSQLWhere == "" {
t.Error("Expected CustomSQLWhere to be set")
}
},
},
{
name: "Parse search operators - AND",
headers: map[string]string{
"X-SearchOp-Eq-Name": "john",
"X-SearchOp-Gt-Age": "18",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SearchOps) != 2 {
t.Errorf("Expected 2 search operators, got %d", len(params.SearchOps))
}
if op, exists := params.SearchOps["name"]; exists {
if op.Operator != "eq" {
t.Errorf("Expected operator=eq for name, got %s", op.Operator)
}
if op.Logic != "AND" {
t.Errorf("Expected logic=AND, got %s", op.Logic)
}
} else {
t.Error("Expected name search operator to exist")
}
},
},
{
name: "Parse search operators - OR",
headers: map[string]string{
"X-SearchOr-Like-Description": "test",
},
validate: func(t *testing.T, params *RequestParameters) {
if op, exists := params.SearchOps["description"]; exists {
if op.Logic != "OR" {
t.Errorf("Expected logic=OR, got %s", op.Logic)
}
} else {
t.Error("Expected description search operator to exist")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
params := handler.ParseParameters(req)
if tt.validate != nil {
tt.validate(t, params)
}
})
}
}
// TestBuildFilterCondition tests the filter condition builder
func TestBuildFilterCondition(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
colName string
operator FilterOperator
expected string
}{
{
name: "Equals operator - numeric",
colName: "age",
operator: FilterOperator{
Operator: "eq",
Value: "25",
Logic: "AND",
},
expected: "age = 25",
},
{
name: "Equals operator - string",
colName: "name",
operator: FilterOperator{
Operator: "eq",
Value: "john",
Logic: "AND",
},
expected: "name = 'john'",
},
{
name: "Not equals operator",
colName: "status",
operator: FilterOperator{
Operator: "neq",
Value: "inactive",
Logic: "AND",
},
expected: "status != 'inactive'",
},
{
name: "Greater than operator",
colName: "age",
operator: FilterOperator{
Operator: "gt",
Value: "18",
Logic: "AND",
},
expected: "age > 18",
},
{
name: "Less than operator",
colName: "price",
operator: FilterOperator{
Operator: "lt",
Value: "100",
Logic: "AND",
},
expected: "price < 100",
},
{
name: "Contains operator",
colName: "description",
operator: FilterOperator{
Operator: "contains",
Value: "test",
Logic: "AND",
},
expected: "description ILIKE '%test%'",
},
{
name: "Starts with operator",
colName: "name",
operator: FilterOperator{
Operator: "startswith",
Value: "john",
Logic: "AND",
},
expected: "name ILIKE 'john%'",
},
{
name: "Ends with operator",
colName: "email",
operator: FilterOperator{
Operator: "endswith",
Value: "@example.com",
Logic: "AND",
},
expected: "email ILIKE '%@example.com'",
},
{
name: "Between operator",
colName: "age",
operator: FilterOperator{
Operator: "between",
Value: "18,65",
Logic: "AND",
},
expected: "age > 18 AND age < 65",
},
{
name: "IN operator",
colName: "status",
operator: FilterOperator{
Operator: "in",
Value: "active,pending,approved",
Logic: "AND",
},
expected: "status IN ('active', 'pending', 'approved')",
},
{
name: "IS NULL operator",
colName: "deleted_at",
operator: FilterOperator{
Operator: "null",
Value: "",
Logic: "AND",
},
expected: "(deleted_at IS NULL OR deleted_at = '')",
},
{
name: "IS NOT NULL operator",
colName: "created_at",
operator: FilterOperator{
Operator: "notnull",
Value: "",
Logic: "AND",
},
expected: "(created_at IS NOT NULL AND created_at != '')",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.buildFilterCondition(tt.colName, tt.operator)
if result != tt.expected {
t.Errorf("Expected: %s\nGot: %s", tt.expected, result)
}
})
}
}
// TestApplyFilters tests the filter application to SQL queries
func TestApplyFilters(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
params *RequestParameters
expectedSQL string
shouldContain []string
}{
{
name: "Apply field filter",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
FieldFilters: map[string]string{
"status": "active",
},
},
shouldContain: []string{"WHERE", "status"},
},
{
name: "Apply search filter",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
SearchFilters: map[string]string{
"name": "john",
},
},
shouldContain: []string{"WHERE", "name", "ILIKE"},
},
{
name: "Apply search operators",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
SearchOps: map[string]FilterOperator{
"age": {
Operator: "gt",
Value: "18",
Logic: "AND",
},
},
},
shouldContain: []string{"WHERE", "age", ">", "18"},
},
{
name: "Apply custom SQL WHERE",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
CustomSQLWhere: "deleted = false",
},
shouldContain: []string{"WHERE", "deleted"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.ApplyFilters(tt.sqlQuery, tt.params)
for _, expected := range tt.shouldContain {
if !strings.Contains(result, expected) {
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
}
}
})
}
}
// TestApplyDistinct tests DISTINCT application
func TestApplyDistinct(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
distinct bool
shouldHave string
}{
{
name: "Apply DISTINCT",
sqlQuery: "SELECT id, name FROM users",
distinct: true,
shouldHave: "SELECT DISTINCT",
},
{
name: "Do not apply DISTINCT",
sqlQuery: "SELECT id, name FROM users",
distinct: false,
shouldHave: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
params := &RequestParameters{Distinct: tt.distinct}
result := handler.ApplyDistinct(tt.sqlQuery, params)
if tt.shouldHave != "" {
if !strings.Contains(result, tt.shouldHave) {
t.Errorf("Expected SQL to contain %q, got: %s", tt.shouldHave, result)
}
} else {
// Should not have DISTINCT when not requested
if strings.Contains(result, "DISTINCT") && !tt.distinct {
t.Errorf("SQL should not contain DISTINCT when not requested: %s", result)
}
}
})
}
}
// TestParseCommaSeparated tests comma-separated value parsing
func TestParseCommaSeparated(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
input string
expected []string
}{
{
name: "Simple comma-separated",
input: "id,name,email",
expected: []string{"id", "name", "email"},
},
{
name: "With spaces",
input: "id, name, email",
expected: []string{"id", "name", "email"},
},
{
name: "Empty string",
input: "",
expected: nil,
},
{
name: "Single value",
input: "id",
expected: []string{"id"},
},
{
name: "With extra commas",
input: "id,,name,,email",
expected: []string{"id", "name", "email"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.parseCommaSeparated(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("Expected %d values, got %d", len(tt.expected), len(result))
return
}
for i, expected := range tt.expected {
if result[i] != expected {
t.Errorf("Expected value %d to be %s, got %s", i, expected, result[i])
}
}
})
}
}
// TestSqlQryWhereOr tests OR WHERE clause manipulation
func TestSqlQryWhereOr(t *testing.T) {
tests := []struct {
name string
sqlQuery string
condition string
shouldContain []string
}{
{
name: "Add WHERE with OR to query without WHERE",
sqlQuery: "SELECT * FROM users",
condition: "status = 'inactive'",
shouldContain: []string{"WHERE", "status = 'inactive'"},
},
{
name: "Add OR to query with existing WHERE",
sqlQuery: "SELECT * FROM users WHERE id > 0",
condition: "status = 'inactive'",
shouldContain: []string{"WHERE", "OR", "(status = 'inactive')"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sqlQryWhereOr(tt.sqlQuery, tt.condition)
for _, expected := range tt.shouldContain {
if !strings.Contains(result, expected) {
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
}
}
})
}
}

View File

@ -0,0 +1,83 @@
package funcspec
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks registers security hooks for funcspec handlers
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
// We provide audit logging for data access tracking
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 1: BeforeQueryList - Audit logging before query list execution
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
secCtx := newFuncSpecSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
// Hook 2: BeforeQuery - Audit logging before single query execution
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
secCtx := newFuncSpecSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
// Note: Row-level security and column masking are challenging in funcspec
// because the SQL query is fully user-defined. Security should be implemented
// at the SQL function level or through database policies (RLS).
}
// funcSpecSecurityContext adapts funcspec.HookContext to security.SecurityContext interface
type funcSpecSecurityContext struct {
ctx *HookContext
}
func newFuncSpecSecurityContext(ctx *HookContext) security.SecurityContext {
return &funcSpecSecurityContext{ctx: ctx}
}
func (f *funcSpecSecurityContext) GetContext() context.Context {
return f.ctx.Context
}
func (f *funcSpecSecurityContext) GetUserID() (int, bool) {
if f.ctx.UserContext == nil {
return 0, false
}
return int(f.ctx.UserContext.UserID), true
}
func (f *funcSpecSecurityContext) GetSchema() string {
// funcspec doesn't have a schema concept, extract from SQL query or use default
return "public"
}
func (f *funcSpecSecurityContext) GetEntity() string {
// funcspec doesn't have an entity concept, could parse from SQL or use a placeholder
return "sql_query"
}
func (f *funcSpecSecurityContext) GetModel() interface{} {
// funcspec doesn't use models in the same way as restheadspec
return nil
}
func (f *funcSpecSecurityContext) GetQuery() interface{} {
// In funcspec, the query is a string, not a query builder object
return f.ctx.SQLQuery
}
func (f *funcSpecSecurityContext) SetQuery(query interface{}) {
// In funcspec, we could modify the SQL string, but this should be done cautiously
if sqlQuery, ok := query.(string); ok {
f.ctx.SQLQuery = sqlQuery
}
}
func (f *funcSpecSecurityContext) GetResult() interface{} {
return f.ctx.Result
}
func (f *funcSpecSecurityContext) SetResult(result interface{}) {
f.ctx.Result = result
}

View File

@ -1,15 +1,19 @@
package logger
import (
"context"
"fmt"
"log"
"os"
"runtime/debug"
"go.uber.org/zap"
errortracking "github.com/bitechdev/ResolveSpec/pkg/errortracking"
)
var Logger *zap.SugaredLogger
var errorTracker errortracking.Provider
func Init(dev bool) {
@ -23,6 +27,15 @@ func Init(dev bool) {
}
func UpdateLoggerPath(path string, dev bool) {
defaultConfig := zap.NewProductionConfig()
if dev {
defaultConfig = zap.NewDevelopmentConfig()
}
defaultConfig.OutputPaths = []string{path}
UpdateLogger(&defaultConfig)
}
func UpdateLogger(config *zap.Config) {
defaultConfig := zap.NewProductionConfig()
defaultConfig.OutputPaths = []string{"resolvespec.log"}
@ -40,6 +53,28 @@ func UpdateLogger(config *zap.Config) {
Info("ResolveSpec Logger initialized")
}
// InitErrorTracking initializes the error tracking provider
func InitErrorTracking(provider errortracking.Provider) {
errorTracker = provider
if errorTracker != nil {
Info("Error tracking initialized")
}
}
// GetErrorTracker returns the current error tracking provider
func GetErrorTracker() errortracking.Provider {
return errorTracker
}
// CloseErrorTracking flushes and closes the error tracking provider
func CloseErrorTracking() error {
if errorTracker != nil {
errorTracker.Flush(5)
return errorTracker.Close()
}
return nil
}
func Info(template string, args ...interface{}) {
if Logger == nil {
log.Printf(template, args...)
@ -49,19 +84,35 @@ func Info(template string, args ...interface{}) {
}
func Warn(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
if Logger == nil {
log.Printf(template, args...)
return
log.Printf("%s", message)
} else {
Logger.Warnw(message, "process_id", os.Getpid())
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{
"process_id": os.Getpid(),
})
}
Logger.Warnw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
}
func Error(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
if Logger == nil {
log.Printf(template, args...)
return
log.Printf("%s", message)
} else {
Logger.Errorw(message, "process_id", os.Getpid())
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{
"process_id": os.Getpid(),
})
}
Logger.Errorw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
}
func Debug(template string, args ...interface{}) {
@ -75,7 +126,7 @@ func Debug(template string, args ...interface{}) {
// CatchPanic - Handle panic
func CatchPanicCallback(location string, cb func(err any)) {
if err := recover(); err != nil {
// callstack := debug.Stack()
callstack := debug.Stack()
if Logger != nil {
Error("Panic in %s : %v", location, err)
@ -84,14 +135,13 @@ func CatchPanicCallback(location string, cb func(err any)) {
debug.PrintStack()
}
// push to sentry
// hub := sentry.CurrentHub()
// if hub != nil {
// evtID := hub.Recover(err)
// if evtID != nil {
// sentry.Flush(time.Second * 2)
// }
// }
// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{
"location": location,
"process_id": os.Getpid(),
})
}
if cb != nil {
cb(err)
@ -103,3 +153,27 @@ func CatchPanicCallback(location string, cb func(err any)) {
func CatchPanic(location string) {
CatchPanicCallback(location, nil)
}
// HandlePanic logs a panic and returns it as an error
// This should be called with the result of recover() from a deferred function
// Example usage:
//
// defer func() {
// if r := recover(); r != nil {
// err = logger.HandlePanic("MethodName", r)
// }
// }()
func HandlePanic(methodName string, r any) error {
stack := debug.Stack()
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{
"method": methodName,
"process_id": os.Getpid(),
})
}
return fmt.Errorf("panic in %s: %v", methodName, r)
}

259
pkg/metrics/README.md Normal file
View File

@ -0,0 +1,259 @@
# Metrics Package
A pluggable metrics collection system with Prometheus implementation.
## Quick Start
```go
import "github.com/bitechdev/ResolveSpec/pkg/metrics"
// Initialize Prometheus provider
provider := metrics.NewPrometheusProvider()
metrics.SetProvider(provider)
// Apply middleware to your router
router.Use(provider.Middleware)
// Expose metrics endpoint
http.Handle("/metrics", provider.Handler())
```
## Provider Interface
The package uses a provider interface, allowing you to plug in different metric systems:
```go
type Provider interface {
RecordHTTPRequest(method, path, status string, duration time.Duration)
IncRequestsInFlight()
DecRequestsInFlight()
RecordDBQuery(operation, table string, duration time.Duration, err error)
RecordCacheHit(provider string)
RecordCacheMiss(provider string)
UpdateCacheSize(provider string, size int64)
Handler() http.Handler
}
```
## Recording Metrics
### HTTP Metrics (Automatic)
When using the middleware, HTTP metrics are recorded automatically:
```go
router.Use(provider.Middleware)
```
**Collected:**
- Request duration (histogram)
- Request count by method, path, and status
- Requests in flight (gauge)
### Database Metrics
```go
start := time.Now()
rows, err := db.Query("SELECT * FROM users WHERE id = ?", userID)
duration := time.Since(start)
metrics.GetProvider().RecordDBQuery("SELECT", "users", duration, err)
```
### Cache Metrics
```go
// Record cache hit
metrics.GetProvider().RecordCacheHit("memory")
// Record cache miss
metrics.GetProvider().RecordCacheMiss("memory")
// Update cache size
metrics.GetProvider().UpdateCacheSize("memory", 1024)
```
## Prometheus Metrics
When using `PrometheusProvider`, the following metrics are available:
| Metric Name | Type | Labels | Description |
|-------------|------|--------|-------------|
| `http_request_duration_seconds` | Histogram | method, path, status | HTTP request duration |
| `http_requests_total` | Counter | method, path, status | Total HTTP requests |
| `http_requests_in_flight` | Gauge | - | Current in-flight requests |
| `db_query_duration_seconds` | Histogram | operation, table | Database query duration |
| `db_queries_total` | Counter | operation, table, status | Total database queries |
| `cache_hits_total` | Counter | provider | Total cache hits |
| `cache_misses_total` | Counter | provider | Total cache misses |
| `cache_size_items` | Gauge | provider | Current cache size |
## Prometheus Queries
### HTTP Request Rate
```promql
rate(http_requests_total[5m])
```
### HTTP Request Duration (95th percentile)
```promql
histogram_quantile(0.95, rate(http_request_duration_seconds_bucket[5m]))
```
### Database Query Error Rate
```promql
rate(db_queries_total{status="error"}[5m])
```
### Cache Hit Rate
```promql
rate(cache_hits_total[5m]) / (rate(cache_hits_total[5m]) + rate(cache_misses_total[5m]))
```
## No-Op Provider
If metrics are disabled:
```go
// No provider set - uses no-op provider automatically
metrics.GetProvider().RecordHTTPRequest(...) // Does nothing
```
## Custom Provider
Implement your own metrics provider:
```go
type CustomProvider struct{}
func (c *CustomProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
// Send to your metrics system
}
// Implement other Provider interface methods...
func (c *CustomProvider) Handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return your metrics format
})
}
// Use it
metrics.SetProvider(&CustomProvider{})
```
## Complete Example
```go
package main
import (
"database/sql"
"log"
"net/http"
"time"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
"github.com/gorilla/mux"
)
func main() {
// Initialize metrics
provider := metrics.NewPrometheusProvider()
metrics.SetProvider(provider)
// Create router
router := mux.NewRouter()
// Apply metrics middleware
router.Use(provider.Middleware)
// Expose metrics endpoint
router.Handle("/metrics", provider.Handler())
// Your API routes
router.HandleFunc("/api/users", getUsersHandler)
log.Fatal(http.ListenAndServe(":8080", router))
}
func getUsersHandler(w http.ResponseWriter, r *http.Request) {
// Record database query
start := time.Now()
users, err := fetchUsers()
duration := time.Since(start)
metrics.GetProvider().RecordDBQuery("SELECT", "users", duration, err)
if err != nil {
http.Error(w, "Internal Server Error", 500)
return
}
// Return users...
}
```
## Docker Compose Example
```yaml
version: '3'
services:
app:
build: .
ports:
- "8080:8080"
prometheus:
image: prom/prometheus
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
command:
- '--config.file=/etc/prometheus/prometheus.yml'
grafana:
image: grafana/grafana
ports:
- "3000:3000"
depends_on:
- prometheus
```
**prometheus.yml:**
```yaml
global:
scrape_interval: 15s
scrape_configs:
- job_name: 'resolvespec'
static_configs:
- targets: ['app:8080']
```
## Best Practices
1. **Label Cardinality**: Keep labels low-cardinality
- ✅ Good: `method`, `status_code`
- ❌ Bad: `user_id`, `timestamp`
2. **Path Normalization**: Normalize dynamic paths
```go
// Instead of /api/users/123
// Use /api/users/:id
```
3. **Metric Naming**: Follow Prometheus conventions
- Use `_total` suffix for counters
- Use `_seconds` suffix for durations
- Use base units (seconds, not milliseconds)
4. **Performance**: Metrics collection is lock-free and highly performant
- Safe for high-throughput applications
- Minimal overhead (<1% in most cases)

86
pkg/metrics/interfaces.go Normal file
View File

@ -0,0 +1,86 @@
package metrics
import (
"net/http"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Provider defines the interface for metric collection
type Provider interface {
// RecordHTTPRequest records metrics for an HTTP request
RecordHTTPRequest(method, path, status string, duration time.Duration)
// IncRequestsInFlight increments the in-flight requests counter
IncRequestsInFlight()
// DecRequestsInFlight decrements the in-flight requests counter
DecRequestsInFlight()
// RecordDBQuery records metrics for a database query
RecordDBQuery(operation, table string, duration time.Duration, err error)
// RecordCacheHit records a cache hit
RecordCacheHit(provider string)
// RecordCacheMiss records a cache miss
RecordCacheMiss(provider string)
// UpdateCacheSize updates the cache size metric
UpdateCacheSize(provider string, size int64)
// RecordEventPublished records an event publication
RecordEventPublished(source, eventType string)
// RecordEventProcessed records an event processing with its status
RecordEventProcessed(source, eventType, status string, duration time.Duration)
// UpdateEventQueueSize updates the event queue size metric
UpdateEventQueueSize(size int64)
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
Handler() http.Handler
}
// globalProvider is the global metrics provider
var globalProvider Provider
// SetProvider sets the global metrics provider
func SetProvider(p Provider) {
globalProvider = p
}
// GetProvider returns the current metrics provider
func GetProvider() Provider {
if globalProvider == nil {
// Return no-op provider if none is set
return &NoOpProvider{}
}
return globalProvider
}
// NoOpProvider is a no-op implementation of Provider
type NoOpProvider struct{}
func (n *NoOpProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {}
func (n *NoOpProvider) IncRequestsInFlight() {}
func (n *NoOpProvider) DecRequestsInFlight() {}
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
}
func (n *NoOpProvider) RecordCacheHit(provider string) {}
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
func (n *NoOpProvider) RecordEventPublished(source, eventType string) {}
func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
}
func (n *NoOpProvider) UpdateEventQueueSize(size int64) {}
func (n *NoOpProvider) Handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
_, err := w.Write([]byte("Metrics provider not configured"))
if err != nil {
logger.Warn("Failed to write. %v", err)
}
})
}

174
pkg/metrics/prometheus.go Normal file
View File

@ -0,0 +1,174 @@
package metrics
import (
"net/http"
"strconv"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// PrometheusProvider implements the Provider interface using Prometheus
type PrometheusProvider struct {
requestDuration *prometheus.HistogramVec
requestTotal *prometheus.CounterVec
requestsInFlight prometheus.Gauge
dbQueryDuration *prometheus.HistogramVec
dbQueryTotal *prometheus.CounterVec
cacheHits *prometheus.CounterVec
cacheMisses *prometheus.CounterVec
cacheSize *prometheus.GaugeVec
}
// NewPrometheusProvider creates a new Prometheus metrics provider
func NewPrometheusProvider() *PrometheusProvider {
return &PrometheusProvider{
requestDuration: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "HTTP request duration in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"method", "path", "status"},
),
requestTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"method", "path", "status"},
),
requestsInFlight: promauto.NewGauge(
prometheus.GaugeOpts{
Name: "http_requests_in_flight",
Help: "Current number of HTTP requests being processed",
},
),
dbQueryDuration: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "db_query_duration_seconds",
Help: "Database query duration in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"operation", "table"},
),
dbQueryTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "db_queries_total",
Help: "Total number of database queries",
},
[]string{"operation", "table", "status"},
),
cacheHits: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_hits_total",
Help: "Total number of cache hits",
},
[]string{"provider"},
),
cacheMisses: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_misses_total",
Help: "Total number of cache misses",
},
[]string{"provider"},
),
cacheSize: promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "cache_size_items",
Help: "Number of items in cache",
},
[]string{"provider"},
),
}
}
// ResponseWriter wraps http.ResponseWriter to capture status code
type ResponseWriter struct {
http.ResponseWriter
statusCode int
}
func NewResponseWriter(w http.ResponseWriter) *ResponseWriter {
return &ResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
}
func (rw *ResponseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
// RecordHTTPRequest implements Provider interface
func (p *PrometheusProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
p.requestDuration.WithLabelValues(method, path, status).Observe(duration.Seconds())
p.requestTotal.WithLabelValues(method, path, status).Inc()
}
// IncRequestsInFlight implements Provider interface
func (p *PrometheusProvider) IncRequestsInFlight() {
p.requestsInFlight.Inc()
}
// DecRequestsInFlight implements Provider interface
func (p *PrometheusProvider) DecRequestsInFlight() {
p.requestsInFlight.Dec()
}
// RecordDBQuery implements Provider interface
func (p *PrometheusProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
status := "success"
if err != nil {
status = "error"
}
p.dbQueryDuration.WithLabelValues(operation, table).Observe(duration.Seconds())
p.dbQueryTotal.WithLabelValues(operation, table, status).Inc()
}
// RecordCacheHit implements Provider interface
func (p *PrometheusProvider) RecordCacheHit(provider string) {
p.cacheHits.WithLabelValues(provider).Inc()
}
// RecordCacheMiss implements Provider interface
func (p *PrometheusProvider) RecordCacheMiss(provider string) {
p.cacheMisses.WithLabelValues(provider).Inc()
}
// UpdateCacheSize implements Provider interface
func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) {
p.cacheSize.WithLabelValues(provider).Set(float64(size))
}
// Handler implements Provider interface
func (p *PrometheusProvider) Handler() http.Handler {
return promhttp.Handler()
}
// Middleware returns an HTTP middleware that collects metrics
func (p *PrometheusProvider) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Increment in-flight requests
p.IncRequestsInFlight()
defer p.DecRequestsInFlight()
// Wrap response writer to capture status code
rw := NewResponseWriter(w)
// Call next handler
next.ServeHTTP(rw, r)
// Record metrics
duration := time.Since(start)
status := strconv.Itoa(rw.statusCode)
p.RecordHTTPRequest(r.Method, r.URL.Path, status, duration)
})
}

806
pkg/middleware/README.md Normal file
View File

@ -0,0 +1,806 @@
# Middleware Package
HTTP middleware utilities for security and performance.
## Table of Contents
1. [Rate Limiting](#rate-limiting)
2. [Request Size Limits](#request-size-limits)
3. [Input Sanitization](#input-sanitization)
---
## Rate Limiting
Production-grade rate limiting using token bucket algorithm.
### Quick Start
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
// Create rate limiter: 100 requests per second, burst of 20
rateLimiter := middleware.NewRateLimiter(100, 20)
// Apply to all routes
router.Use(rateLimiter.Middleware)
```
### Basic Usage
```go
package main
import (
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/gorilla/mux"
)
func main() {
router := mux.NewRouter()
// Rate limit: 10 requests per second, burst of 5
rateLimiter := middleware.NewRateLimiter(10, 5)
router.Use(rateLimiter.Middleware)
router.HandleFunc("/api/data", dataHandler)
log.Fatal(http.ListenAndServe(":8080", router))
}
```
### Custom Key Extraction
By default, rate limiting is per IP address. Customize the key:
```go
// Rate limit by User ID from header
keyFunc := func(r *http.Request) string {
userID := r.Header.Get("X-User-ID")
if userID == "" {
return r.RemoteAddr // Fallback to IP
}
return "user:" + userID
}
router.Use(rateLimiter.MiddlewareWithKeyFunc(keyFunc))
```
### Advanced Key Functions
**By API Key:**
```go
keyFunc := func(r *http.Request) string {
apiKey := r.Header.Get("X-API-Key")
if apiKey == "" {
return r.RemoteAddr
}
return "api:" + apiKey
}
```
**By Authenticated User:**
```go
keyFunc := func(r *http.Request) string {
// Extract from JWT or session
user := getUserFromContext(r.Context())
if user != nil {
return "user:" + user.ID
}
return r.RemoteAddr
}
```
**By Path + User:**
```go
keyFunc := func(r *http.Request) string {
user := getUserFromContext(r.Context())
if user != nil {
return fmt.Sprintf("user:%s:path:%s", user.ID, r.URL.Path)
}
return r.URL.Path + ":" + r.RemoteAddr
}
```
### Different Limits Per Route
```go
func main() {
router := mux.NewRouter()
// Public endpoints: 10 rps
publicLimiter := middleware.NewRateLimiter(10, 5)
// API endpoints: 100 rps
apiLimiter := middleware.NewRateLimiter(100, 20)
// Admin endpoints: 1000 rps
adminLimiter := middleware.NewRateLimiter(1000, 50)
// Apply different limiters to subrouters
publicRouter := router.PathPrefix("/public").Subrouter()
publicRouter.Use(publicLimiter.Middleware)
apiRouter := router.PathPrefix("/api").Subrouter()
apiRouter.Use(apiLimiter.Middleware)
adminRouter := router.PathPrefix("/admin").Subrouter()
adminRouter.Use(adminLimiter.Middleware)
}
```
### Rate Limit Response
When rate limited, clients receive:
```http
HTTP/1.1 429 Too Many Requests
Content-Type: text/plain
{"error":"rate_limit_exceeded","message":"Too many requests"}
```
### Configuration Examples
**Tight Rate Limit (Anti-abuse):**
```go
// 1 request per second, burst of 3
rateLimiter := middleware.NewRateLimiter(1, 3)
```
**Moderate Rate Limit (Standard API):**
```go
// 100 requests per second, burst of 20
rateLimiter := middleware.NewRateLimiter(100, 20)
```
**Generous Rate Limit (Internal Services):**
```go
// 1000 requests per second, burst of 100
rateLimiter := middleware.NewRateLimiter(1000, 100)
```
**Time-based Limits:**
```go
// 60 requests per minute = 1 request per second
rateLimiter := middleware.NewRateLimiter(1, 10)
// 1000 requests per hour ≈ 0.28 requests per second
rateLimiter := middleware.NewRateLimiter(0.28, 50)
```
### Understanding Burst
The burst parameter allows short bursts above the rate:
```go
// Rate: 10 rps, Burst: 5
// Allows up to 5 requests immediately, then 10/second
rateLimiter := middleware.NewRateLimiter(10, 5)
```
**Bucket fills at rate:** 10 tokens/second
**Bucket capacity:** 5 tokens
**Request consumes:** 1 token
**Example traffic pattern:**
- T=0s: 5 requests → ✅ All allowed (burst)
- T=0.1s: 1 request → ❌ Denied (bucket empty)
- T=0.5s: 1 request → ✅ Allowed (bucket refilled 0.5 tokens)
- T=1s: 1 request → ✅ Allowed (bucket has ~1 token)
### Cleanup Behavior
The rate limiter automatically cleans up inactive limiters every 5 minutes to prevent memory leaks.
### Performance Characteristics
- **Memory**: ~100 bytes per active limiter
- **Throughput**: >1M requests/second
- **Latency**: <1μs per request
- **Concurrency**: Lock-free for rate checks
### Production Deployment
**With Reverse Proxy:**
```go
// Use X-Forwarded-For or X-Real-IP
keyFunc := func(r *http.Request) string {
// Check proxy headers first
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return strings.Split(ip, ",")[0]
}
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
return r.RemoteAddr
}
router.Use(rateLimiter.MiddlewareWithKeyFunc(keyFunc))
```
**Environment-based Configuration:**
```go
import "os"
func getRateLimiter() *middleware.RateLimiter {
rps := getEnvFloat("RATE_LIMIT_RPS", 100)
burst := getEnvInt("RATE_LIMIT_BURST", 20)
return middleware.NewRateLimiter(rps, burst)
}
```
### Testing Rate Limits
```bash
# Send 10 requests rapidly
for i in {1..10}; do
curl -w "Status: %{http_code}\n" http://localhost:8080/api/data
done
```
**Expected output:**
```
Status: 200 # Request 1-5 (within burst)
Status: 200
Status: 200
Status: 200
Status: 200
Status: 429 # Request 6-10 (rate limited)
Status: 429
Status: 429
Status: 429
Status: 429
```
### Complete Example
```go
package main
import (
"encoding/json"
"log"
"net/http"
"os"
"strconv"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/gorilla/mux"
)
func main() {
// Configuration from environment
rps, _ := strconv.ParseFloat(os.Getenv("RATE_LIMIT_RPS"), 64)
if rps == 0 {
rps = 100 // Default
}
burst, _ := strconv.Atoi(os.Getenv("RATE_LIMIT_BURST"))
if burst == 0 {
burst = 20 // Default
}
// Create rate limiter
rateLimiter := middleware.NewRateLimiter(rps, burst)
// Custom key extraction
keyFunc := func(r *http.Request) string {
// Try API key first
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
return "api:" + apiKey
}
// Try authenticated user
if userID := r.Header.Get("X-User-ID"); userID != "" {
return "user:" + userID
}
// Fall back to IP
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return ip
}
return r.RemoteAddr
}
// Create router
router := mux.NewRouter()
// Apply rate limiting
router.Use(rateLimiter.MiddlewareWithKeyFunc(keyFunc))
// Routes
router.HandleFunc("/api/data", dataHandler)
router.HandleFunc("/health", healthHandler)
log.Printf("Starting server with rate limit: %.1f rps, burst: %d", rps, burst)
log.Fatal(http.ListenAndServe(":8080", router))
}
func dataHandler(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"message": "Data endpoint",
})
}
func healthHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}
```
## Best Practices
1. **Set Appropriate Limits**: Consider your backend capacity
- Database: Can it handle X queries/second?
- External APIs: What are their rate limits?
- Server resources: CPU, memory, connections
2. **Use Burst Wisely**: Allow legitimate traffic spikes
- Too low: Reject valid bursts
- Too high: Allow abuse
3. **Monitor Rate Limits**: Track how often limits are hit
```go
// Log rate limit events
if rateLimited {
log.Printf("Rate limited: %s", clientKey)
}
```
4. **Provide Feedback**: Include rate limit headers (future enhancement)
```http
X-RateLimit-Limit: 100
X-RateLimit-Remaining: 95
X-RateLimit-Reset: 1640000000
```
5. **Tiered Limits**: Different limits for different user tiers
```go
func getRateLimiter(userTier string) *middleware.RateLimiter {
switch userTier {
case "premium":
return middleware.NewRateLimiter(1000, 100)
case "standard":
return middleware.NewRateLimiter(100, 20)
default:
return middleware.NewRateLimiter(10, 5)
}
}
```
---
## Request Size Limits
Protect against oversized request bodies with configurable size limits.
### Quick Start
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
// Default: 10MB limit
sizeLimiter := middleware.NewRequestSizeLimiter(0)
router.Use(sizeLimiter.Middleware)
```
### Custom Size Limit
```go
// 5MB limit
sizeLimiter := middleware.NewRequestSizeLimiter(5 * 1024 * 1024)
router.Use(sizeLimiter.Middleware)
// Or use constants
sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size5MB)
```
### Available Size Constants
```go
middleware.Size1MB // 1 MB
middleware.Size5MB // 5 MB
middleware.Size10MB // 10 MB (default)
middleware.Size50MB // 50 MB
middleware.Size100MB // 100 MB
```
### Different Limits Per Route
```go
func main() {
router := mux.NewRouter()
// File upload endpoint: 50MB
uploadLimiter := middleware.NewRequestSizeLimiter(middleware.Size50MB)
uploadRouter := router.PathPrefix("/upload").Subrouter()
uploadRouter.Use(uploadLimiter.Middleware)
// API endpoints: 1MB
apiLimiter := middleware.NewRequestSizeLimiter(middleware.Size1MB)
apiRouter := router.PathPrefix("/api").Subrouter()
apiRouter.Use(apiLimiter.Middleware)
}
```
### Dynamic Size Limits
```go
// Custom size based on request
sizeFunc := func(r *http.Request) int64 {
// Premium users get 50MB
if isPremiumUser(r) {
return middleware.Size50MB
}
// Free users get 5MB
return middleware.Size5MB
}
router.Use(sizeLimiter.MiddlewareWithCustomSize(sizeFunc))
```
**By Content-Type:**
```go
sizeFunc := func(r *http.Request) int64 {
contentType := r.Header.Get("Content-Type")
switch {
case strings.Contains(contentType, "multipart/form-data"):
return middleware.Size50MB // File uploads
case strings.Contains(contentType, "application/json"):
return middleware.Size1MB // JSON APIs
default:
return middleware.Size10MB // Default
}
}
```
### Error Response
When size limit exceeded:
```http
HTTP/1.1 413 Request Entity Too Large
X-Max-Request-Size: 10485760
http: request body too large
```
### Complete Example
```go
package main
import (
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/gorilla/mux"
)
func main() {
router := mux.NewRouter()
// API routes: 1MB limit
api := router.PathPrefix("/api").Subrouter()
apiLimiter := middleware.NewRequestSizeLimiter(middleware.Size1MB)
api.Use(apiLimiter.Middleware)
api.HandleFunc("/users", createUserHandler).Methods("POST")
// Upload routes: 50MB limit
upload := router.PathPrefix("/upload").Subrouter()
uploadLimiter := middleware.NewRequestSizeLimiter(middleware.Size50MB)
upload.Use(uploadLimiter.Middleware)
upload.HandleFunc("/file", uploadFileHandler).Methods("POST")
log.Fatal(http.ListenAndServe(":8080", router))
}
```
---
## Input Sanitization
Protect against XSS, injection attacks, and malicious input.
### Quick Start
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
// Default sanitizer (safe defaults)
sanitizer := middleware.DefaultSanitizer()
router.Use(sanitizer.Middleware)
```
### Sanitizer Types
**Default Sanitizer (Recommended):**
```go
sanitizer := middleware.DefaultSanitizer()
// ✓ Escapes HTML entities
// ✓ Removes null bytes
// ✓ Removes control characters
// ✓ Blocks XSS patterns (script tags, event handlers)
// ✗ Does not strip HTML (allows legitimate content)
```
**Strict Sanitizer:**
```go
sanitizer := middleware.StrictSanitizer()
// ✓ All default features
// ✓ Strips ALL HTML tags
// ✓ Max string length: 10,000 chars
```
### Custom Configuration
```go
sanitizer := &middleware.Sanitizer{
StripHTML: true, // Remove HTML tags
EscapeHTML: false, // Don't escape (already stripped)
RemoveNullBytes: true, // Remove \x00
RemoveControlChars: true, // Remove dangerous control chars
MaxStringLength: 5000, // Limit to 5000 chars
// Block patterns (regex)
BlockPatterns: []*regexp.Regexp{
regexp.MustCompile(`(?i)<script`),
regexp.MustCompile(`(?i)javascript:`),
},
// Custom sanitization function
CustomSanitizer: func(s string) string {
// Your custom logic
return strings.ToLower(s)
},
}
router.Use(sanitizer.Middleware)
```
### What Gets Sanitized
**Automatic (via middleware):**
- Query parameters
- Headers (User-Agent, Referer, X-Forwarded-For, X-Real-IP)
**Manual (in your handler):**
- Request body (JSON, form data)
- Database queries
- File names
### Manual Sanitization
**String Values:**
```go
sanitizer := middleware.DefaultSanitizer()
// Sanitize user input
username := sanitizer.Sanitize(r.FormValue("username"))
email := sanitizer.Sanitize(r.FormValue("email"))
```
**Map/JSON Data:**
```go
var data map[string]interface{}
json.Unmarshal(body, &data)
// Sanitize all string values recursively
sanitizedData := sanitizer.SanitizeMap(data)
```
**Nested Structures:**
```go
type User struct {
Name string
Email string
Bio string
Profile map[string]interface{}
}
// After unmarshaling
user.Name = sanitizer.Sanitize(user.Name)
user.Email = sanitizer.Sanitize(user.Email)
user.Bio = sanitizer.Sanitize(user.Bio)
user.Profile = sanitizer.SanitizeMap(user.Profile)
```
### Specialized Sanitizers
**Filenames:**
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
filename := middleware.SanitizeFilename(uploadedFilename)
// Removes: .., /, \, null bytes
// Limits: 255 characters
```
**Emails:**
```go
email := middleware.SanitizeEmail(" USER@EXAMPLE.COM ")
// Result: "user@example.com"
// Trims, lowercases, removes null bytes
```
**URLs:**
```go
url := middleware.SanitizeURL(userInput)
// Blocks: javascript:, data: protocols
// Removes: null bytes
```
### Blocked Patterns (Default)
The default sanitizer blocks:
1. **Script tags**: `<script>...</script>`
2. **JavaScript protocol**: `javascript:alert(1)`
3. **Event handlers**: `onclick="..."`, `onerror="..."`
4. **Iframes**: `<iframe src="...">`
5. **Objects**: `<object data="...">`
6. **Embeds**: `<embed src="...">`
### Security Best Practices
**1. Layer Defense:**
```go
// Layer 1: Middleware (query params, headers)
router.Use(sanitizer.Middleware)
// Layer 2: Input validation (in handler)
func createUserHandler(w http.ResponseWriter, r *http.Request) {
var user User
json.NewDecoder(r.Body).Decode(&user)
// Sanitize
user.Name = sanitizer.Sanitize(user.Name)
user.Email = middleware.SanitizeEmail(user.Email)
// Validate
if !isValidEmail(user.Email) {
http.Error(w, "Invalid email", 400)
return
}
// Use parameterized queries (prevents SQL injection)
db.Exec("INSERT INTO users (name, email) VALUES (?, ?)",
user.Name, user.Email)
}
```
**2. Context-Aware Sanitization:**
```go
// HTML content (user posts, comments)
sanitizer := middleware.StrictSanitizer()
post.Content = sanitizer.Sanitize(post.Content)
// Structured data (JSON API)
sanitizer := middleware.DefaultSanitizer()
data = sanitizer.SanitizeMap(jsonData)
// Search queries (preserve special chars)
query = middleware.SanitizeFilename(searchTerm) // Light sanitization
```
**3. Output Encoding:**
```go
// When rendering HTML
import "html/template"
tmpl := template.Must(template.New("page").Parse(`
<h1>{{.Title}}</h1>
<p>{{.Content}}</p>
`))
// template.HTML automatically escapes
tmpl.Execute(w, data)
```
### Complete Example
```go
package main
import (
"encoding/json"
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/gorilla/mux"
)
func main() {
router := mux.NewRouter()
// Apply sanitization middleware
sanitizer := middleware.DefaultSanitizer()
router.Use(sanitizer.Middleware)
router.HandleFunc("/api/users", createUserHandler).Methods("POST")
log.Fatal(http.ListenAndServe(":8080", router))
}
func createUserHandler(w http.ResponseWriter, r *http.Request) {
sanitizer := middleware.DefaultSanitizer()
var user struct {
Name string `json:"name"`
Email string `json:"email"`
Bio string `json:"bio"`
}
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
http.Error(w, "Invalid JSON", 400)
return
}
// Sanitize inputs
user.Name = sanitizer.Sanitize(user.Name)
user.Email = middleware.SanitizeEmail(user.Email)
user.Bio = sanitizer.Sanitize(user.Bio)
// Validate
if len(user.Name) == 0 || len(user.Email) == 0 {
http.Error(w, "Name and email required", 400)
return
}
// Save to database (use parameterized queries!)
// db.Exec("INSERT INTO users (name, email, bio) VALUES (?, ?, ?)",
// user.Name, user.Email, user.Bio)
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{
"status": "created",
})
}
```
### Testing Sanitization
```bash
# Test XSS prevention
curl -X POST http://localhost:8080/api/users \
-H "Content-Type: application/json" \
-d '{
"name": "<script>alert(1)</script>John",
"email": "test@example.com",
"bio": "My bio with <iframe src=\"evil.com\"></iframe>"
}'
# Script tags and iframes should be removed
```
### Performance
- **Overhead**: <1ms per request for typical payloads
- **Regex compilation**: Done once at initialization
- **Safe for production**: Minimal performance impact

212
pkg/middleware/blacklist.go Normal file
View File

@ -0,0 +1,212 @@
package middleware
import (
"encoding/json"
"net"
"net/http"
"strings"
"sync"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// IPBlacklist provides IP blocking functionality
type IPBlacklist struct {
mu sync.RWMutex
ips map[string]bool // Individual IPs
cidrs []*net.IPNet // CIDR ranges
reason map[string]string
useProxy bool // Whether to check X-Forwarded-For headers
}
// BlacklistConfig configures the IP blacklist
type BlacklistConfig struct {
// UseProxy indicates whether to extract IP from X-Forwarded-For/X-Real-IP headers
UseProxy bool
}
// NewIPBlacklist creates a new IP blacklist
func NewIPBlacklist(config BlacklistConfig) *IPBlacklist {
return &IPBlacklist{
ips: make(map[string]bool),
cidrs: make([]*net.IPNet, 0),
reason: make(map[string]string),
useProxy: config.UseProxy,
}
}
// BlockIP blocks a single IP address
func (bl *IPBlacklist) BlockIP(ip string, reason string) error {
// Validate IP
if net.ParseIP(ip) == nil {
return &net.ParseError{Type: "IP address", Text: ip}
}
bl.mu.Lock()
defer bl.mu.Unlock()
bl.ips[ip] = true
if reason != "" {
bl.reason[ip] = reason
}
return nil
}
// BlockCIDR blocks an IP range using CIDR notation
func (bl *IPBlacklist) BlockCIDR(cidr string, reason string) error {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
return err
}
bl.mu.Lock()
defer bl.mu.Unlock()
bl.cidrs = append(bl.cidrs, ipNet)
if reason != "" {
bl.reason[cidr] = reason
}
return nil
}
// UnblockIP removes an IP from the blacklist
func (bl *IPBlacklist) UnblockIP(ip string) {
bl.mu.Lock()
defer bl.mu.Unlock()
delete(bl.ips, ip)
delete(bl.reason, ip)
}
// UnblockCIDR removes a CIDR range from the blacklist
func (bl *IPBlacklist) UnblockCIDR(cidr string) {
bl.mu.Lock()
defer bl.mu.Unlock()
// Find and remove the CIDR
for i, ipNet := range bl.cidrs {
if ipNet.String() == cidr {
bl.cidrs = append(bl.cidrs[:i], bl.cidrs[i+1:]...)
break
}
}
delete(bl.reason, cidr)
}
// IsBlocked checks if an IP is blacklisted
func (bl *IPBlacklist) IsBlocked(ip string) (blacklist bool, reason string) {
bl.mu.RLock()
defer bl.mu.RUnlock()
// Check individual IPs
if bl.ips[ip] {
return true, bl.reason[ip]
}
// Check CIDR ranges
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false, ""
}
for i, ipNet := range bl.cidrs {
if ipNet.Contains(parsedIP) {
cidr := ipNet.String()
// Try to find reason by CIDR or by index
if reason, ok := bl.reason[cidr]; ok {
return true, reason
}
// Check if reason was stored by original CIDR string
for key, reason := range bl.reason {
if strings.Contains(key, "/") && key == cidr {
return true, reason
}
}
// Return true even if no reason found
if i < len(bl.cidrs) {
return true, ""
}
}
}
return false, ""
}
// GetBlacklist returns all blacklisted IPs and CIDRs
func (bl *IPBlacklist) GetBlacklist() (ips []string, cidrs []string) {
bl.mu.RLock()
defer bl.mu.RUnlock()
ips = make([]string, 0, len(bl.ips))
for ip := range bl.ips {
ips = append(ips, ip)
}
cidrs = make([]string, 0, len(bl.cidrs))
for _, ipNet := range bl.cidrs {
cidrs = append(cidrs, ipNet.String())
}
return ips, cidrs
}
// Middleware returns an HTTP middleware that blocks blacklisted IPs
func (bl *IPBlacklist) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var clientIP string
if bl.useProxy {
clientIP = getClientIP(r)
// Clean up IPv6 brackets if present
clientIP = strings.Trim(clientIP, "[]")
} else {
// Extract IP from RemoteAddr
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
clientIP = r.RemoteAddr[:idx]
} else {
clientIP = r.RemoteAddr
}
clientIP = strings.Trim(clientIP, "[]")
}
blocked, reason := bl.IsBlocked(clientIP)
if blocked {
response := map[string]interface{}{
"error": "forbidden",
"message": "Access denied",
}
if reason != "" {
response["reason"] = reason
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
err := json.NewEncoder(w).Encode(response)
if err != nil {
logger.Debug("Failed to write blacklist response: %v", err)
}
return
}
next.ServeHTTP(w, r)
})
}
// StatsHandler returns an HTTP handler that shows blacklist statistics
func (bl *IPBlacklist) StatsHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ips, cidrs := bl.GetBlacklist()
stats := map[string]interface{}{
"blocked_ips": ips,
"blocked_cidrs": cidrs,
"total_ips": len(ips),
"total_cidrs": len(cidrs),
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(stats)
if err != nil {
logger.Debug("Failed to encode stats: %v", err)
}
})
}

View File

@ -0,0 +1,254 @@
package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestIPBlacklist_BlockIP(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
// Block an IP
err := bl.BlockIP("192.168.1.100", "Suspicious activity")
if err != nil {
t.Fatalf("BlockIP() error = %v", err)
}
// Check if IP is blocked
blocked, reason := bl.IsBlocked("192.168.1.100")
if !blocked {
t.Error("IP should be blocked")
}
if reason != "Suspicious activity" {
t.Errorf("Reason = %q, want %q", reason, "Suspicious activity")
}
// Check non-blocked IP
blocked, _ = bl.IsBlocked("192.168.1.1")
if blocked {
t.Error("IP should not be blocked")
}
}
func TestIPBlacklist_BlockCIDR(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
// Block a CIDR range
err := bl.BlockCIDR("10.0.0.0/24", "Internal network blocked")
if err != nil {
t.Fatalf("BlockCIDR() error = %v", err)
}
// Check IPs in range
testIPs := []string{
"10.0.0.1",
"10.0.0.100",
"10.0.0.254",
}
for _, ip := range testIPs {
blocked, _ := bl.IsBlocked(ip)
if !blocked {
t.Errorf("IP %s should be blocked by CIDR", ip)
}
}
// Check IP outside range
blocked, _ := bl.IsBlocked("10.0.1.1")
if blocked {
t.Error("IP outside CIDR range should not be blocked")
}
}
func TestIPBlacklist_UnblockIP(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
// Block and then unblock
bl.BlockIP("192.168.1.100", "Test")
blocked, _ := bl.IsBlocked("192.168.1.100")
if !blocked {
t.Error("IP should be blocked")
}
bl.UnblockIP("192.168.1.100")
blocked, _ = bl.IsBlocked("192.168.1.100")
if blocked {
t.Error("IP should be unblocked")
}
}
func TestIPBlacklist_UnblockCIDR(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
// Block and then unblock CIDR
bl.BlockCIDR("10.0.0.0/24", "Test")
blocked, _ := bl.IsBlocked("10.0.0.1")
if !blocked {
t.Error("IP should be blocked by CIDR")
}
bl.UnblockCIDR("10.0.0.0/24")
blocked, _ = bl.IsBlocked("10.0.0.1")
if blocked {
t.Error("IP should be unblocked after CIDR removal")
}
}
func TestIPBlacklist_Middleware(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
bl.BlockIP("192.168.1.100", "Banned")
handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
// Blocked IP should get 403
t.Run("BlockedIP", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "forbidden" {
t.Errorf("Error = %v, want %q", response["error"], "forbidden")
}
})
// Allowed IP should succeed
t.Run("AllowedIP", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
}
})
}
func TestIPBlacklist_MiddlewareWithProxy(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: true})
bl.BlockIP("203.0.113.1", "Blocked via proxy")
handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Test X-Forwarded-For
t.Run("X-Forwarded-For", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "10.0.0.1:12345"
req.Header.Set("X-Forwarded-For", "203.0.113.1")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
}
})
// Test X-Real-IP
t.Run("X-Real-IP", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "10.0.0.1:12345"
req.Header.Set("X-Real-IP", "203.0.113.1")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
}
})
}
func TestIPBlacklist_StatsHandler(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
bl.BlockIP("192.168.1.100", "Test1")
bl.BlockIP("192.168.1.101", "Test2")
bl.BlockCIDR("10.0.0.0/24", "Test CIDR")
handler := bl.StatsHandler()
req := httptest.NewRequest("GET", "/blacklist-stats", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
}
var stats map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if int(stats["total_ips"].(float64)) != 2 {
t.Errorf("total_ips = %v, want 2", stats["total_ips"])
}
if int(stats["total_cidrs"].(float64)) != 1 {
t.Errorf("total_cidrs = %v, want 1", stats["total_cidrs"])
}
}
func TestIPBlacklist_GetBlacklist(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
bl.BlockIP("192.168.1.100", "")
bl.BlockIP("192.168.1.101", "")
bl.BlockCIDR("10.0.0.0/24", "")
ips, cidrs := bl.GetBlacklist()
if len(ips) != 2 {
t.Errorf("len(ips) = %d, want 2", len(ips))
}
if len(cidrs) != 1 {
t.Errorf("len(cidrs) = %d, want 1", len(cidrs))
}
// Verify CIDR format
if cidrs[0] != "10.0.0.0/24" {
t.Errorf("CIDR = %q, want %q", cidrs[0], "10.0.0.0/24")
}
}
func TestIPBlacklist_InvalidIP(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
err := bl.BlockIP("invalid-ip", "Test")
if err == nil {
t.Error("BlockIP() should return error for invalid IP")
}
}
func TestIPBlacklist_InvalidCIDR(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
err := bl.BlockCIDR("invalid-cidr", "Test")
if err == nil {
t.Error("BlockCIDR() should return error for invalid CIDR")
}
}

233
pkg/middleware/ratelimit.go Normal file
View File

@ -0,0 +1,233 @@
// Package middleware provides HTTP middleware functionalities such as rate limiting and IP blacklisting.
package middleware
//nolint:all
import (
"encoding/json"
"net/http"
"strings"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"golang.org/x/time/rate"
)
// RateLimiter provides rate limiting functionality
type RateLimiter struct {
mu sync.RWMutex
limiters map[string]*rate.Limiter
rate rate.Limit
burst int
cleanup time.Duration
}
// NewRateLimiter creates a new rate limiter
// rps is requests per second, burst is the maximum burst size
func NewRateLimiter(rps float64, burst int) *RateLimiter {
rl := &RateLimiter{
limiters: make(map[string]*rate.Limiter),
rate: rate.Limit(rps),
burst: burst,
cleanup: 5 * time.Minute, // Clean up stale limiters every 5 minutes
}
// Start cleanup goroutine
go rl.cleanupRoutine()
return rl
}
// getLimiter returns the rate limiter for a given key (e.g., IP address)
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
rl.mu.RLock()
limiter, exists := rl.limiters[key]
rl.mu.RUnlock()
if exists {
return limiter
}
rl.mu.Lock()
defer rl.mu.Unlock()
// Double-check after acquiring write lock
if limiter, exists := rl.limiters[key]; exists {
return limiter
}
limiter = rate.NewLimiter(rl.rate, rl.burst)
rl.limiters[key] = limiter
return limiter
}
// cleanupRoutine periodically removes inactive limiters
func (rl *RateLimiter) cleanupRoutine() {
ticker := time.NewTicker(rl.cleanup)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
// Simple cleanup: remove all limiters
// In production, you might want to track last access time
rl.limiters = make(map[string]*rate.Limiter)
rl.mu.Unlock()
}
}
// Middleware returns an HTTP middleware that applies rate limiting
// Automatically handles X-Forwarded-For headers when behind a proxy
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract client IP, handling proxy headers
key := getClientIP(r)
limiter := rl.getLimiter(key)
if !limiter.Allow() {
http.Error(w, `{"error":"rate_limit_exceeded","message":"Too many requests"}`, http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// MiddlewareWithKeyFunc returns an HTTP middleware with a custom key extraction function
func (rl *RateLimiter) MiddlewareWithKeyFunc(keyFunc func(*http.Request) string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := keyFunc(r)
if key == "" {
key = r.RemoteAddr
}
limiter := rl.getLimiter(key)
if !limiter.Allow() {
http.Error(w, `{"error":"rate_limit_exceeded","message":"Too many requests"}`, http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// RateLimitInfo contains information about a specific IP's rate limit status
type RateLimitInfo struct {
IP string `json:"ip"`
TokensRemaining float64 `json:"tokens_remaining"`
Limit float64 `json:"limit"`
Burst int `json:"burst"`
}
// GetTrackedIPs returns all IPs currently being tracked by the rate limiter
func (rl *RateLimiter) GetTrackedIPs() []string {
rl.mu.RLock()
defer rl.mu.RUnlock()
ips := make([]string, 0, len(rl.limiters))
for ip := range rl.limiters {
ips = append(ips, ip)
}
return ips
}
// GetRateLimitInfo returns rate limit information for a specific IP
func (rl *RateLimiter) GetRateLimitInfo(ip string) *RateLimitInfo {
rl.mu.RLock()
limiter, exists := rl.limiters[ip]
rl.mu.RUnlock()
if !exists {
// Return default info for untracked IP
return &RateLimitInfo{
IP: ip,
TokensRemaining: float64(rl.burst),
Limit: float64(rl.rate),
Burst: rl.burst,
}
}
return &RateLimitInfo{
IP: ip,
TokensRemaining: limiter.Tokens(),
Limit: float64(rl.rate),
Burst: rl.burst,
}
}
// GetAllRateLimitInfo returns rate limit information for all tracked IPs
func (rl *RateLimiter) GetAllRateLimitInfo() []*RateLimitInfo {
ips := rl.GetTrackedIPs()
info := make([]*RateLimitInfo, 0, len(ips))
for _, ip := range ips {
info = append(info, rl.GetRateLimitInfo(ip))
}
return info
}
// StatsHandler returns an HTTP handler that exposes rate limit statistics
// Example: GET /rate-limit-stats
func (rl *RateLimiter) StatsHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Support querying specific IP via ?ip=x.x.x.x
if ip := r.URL.Query().Get("ip"); ip != "" {
info := rl.GetRateLimitInfo(ip)
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(info)
if err != nil {
logger.Debug("Failed to encode json: %v", err)
}
return
}
// Return all tracked IPs
allInfo := rl.GetAllRateLimitInfo()
stats := map[string]interface{}{
"total_tracked_ips": len(allInfo),
"rate_limit_config": map[string]interface{}{
"requests_per_second": float64(rl.rate),
"burst": rl.burst,
},
"tracked_ips": allInfo,
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(stats)
if err != nil {
logger.Debug("Failed to encode json: %v", err)
}
})
}
// getClientIP extracts the real client IP from the request
// Handles X-Forwarded-For, X-Real-IP, and falls back to RemoteAddr
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header (most common in production)
// Format: X-Forwarded-For: client, proxy1, proxy2
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP (the original client)
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header (used by some proxies like nginx)
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
// Remove port if present (format: "ip:port")
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}

View File

@ -0,0 +1,388 @@
package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRateLimiter(t *testing.T) {
// Create rate limiter: 2 requests per second, burst of 2
rl := NewRateLimiter(2, 2)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
// First request should succeed
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("First request failed: got %d, want %d", w.Code, http.StatusOK)
}
// Second request should succeed (within burst)
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Second request failed: got %d, want %d", w.Code, http.StatusOK)
}
// Third request should be rate limited
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("Third request should be rate limited: got %d, want %d", w.Code, http.StatusTooManyRequests)
}
// Wait for rate limiter to refill
time.Sleep(600 * time.Millisecond)
// Request should succeed again
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Request after wait failed: got %d, want %d", w.Code, http.StatusOK)
}
}
func TestRateLimiterDifferentIPs(t *testing.T) {
rl := NewRateLimiter(1, 1)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First IP
req1 := httptest.NewRequest("GET", "/test", nil)
req1.RemoteAddr = "192.168.1.1:12345"
// Second IP
req2 := httptest.NewRequest("GET", "/test", nil)
req2.RemoteAddr = "192.168.1.2:12345"
// Both should succeed (different IPs)
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
if w1.Code != http.StatusOK {
t.Errorf("First IP request failed: got %d, want %d", w1.Code, http.StatusOK)
}
if w2.Code != http.StatusOK {
t.Errorf("Second IP request failed: got %d, want %d", w2.Code, http.StatusOK)
}
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
xForwardedFor string
xRealIP string
expectedIP string
}{
{
name: "RemoteAddr only",
remoteAddr: "192.168.1.1:12345",
expectedIP: "192.168.1.1",
},
{
name: "X-Forwarded-For single IP",
remoteAddr: "10.0.0.1:12345",
xForwardedFor: "203.0.113.1",
expectedIP: "203.0.113.1",
},
{
name: "X-Forwarded-For multiple IPs",
remoteAddr: "10.0.0.1:12345",
xForwardedFor: "203.0.113.1, 10.0.0.2, 10.0.0.3",
expectedIP: "203.0.113.1",
},
{
name: "X-Real-IP",
remoteAddr: "10.0.0.1:12345",
xRealIP: "203.0.113.1",
expectedIP: "203.0.113.1",
},
{
name: "X-Forwarded-For takes precedence over X-Real-IP",
remoteAddr: "10.0.0.1:12345",
xForwardedFor: "203.0.113.1",
xRealIP: "203.0.113.2",
expectedIP: "203.0.113.1",
},
{
name: "IPv6 address",
remoteAddr: "[2001:db8::1]:12345",
expectedIP: "[2001:db8::1]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
}
if tt.xRealIP != "" {
req.Header.Set("X-Real-IP", tt.xRealIP)
}
ip := getClientIP(req)
if ip != tt.expectedIP {
t.Errorf("getClientIP() = %q, want %q", ip, tt.expectedIP)
}
})
}
}
func TestRateLimiterWithCustomKeyFunc(t *testing.T) {
rl := NewRateLimiter(1, 1)
// Use user ID as key
keyFunc := func(r *http.Request) string {
userID := r.Header.Get("X-User-ID")
if userID == "" {
return r.RemoteAddr
}
return "user:" + userID
}
handler := rl.MiddlewareWithKeyFunc(keyFunc)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// User 1
req1 := httptest.NewRequest("GET", "/test", nil)
req1.Header.Set("X-User-ID", "user1")
// User 2
req2 := httptest.NewRequest("GET", "/test", nil)
req2.Header.Set("X-User-ID", "user2")
// Both users should succeed (different keys)
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
if w1.Code != http.StatusOK {
t.Errorf("User 1 request failed: got %d, want %d", w1.Code, http.StatusOK)
}
if w2.Code != http.StatusOK {
t.Errorf("User 2 request failed: got %d, want %d", w2.Code, http.StatusOK)
}
// User 1 second request should be rate limited
w1 = httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
if w1.Code != http.StatusTooManyRequests {
t.Errorf("User 1 second request should be rate limited: got %d, want %d", w1.Code, http.StatusTooManyRequests)
}
}
func TestRateLimiter_GetTrackedIPs(t *testing.T) {
rl := NewRateLimiter(10, 10)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make requests from different IPs
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
for _, ip := range ips {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = ip + ":12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
}
// Check tracked IPs
trackedIPs := rl.GetTrackedIPs()
if len(trackedIPs) != len(ips) {
t.Errorf("len(trackedIPs) = %d, want %d", len(trackedIPs), len(ips))
}
// Verify all IPs are tracked
ipMap := make(map[string]bool)
for _, ip := range trackedIPs {
ipMap[ip] = true
}
for _, ip := range ips {
if !ipMap[ip] {
t.Errorf("IP %s should be tracked", ip)
}
}
}
func TestRateLimiter_GetRateLimitInfo(t *testing.T) {
rl := NewRateLimiter(10, 5)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make a request
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Get rate limit info
info := rl.GetRateLimitInfo("192.168.1.1")
if info.IP != "192.168.1.1" {
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
}
if info.Limit != 10.0 {
t.Errorf("Limit = %f, want 10.0", info.Limit)
}
if info.Burst != 5 {
t.Errorf("Burst = %d, want 5", info.Burst)
}
// Tokens should be less than burst after one request
if info.TokensRemaining >= float64(info.Burst) {
t.Errorf("TokensRemaining = %f, should be less than %d", info.TokensRemaining, info.Burst)
}
}
func TestRateLimiter_GetRateLimitInfo_UntrackedIP(t *testing.T) {
rl := NewRateLimiter(10, 5)
// Get info for untracked IP (should return default)
info := rl.GetRateLimitInfo("192.168.1.1")
if info.IP != "192.168.1.1" {
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
}
if info.TokensRemaining != float64(rl.burst) {
t.Errorf("TokensRemaining = %f, want %d (full burst)", info.TokensRemaining, rl.burst)
}
}
func TestRateLimiter_GetAllRateLimitInfo(t *testing.T) {
rl := NewRateLimiter(10, 10)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make requests from different IPs
ips := []string{"192.168.1.1", "192.168.1.2"}
for _, ip := range ips {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = ip + ":12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
}
// Get all rate limit info
allInfo := rl.GetAllRateLimitInfo()
if len(allInfo) != len(ips) {
t.Errorf("len(allInfo) = %d, want %d", len(allInfo), len(ips))
}
// Verify each IP has info
for _, info := range allInfo {
found := false
for _, ip := range ips {
if info.IP == ip {
found = true
break
}
}
if !found {
t.Errorf("Unexpected IP in info: %s", info.IP)
}
}
}
func TestRateLimiter_StatsHandler(t *testing.T) {
rl := NewRateLimiter(10, 5)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make requests from different IPs
req1 := httptest.NewRequest("GET", "/test", nil)
req1.RemoteAddr = "192.168.1.1:12345"
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
req2 := httptest.NewRequest("GET", "/test", nil)
req2.RemoteAddr = "192.168.1.2:12345"
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
// Test stats handler (all IPs)
t.Run("AllIPs", func(t *testing.T) {
statsHandler := rl.StatsHandler()
req := httptest.NewRequest("GET", "/rate-limit-stats", nil)
w := httptest.NewRecorder()
statsHandler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
}
var stats map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if int(stats["total_tracked_ips"].(float64)) != 2 {
t.Errorf("total_tracked_ips = %v, want 2", stats["total_tracked_ips"])
}
config := stats["rate_limit_config"].(map[string]interface{})
if config["requests_per_second"].(float64) != 10.0 {
t.Errorf("requests_per_second = %v, want 10.0", config["requests_per_second"])
}
})
// Test stats handler (specific IP)
t.Run("SpecificIP", func(t *testing.T) {
statsHandler := rl.StatsHandler()
req := httptest.NewRequest("GET", "/rate-limit-stats?ip=192.168.1.1", nil)
w := httptest.NewRecorder()
statsHandler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
}
var info RateLimitInfo
if err := json.Unmarshal(w.Body.Bytes(), &info); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if info.IP != "192.168.1.1" {
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
}
})
}

251
pkg/middleware/sanitize.go Normal file
View File

@ -0,0 +1,251 @@
package middleware
import (
"html"
"net/http"
"regexp"
"strings"
)
// Sanitizer provides input sanitization beyond SQL injection protection
type Sanitizer struct {
// StripHTML removes HTML tags from input
StripHTML bool
// EscapeHTML escapes HTML entities
EscapeHTML bool
// RemoveNullBytes removes null bytes from input
RemoveNullBytes bool
// RemoveControlChars removes control characters (except newline, carriage return, tab)
RemoveControlChars bool
// MaxStringLength limits individual string field length (0 = no limit)
MaxStringLength int
// BlockPatterns are regex patterns to block (e.g., script tags, SQL keywords)
BlockPatterns []*regexp.Regexp
// Custom sanitization function
CustomSanitizer func(string) string
}
// DefaultSanitizer returns a sanitizer with secure defaults
func DefaultSanitizer() *Sanitizer {
return &Sanitizer{
StripHTML: false, // Don't strip by default (breaks legitimate HTML content)
EscapeHTML: true, // Escape HTML entities to prevent XSS
RemoveNullBytes: true, // Remove null bytes (security best practice)
RemoveControlChars: true, // Remove dangerous control characters
MaxStringLength: 0, // No limit by default
// Block common XSS and injection patterns
BlockPatterns: []*regexp.Regexp{
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`), // Script tags
regexp.MustCompile(`(?i)javascript:`), // JavaScript protocol
regexp.MustCompile(`(?i)on\w+\s*=`), // Event handlers (onclick, onerror, etc.)
regexp.MustCompile(`(?i)<iframe[^>]*>`), // Iframes
regexp.MustCompile(`(?i)<object[^>]*>`), // Objects
regexp.MustCompile(`(?i)<embed[^>]*>`), // Embeds
},
}
}
// StrictSanitizer returns a sanitizer with very strict rules
func StrictSanitizer() *Sanitizer {
s := DefaultSanitizer()
s.StripHTML = true
s.MaxStringLength = 10000
return s
}
// Sanitize sanitizes a string value
func (s *Sanitizer) Sanitize(value string) string {
if value == "" {
return value
}
// Remove null bytes
if s.RemoveNullBytes {
value = strings.ReplaceAll(value, "\x00", "")
}
// Remove control characters
if s.RemoveControlChars {
value = removeControlCharacters(value)
}
// Check block patterns
for _, pattern := range s.BlockPatterns {
if pattern.MatchString(value) {
// Replace matched pattern with empty string
value = pattern.ReplaceAllString(value, "")
}
}
// Strip HTML tags
if s.StripHTML {
value = stripHTMLTags(value)
}
// Escape HTML entities
if s.EscapeHTML && !s.StripHTML {
value = html.EscapeString(value)
}
// Apply max length
if s.MaxStringLength > 0 && len(value) > s.MaxStringLength {
value = value[:s.MaxStringLength]
}
// Apply custom sanitizer
if s.CustomSanitizer != nil {
value = s.CustomSanitizer(value)
}
return value
}
// SanitizeMap sanitizes all string values in a map
func (s *Sanitizer) SanitizeMap(data map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for key, value := range data {
result[key] = s.sanitizeValue(value)
}
return result
}
// sanitizeValue recursively sanitizes values
func (s *Sanitizer) sanitizeValue(value interface{}) interface{} {
switch v := value.(type) {
case string:
return s.Sanitize(v)
case map[string]interface{}:
return s.SanitizeMap(v)
case []interface{}:
result := make([]interface{}, len(v))
for i, item := range v {
result[i] = s.sanitizeValue(item)
}
return result
default:
return value
}
}
// Middleware returns an HTTP middleware that sanitizes request headers and query params
// Note: Body sanitization should be done at the application level after parsing
func (s *Sanitizer) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Sanitize query parameters
if r.URL.RawQuery != "" {
q := r.URL.Query()
sanitized := false
for key, values := range q {
for i, value := range values {
sanitizedValue := s.Sanitize(value)
if sanitizedValue != value {
values[i] = sanitizedValue
sanitized = true
}
}
if sanitized {
q[key] = values
}
}
if sanitized {
r.URL.RawQuery = q.Encode()
}
}
// Sanitize specific headers (User-Agent, Referer, etc.)
dangerousHeaders := []string{
"User-Agent",
"Referer",
"X-Forwarded-For",
"X-Real-IP",
}
for _, header := range dangerousHeaders {
if value := r.Header.Get(header); value != "" {
sanitized := s.Sanitize(value)
if sanitized != value {
r.Header.Set(header, sanitized)
}
}
}
next.ServeHTTP(w, r)
})
}
// Helper functions
// removeControlCharacters removes control characters except \n, \r, \t
func removeControlCharacters(s string) string {
var result strings.Builder
for _, r := range s {
// Keep newline, carriage return, tab, and non-control characters
if r == '\n' || r == '\r' || r == '\t' || r >= 32 {
result.WriteRune(r)
}
}
return result.String()
}
// stripHTMLTags removes HTML tags from a string
func stripHTMLTags(s string) string {
// Simple regex to remove HTML tags
re := regexp.MustCompile(`<[^>]*>`)
return re.ReplaceAllString(s, "")
}
// Common sanitization patterns
// SanitizeFilename sanitizes a filename
func SanitizeFilename(filename string) string {
// Remove path traversal attempts
filename = strings.ReplaceAll(filename, "..", "")
filename = strings.ReplaceAll(filename, "/", "")
filename = strings.ReplaceAll(filename, "\\", "")
// Remove null bytes
filename = strings.ReplaceAll(filename, "\x00", "")
// Limit length
if len(filename) > 255 {
filename = filename[:255]
}
return filename
}
// SanitizeEmail performs basic email sanitization
func SanitizeEmail(email string) string {
email = strings.TrimSpace(strings.ToLower(email))
// Remove dangerous characters
email = strings.ReplaceAll(email, "\x00", "")
email = removeControlCharacters(email)
return email
}
// SanitizeURL performs basic URL sanitization
func SanitizeURL(url string) string {
url = strings.TrimSpace(url)
// Remove null bytes
url = strings.ReplaceAll(url, "\x00", "")
// Block javascript: and data: protocols
if strings.HasPrefix(strings.ToLower(url), "javascript:") {
return ""
}
if strings.HasPrefix(strings.ToLower(url), "data:") {
return ""
}
return url
}

View File

@ -0,0 +1,273 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestSanitizeXSS(t *testing.T) {
sanitizer := DefaultSanitizer()
tests := []struct {
name string
input string
contains string // String that should NOT be in output
}{
{
name: "Script tag",
input: "<script>alert(1)</script>",
contains: "<script>",
},
{
name: "JavaScript protocol",
input: "javascript:alert(1)",
contains: "javascript:",
},
{
name: "Event handler",
input: "<img onerror='alert(1)'>",
contains: "onerror=",
},
{
name: "Iframe",
input: "<iframe src='evil.com'></iframe>",
contains: "<iframe",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sanitizer.Sanitize(tt.input)
if result == tt.input {
t.Errorf("Sanitize() did not modify input: %q", tt.input)
}
})
}
}
func TestSanitizeNullBytes(t *testing.T) {
sanitizer := DefaultSanitizer()
input := "hello\x00world"
result := sanitizer.Sanitize(input)
if result == input {
t.Error("Null bytes should be removed")
}
if len(result) >= len(input) {
t.Errorf("Result length should be less than input: got %d, input %d", len(result), len(input))
}
}
func TestSanitizeControlCharacters(t *testing.T) {
sanitizer := DefaultSanitizer()
// Include various control characters
input := "hello\x01\x02world\x1F"
result := sanitizer.Sanitize(input)
if result == input {
t.Error("Control characters should be removed")
}
// Newlines, tabs, carriage returns should be preserved
input2 := "hello\nworld\t\r"
result2 := sanitizer.Sanitize(input2)
if result2 != input2 {
t.Errorf("Safe control characters should be preserved: got %q, want %q", result2, input2)
}
}
func TestSanitizeMap(t *testing.T) {
sanitizer := DefaultSanitizer()
input := map[string]interface{}{
"name": "<script>alert(1)</script>John",
"email": "test@example.com",
"nested": map[string]interface{}{
"bio": "<iframe src='evil.com'>Bio</iframe>",
},
}
result := sanitizer.SanitizeMap(input)
// Check that script tag was removed/escaped
name, ok := result["name"].(string)
if !ok || name == input["name"] {
t.Error("Name should be sanitized")
}
// Check nested map
nested, ok := result["nested"].(map[string]interface{})
if !ok {
t.Fatal("Nested should still be a map")
}
bio, ok := nested["bio"].(string)
if !ok || bio == input["nested"].(map[string]interface{})["bio"] {
t.Error("Nested bio should be sanitized")
}
}
func TestSanitizeMiddleware(t *testing.T) {
sanitizer := DefaultSanitizer()
handler := sanitizer.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check that query param was sanitized
param := r.URL.Query().Get("q")
if param == "<script>alert(1)</script>" {
t.Error("Query param should be sanitized")
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test?q=<script>alert(1)</script>", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Handler failed: got %d, want %d", w.Code, http.StatusOK)
}
}
func TestSanitizeFilename(t *testing.T) {
tests := []struct {
name string
input string
contains string // String that should NOT be in output
}{
{
name: "Path traversal",
input: "../../../etc/passwd",
contains: "..",
},
{
name: "Absolute path",
input: "/etc/passwd",
contains: "/",
},
{
name: "Windows path",
input: "..\\..\\windows\\system32",
contains: "\\",
},
{
name: "Null byte",
input: "file\x00.txt",
contains: "\x00",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeFilename(tt.input)
if result == tt.input {
t.Errorf("SanitizeFilename() did not modify input: %q", tt.input)
}
})
}
}
func TestSanitizeEmail(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "Uppercase",
input: "TEST@EXAMPLE.COM",
expected: "test@example.com",
},
{
name: "Whitespace",
input: " test@example.com ",
expected: "test@example.com",
},
{
name: "Null bytes",
input: "test\x00@example.com",
expected: "test@example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeEmail(tt.input)
if result != tt.expected {
t.Errorf("SanitizeEmail() = %q, want %q", result, tt.expected)
}
})
}
}
func TestSanitizeURL(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "JavaScript protocol",
input: "javascript:alert(1)",
expected: "",
},
{
name: "Data protocol",
input: "data:text/html,<script>alert(1)</script>",
expected: "",
},
{
name: "Valid HTTP URL",
input: "https://example.com",
expected: "https://example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeURL(tt.input)
if result != tt.expected {
t.Errorf("SanitizeURL() = %q, want %q", result, tt.expected)
}
})
}
}
func TestStrictSanitizer(t *testing.T) {
sanitizer := StrictSanitizer()
input := "<b>Bold text</b> with <script>alert(1)</script>"
result := sanitizer.Sanitize(input)
// Should strip ALL HTML tags
if result == input {
t.Error("Strict sanitizer should modify input")
}
// Should not contain any HTML tags
if len(result) > 0 && (result[0] == '<' || result[len(result)-1] == '>') {
t.Error("Result should not contain HTML tags")
}
}
func TestMaxStringLength(t *testing.T) {
sanitizer := &Sanitizer{
MaxStringLength: 10,
}
input := "This is a very long string that exceeds the maximum length"
result := sanitizer.Sanitize(input)
if len(result) != 10 {
t.Errorf("Result length = %d, want 10", len(result))
}
if result != input[:10] {
t.Errorf("Result = %q, want %q", result, input[:10])
}
}

View File

@ -0,0 +1,70 @@
package middleware
import (
"fmt"
"net/http"
)
const (
// DefaultMaxRequestSize is the default maximum request body size (10MB)
DefaultMaxRequestSize = 10 * 1024 * 1024 // 10MB
// MaxRequestSizeHeader is the header name for max request size
MaxRequestSizeHeader = "X-Max-Request-Size"
)
// RequestSizeLimiter limits the size of request bodies
type RequestSizeLimiter struct {
maxSize int64
}
// NewRequestSizeLimiter creates a new request size limiter
// maxSize is in bytes. If 0, uses DefaultMaxRequestSize (10MB)
func NewRequestSizeLimiter(maxSize int64) *RequestSizeLimiter {
if maxSize <= 0 {
maxSize = DefaultMaxRequestSize
}
return &RequestSizeLimiter{
maxSize: maxSize,
}
}
// Middleware returns an HTTP middleware that enforces request size limits
func (rsl *RequestSizeLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set max bytes reader on the request body
r.Body = http.MaxBytesReader(w, r.Body, rsl.maxSize)
// Add informational header
w.Header().Set(MaxRequestSizeHeader, fmt.Sprintf("%d", rsl.maxSize))
next.ServeHTTP(w, r)
})
}
// MiddlewareWithCustomSize returns middleware with a custom size limit function
// This allows different size limits based on the request
func (rsl *RequestSizeLimiter) MiddlewareWithCustomSize(sizeFunc func(*http.Request) int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
maxSize := sizeFunc(r)
if maxSize <= 0 {
maxSize = rsl.maxSize
}
r.Body = http.MaxBytesReader(w, r.Body, maxSize)
w.Header().Set(MaxRequestSizeHeader, fmt.Sprintf("%d", maxSize))
next.ServeHTTP(w, r)
})
}
}
// Common size limits
const (
Size1MB = 1 * 1024 * 1024
Size5MB = 5 * 1024 * 1024
Size10MB = 10 * 1024 * 1024
Size50MB = 50 * 1024 * 1024
Size100MB = 100 * 1024 * 1024
)

View File

@ -0,0 +1,126 @@
package middleware
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestSizeLimiter(t *testing.T) {
// 1KB limit
limiter := NewRequestSizeLimiter(1024)
handler := limiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Try to read body
_, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusRequestEntityTooLarge)
return
}
w.WriteHeader(http.StatusOK)
}))
// Small request (should succeed)
t.Run("SmallRequest", func(t *testing.T) {
body := bytes.NewReader(make([]byte, 512)) // 512 bytes
req := httptest.NewRequest("POST", "/test", body)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Small request failed: got %d, want %d", w.Code, http.StatusOK)
}
// Check header
if maxSize := w.Header().Get(MaxRequestSizeHeader); maxSize != "1024" {
t.Errorf("MaxRequestSizeHeader = %q, want %q", maxSize, "1024")
}
})
// Large request (should fail)
t.Run("LargeRequest", func(t *testing.T) {
body := bytes.NewReader(make([]byte, 2048)) // 2KB
req := httptest.NewRequest("POST", "/test", body)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Errorf("Large request should fail: got %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
}
})
}
func TestRequestSizeLimiterDefault(t *testing.T) {
// Default limiter (10MB)
limiter := NewRequestSizeLimiter(0)
handler := limiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("POST", "/test", bytes.NewReader(make([]byte, 1024)))
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Request failed: got %d, want %d", w.Code, http.StatusOK)
}
// Check default size
if maxSize := w.Header().Get(MaxRequestSizeHeader); maxSize != "10485760" {
t.Errorf("Default MaxRequestSizeHeader = %q, want %q", maxSize, "10485760")
}
}
func TestRequestSizeLimiterWithCustomSize(t *testing.T) {
limiter := NewRequestSizeLimiter(1024)
// Premium users get 10MB, regular users get 1KB
sizeFunc := func(r *http.Request) int64 {
if r.Header.Get("X-User-Tier") == "premium" {
return Size10MB
}
return 1024
}
handler := limiter.MiddlewareWithCustomSize(sizeFunc)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusRequestEntityTooLarge)
return
}
w.WriteHeader(http.StatusOK)
}))
// Regular user with large request (should fail)
t.Run("RegularUserLargeRequest", func(t *testing.T) {
body := bytes.NewReader(make([]byte, 2048))
req := httptest.NewRequest("POST", "/test", body)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Errorf("Regular user large request should fail: got %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
}
})
// Premium user with large request (should succeed)
t.Run("PremiumUserLargeRequest", func(t *testing.T) {
body := bytes.NewReader(make([]byte, 2048))
req := httptest.NewRequest("POST", "/test", body)
req.Header.Set("X-User-Tier", "premium")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Premium user large request failed: got %d, want %d", w.Code, http.StatusOK)
}
})
}

View File

@ -6,15 +6,37 @@ import (
"sync"
)
// ModelRules defines the permissions and security settings for a model
type ModelRules struct {
CanRead bool // Whether the model can be read (GET operations)
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
CanCreate bool // Whether the model can be created (POST operations)
CanDelete bool // Whether the model can be deleted (DELETE operations)
SecurityDisabled bool // Whether security checks are disabled for this model
}
// DefaultModelRules returns the default rules for a model (all operations allowed, security enabled)
func DefaultModelRules() ModelRules {
return ModelRules{
CanRead: true,
CanUpdate: true,
CanCreate: true,
CanDelete: true,
SecurityDisabled: false,
}
}
// DefaultModelRegistry implements ModelRegistry interface
type DefaultModelRegistry struct {
models map[string]interface{}
rules map[string]ModelRules
mutex sync.RWMutex
}
// Global default registry instance
var defaultRegistry = &DefaultModelRegistry{
models: make(map[string]interface{}),
rules: make(map[string]ModelRules),
}
// Global list of registries (searched in order)
@ -25,11 +47,18 @@ var registriesMutex sync.RWMutex
func NewModelRegistry() *DefaultModelRegistry {
return &DefaultModelRegistry{
models: make(map[string]interface{}),
rules: make(map[string]ModelRules),
}
}
func GetDefaultRegistry() *DefaultModelRegistry {
return defaultRegistry
}
func SetDefaultRegistry(registry *DefaultModelRegistry) {
registriesMutex.Lock()
defer registriesMutex.Unlock()
foundAt := -1
for idx, r := range registries {
if r == defaultRegistry {
@ -43,9 +72,6 @@ func SetDefaultRegistry(registry *DefaultModelRegistry) {
} else {
registries = append([]*DefaultModelRegistry{registry}, registries...)
}
defer registriesMutex.Unlock()
}
// AddRegistry adds a registry to the global list of registries
@ -95,6 +121,10 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
}
r.models[name] = model
// Initialize with default rules if not already set
if _, exists := r.rules[name]; !exists {
r.rules[name] = DefaultModelRules()
}
return nil
}
@ -132,6 +162,54 @@ func (r *DefaultModelRegistry) GetModelByEntity(schema, entity string) (interfac
return r.GetModel(entity)
}
// SetModelRules sets the rules for a specific model
func (r *DefaultModelRegistry) SetModelRules(name string, rules ModelRules) error {
r.mutex.Lock()
defer r.mutex.Unlock()
// Check if model exists
if _, exists := r.models[name]; !exists {
return fmt.Errorf("model %s not found", name)
}
r.rules[name] = rules
return nil
}
// GetModelRules retrieves the rules for a specific model
// Returns default rules if model exists but rules are not set
func (r *DefaultModelRegistry) GetModelRules(name string) (ModelRules, error) {
r.mutex.RLock()
defer r.mutex.RUnlock()
// Check if model exists
if _, exists := r.models[name]; !exists {
return ModelRules{}, fmt.Errorf("model %s not found", name)
}
// Return rules if set, otherwise return default rules
if rules, exists := r.rules[name]; exists {
return rules, nil
}
return DefaultModelRules(), nil
}
// RegisterModelWithRules registers a model with specific rules
func (r *DefaultModelRegistry) RegisterModelWithRules(name string, model interface{}, rules ModelRules) error {
// First register the model
if err := r.RegisterModel(name, model); err != nil {
return err
}
// Then set the rules (we need to lock again for rules)
r.mutex.Lock()
defer r.mutex.Unlock()
r.rules[name] = rules
return nil
}
// Global convenience functions using the default registry
// RegisterModel registers a model with the default global registry
@ -187,3 +265,34 @@ func GetModels() []interface{} {
return models
}
// SetModelRules sets the rules for a specific model in the default registry
func SetModelRules(name string, rules ModelRules) error {
return defaultRegistry.SetModelRules(name, rules)
}
// GetModelRules retrieves the rules for a specific model from the default registry
func GetModelRules(name string) (ModelRules, error) {
return defaultRegistry.GetModelRules(name)
}
// GetModelRulesByName retrieves the rules for a model by searching through all registries in order
// Returns the first match found
func GetModelRulesByName(name string) (ModelRules, error) {
registriesMutex.RLock()
defer registriesMutex.RUnlock()
for _, registry := range registries {
if _, err := registry.GetModel(name); err == nil {
// Model found in this registry, get its rules
return registry.GetModelRules(name)
}
}
return ModelRules{}, fmt.Errorf("model %s not found in any registry", name)
}
// RegisterModelWithRules registers a model with specific rules in the default registry
func RegisterModelWithRules(model interface{}, name string, rules ModelRules) error {
return defaultRegistry.RegisterModelWithRules(name, model, rules)
}

Some files were not shown because too many files have changed in this diff Show More