Compare commits

...

72 Commits

Author SHA1 Message Date
Hein
1d0407a16d Fixed linting
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-10 10:00:01 +02:00
Hein
99001c749d Better sql where validation 2025-12-10 09:52:13 +02:00
Hein
1f7a57f8e3 Tracking provider 2025-12-10 09:31:55 +02:00
Hein
a95c28a0bf Multi Token warning and handling 2025-12-10 08:44:37 +02:00
Hein
e1abd5ebc1 Enhanced the SanitizeWhereClause function 2025-12-10 08:36:24 +02:00
Hein
ca4e53969b Better tests
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-09 15:32:16 +02:00
Hein
db2b7e878e Better handling of preloads 2025-12-09 15:12:17 +02:00
Hein
9572bfc7b8 Fix qualified column reference (like APIL.rid_hub) in a preload: 2025-12-09 14:46:33 +02:00
Hein
f0962ea1ec Added EnableQueryDebug log 2025-12-09 14:37:09 +02:00
Hein
8fcb065b42 Better Query Debugging 2025-12-09 14:31:26 +02:00
Hein
dc3b621380 Fixed test for session id changes 2025-12-09 14:07:00 +02:00
Hein
a4dd2a7086 exposed types FromString 2025-12-09 14:03:55 +02:00
Hein
3ec2e5f15a Proper handling of fromString in the types 2025-12-09 13:55:51 +02:00
Hein
c52afe2825 Updated sql types 2025-12-09 13:14:22 +02:00
Hein
76e98d02c3 Added modelregistry.GetDefaultRegistry 2025-12-09 12:12:10 +02:00
Hein
23e2db1496 Fixed linting 2025-12-09 12:02:44 +02:00
Hein
d188f49126 Added openapi spec 2025-12-09 12:01:21 +02:00
Hein
0f05202438 Database Authenticator with cache 2025-12-09 11:32:44 +02:00
Hein
b2115038f2 Fixed providers 2025-12-09 11:18:11 +02:00
Hein
229ee4fb28 Fixed DatabaseAuthenticator sq select 2025-12-09 11:05:48 +02:00
Hein
2cf760b979 Added a few auth shortcuts 2025-12-09 10:31:08 +02:00
Hein
0a9c107095 Fixed sqlquery bug in funcspec 2025-12-09 10:19:03 +02:00
Hein
4e2fe33b77 Fixed session_rid in funcspec 2025-12-09 10:04:39 +02:00
Hein
1baa0af0ac Config Package
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-09 09:19:56 +02:00
Hein
659b2925e4 Cursor pagnation for resolvespec 2025-12-09 08:51:15 +02:00
Hein
baca70cafc Split coverage reports
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-08 17:20:40 +02:00
Hein
ed57978620 go-version 1.24 2025-12-08 17:14:04 +02:00
Hein
97b39de88a Broken linting 2025-12-08 17:12:44 +02:00
Hein
bf955b7971 Updated version
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-08 17:08:23 +02:00
Hein
545856f8a0 Fixed linting issues 2025-12-08 17:07:13 +02:00
Hein
8d123e47bd Updated deps on workflow 2025-12-08 16:59:49 +02:00
Hein
c9eaf84125 A lot more tests 2025-12-08 16:56:48 +02:00
Hein
aeae9d7e0c Added blacklist middleware 2025-12-08 09:26:36 +02:00
Hein
2a84652dba Middleware enhancements 2025-12-08 08:47:13 +02:00
Hein
b741958895 Code sanity fixes, added middlewares
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-12-08 08:28:43 +02:00
Hein
2442589982 Better headers
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-12-03 14:42:38 +02:00
Hein
7c1bae60c9 Added meta handlers 2025-12-03 13:52:06 +02:00
Hein
06b2404c0c Remove blank array if no args 2025-12-03 12:25:51 +02:00
Hein
32007480c6 Handle cql columns as text by default 2025-12-03 12:18:33 +02:00
Hein
ab1ce869b6 Handling JSON responses in funcspec 2025-12-03 12:10:13 +02:00
Hein
ff72e04428 Added meta operation. 2025-12-03 11:59:58 +02:00
Hein
e35f8a4f14 Fix session id that is an integer. 2025-12-03 11:49:19 +02:00
Hein
5ff9a8a24e Fixed blank params on funcspec 2025-12-03 11:42:32 +02:00
Hein
81b87af6e4 Updated doc 2025-12-03 11:30:59 +02:00
Hein
f3ba314640 Refectored the mux routers. 2025-12-03 10:42:26 +02:00
Hein
93df33e274 UnderlyingRequest and UnderlyingResponseWriter
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-12-02 17:40:44 +02:00
Hein
abd045493a mux UnderlyingRequest 2025-12-02 17:34:18 +02:00
Hein
a61556d857 Added FallbackHandler 2025-12-02 17:16:34 +02:00
Hein
eaf1133575 Fixed security rules not loading 2025-12-02 16:55:12 +02:00
Hein
8172c0495d More generic security solution. 2025-12-02 16:35:08 +02:00
Hein
7a3c368121 Pass through to default handler 2025-12-02 16:09:36 +02:00
Hein
9c5c7689e9 More common handler interface 2025-12-02 15:45:24 +02:00
Hein
08050c960d Optional Authentication 2025-12-02 14:14:38 +02:00
Hein
78029fb34f Fixed formatting issues
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-12-01 14:56:30 +02:00
Hein
1643a5e920 Added cache, funcspec and implemented total cache 2025-12-01 14:40:54 +02:00
Hein
6bbe0ec8b0 Added function api prototype
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-11-24 17:00:15 +02:00
Hein
e32ec9e17e Updated the security package 2025-11-24 17:00:05 +02:00
Hein
26c175e65e Added make release to vscode tasks
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-24 10:15:23 +02:00
Hein
aa99e8e4bc Added WrapHTTPRequest 2025-11-24 10:13:48 +02:00
Hein
163593901f Huge preload chains causing errors, workaround to do seperate selects.
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-11-21 17:09:11 +02:00
Hein
1261960e97 Ability to handle multiple x-custom- headers
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-21 12:15:07 +02:00
Hein
76bbf33db2 Fixed SingleRecordAsObject true when handleRead with no id 2025-11-21 11:49:08 +02:00
Hein
02c9b96b0c Better SanitizeWhereClause 2025-11-21 11:42:01 +02:00
Hein
9a3564f05f SanitizeWhereClause with tablename on handlers. 2025-11-21 11:00:44 +02:00
Hein
a931b8cdd2 Better preloads 2025-11-21 10:41:58 +02:00
Hein
7e76977dcc Lots of refactoring, Fixes to preloads 2025-11-21 10:17:20 +02:00
Hein
7853a3f56a cql_columns parsing and recursive preloading. Also added legacy header support for limt(s,e) ,sort(x,y,-z) 2025-11-21 09:15:40 +02:00
Hein
c2e0c36c79 Restheadspec now takes parameters from query parameters and headers. Allows for backward compatibility with our old dojo clients 2025-11-21 08:56:58 +02:00
Hein
59bd709460 More reflection function to handle sql columns and get default sqlcolumn lists. 2025-11-21 08:35:46 +02:00
Hein
05962035b6 when you specify computed columns without explicitly listing base columns, you'll get all base model column
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-20 17:34:46 +02:00
Hein
1cd04b7083 Better where clause handling for preloads 2025-11-20 17:02:27 +02:00
Hein
0d4909054c Better handling of preload where conditions and a few panic changes 2025-11-20 16:50:26 +02:00
139 changed files with 34161 additions and 3460 deletions

1
.claude/readme Normal file
View File

@@ -0,0 +1 @@
We use claude for testing and document generation.

52
.env.example Normal file
View File

@@ -0,0 +1,52 @@
# ResolveSpec Environment Variables Example
# Environment variables override config file settings
# All variables are prefixed with RESOLVESPEC_
# Nested config uses underscores (e.g., server.addr -> RESOLVESPEC_SERVER_ADDR)
# Server Configuration
RESOLVESPEC_SERVER_ADDR=:8080
RESOLVESPEC_SERVER_SHUTDOWN_TIMEOUT=30s
RESOLVESPEC_SERVER_DRAIN_TIMEOUT=25s
RESOLVESPEC_SERVER_READ_TIMEOUT=10s
RESOLVESPEC_SERVER_WRITE_TIMEOUT=10s
RESOLVESPEC_SERVER_IDLE_TIMEOUT=120s
# Tracing Configuration
RESOLVESPEC_TRACING_ENABLED=false
RESOLVESPEC_TRACING_SERVICE_NAME=resolvespec
RESOLVESPEC_TRACING_SERVICE_VERSION=1.0.0
RESOLVESPEC_TRACING_ENDPOINT=http://localhost:4318/v1/traces
# Cache Configuration
RESOLVESPEC_CACHE_PROVIDER=memory
# Redis Cache (when provider=redis)
RESOLVESPEC_CACHE_REDIS_HOST=localhost
RESOLVESPEC_CACHE_REDIS_PORT=6379
RESOLVESPEC_CACHE_REDIS_PASSWORD=
RESOLVESPEC_CACHE_REDIS_DB=0
# Memcache (when provider=memcache)
# Note: For arrays, separate values with commas
RESOLVESPEC_CACHE_MEMCACHE_SERVERS=localhost:11211
RESOLVESPEC_CACHE_MEMCACHE_MAX_IDLE_CONNS=10
RESOLVESPEC_CACHE_MEMCACHE_TIMEOUT=100ms
# Logger Configuration
RESOLVESPEC_LOGGER_DEV=false
RESOLVESPEC_LOGGER_PATH=
# Middleware Configuration
RESOLVESPEC_MIDDLEWARE_RATE_LIMIT_RPS=100.0
RESOLVESPEC_MIDDLEWARE_RATE_LIMIT_BURST=200
RESOLVESPEC_MIDDLEWARE_MAX_REQUEST_SIZE=10485760
# CORS Configuration
# Note: For arrays in env vars, separate with commas
RESOLVESPEC_CORS_ALLOWED_ORIGINS=*
RESOLVESPEC_CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
RESOLVESPEC_CORS_ALLOWED_HEADERS=*
RESOLVESPEC_CORS_MAX_AGE=3600
# Database Configuration
RESOLVESPEC_DATABASE_URL=host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable

View File

@@ -1,4 +1,4 @@
name: Tests name: Build , Vet Test, and Lint
on: on:
push: push:
@@ -9,7 +9,7 @@ on:
jobs: jobs:
test: test:
name: Run Tests name: Run Vet Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
@@ -38,22 +38,6 @@ jobs:
- name: Run go vet - name: Run go vet
run: go vet ./... run: go vet ./...
- name: Run tests
run: go test -v -race -coverprofile=coverage.out -covermode=atomic ./...
- name: Display test coverage
run: go tool cover -func=coverage.out
# - name: Upload coverage to Codecov
# uses: codecov/codecov-action@v4
# with:
# file: ./coverage.out
# flags: unittests
# name: codecov-umbrella
# env:
# CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
# continue-on-error: true
lint: lint:
name: Lint Code name: Lint Code
runs-on: ubuntu-latest runs-on: ubuntu-latest

81
.github/workflows/tests.yml vendored Normal file
View File

@@ -0,0 +1,81 @@
name: Tests
on:
push:
branches: [main, develop]
pull_request:
branches: [main, develop]
jobs:
unit-tests:
name: Unit Tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: "1.24"
- name: Run unit tests
run: go test ./pkg/resolvespec ./pkg/restheadspec -v -cover
- name: Generate coverage report
run: |
go test ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage.out
go tool cover -html=coverage.out -o coverage.html
- name: Upload coverage
uses: actions/upload-artifact@v5
with:
name: coverage-report
path: coverage.html
integration-tests:
name: Integration Tests
runs-on: ubuntu-latest
services:
postgres:
image: postgres:15-alpine
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
steps:
- uses: actions/checkout@v6
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: "1.24"
- name: Create test databases
env:
PGPASSWORD: postgres
run: |
psql -h localhost -U postgres -c "CREATE DATABASE resolvespec_test;"
psql -h localhost -U postgres -c "CREATE DATABASE restheadspec_test;"
- name: Run resolvespec integration tests
env:
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
run: go test -tags=integration ./pkg/resolvespec -v -coverprofile=coverage-resolvespec-integration.out
- name: Run restheadspec integration tests
env:
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=restheadspec_test port=5432 sslmode=disable"
run: go test -tags=integration ./pkg/restheadspec -v -coverprofile=coverage-restheadspec-integration.out
- name: Generate integration coverage
env:
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
run: |
go tool cover -html=coverage-resolvespec-integration.out -o coverage-resolvespec-integration.html
go tool cover -html=coverage-restheadspec-integration.out -o coverage-restheadspec-integration.html
- name: Upload resolvespec integration coverage
uses: actions/upload-artifact@v5
with:
name: resolvespec-integration-coverage-report
path: coverage-resolvespec-integration.html
- name: Upload restheadspec integration coverage
uses: actions/upload-artifact@v5
with:
name: integration-coverage-restheadspec-report
path: coverage-restheadspec-integration

View File

@@ -71,41 +71,26 @@
}, },
"gocritic": { "gocritic": {
"enabled-checks": [ "enabled-checks": [
"appendAssign",
"assignOp",
"boolExprSimplify", "boolExprSimplify",
"builtinShadow", "builtinShadow",
"captLocal",
"caseOrder",
"defaultCaseOrder",
"dupArg",
"dupBranchBody",
"dupCase",
"dupSubExpr",
"elseif",
"emptyFallthrough", "emptyFallthrough",
"equalFold", "equalFold",
"flagName",
"ifElseChain",
"indexAlloc", "indexAlloc",
"initClause", "initClause",
"methodExprCall", "methodExprCall",
"nilValReturn", "nilValReturn",
"rangeExprCopy", "rangeExprCopy",
"rangeValCopy", "rangeValCopy",
"regexpMust",
"singleCaseSwitch",
"sloppyLen",
"stringXbytes", "stringXbytes",
"switchTrue",
"typeAssertChain", "typeAssertChain",
"typeSwitchVar",
"underef",
"unlabelStmt", "unlabelStmt",
"unnamedResult", "unnamedResult",
"unnecessaryBlock", "unnecessaryBlock",
"weakCond", "weakCond",
"yodaStyleExpr" "yodaStyleExpr"
],
"disabled-checks": [
"ifElseChain"
] ]
}, },
"revive": { "revive": {

56
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,56 @@
{
"go.testFlags": [
"-v"
],
"go.testTimeout": "300s",
"go.coverOnSave": false,
"go.coverOnSingleTest": true,
"go.coverageDecorator": {
"type": "gutter"
},
"go.testEnvVars": {
"TEST_DATABASE_URL": "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
},
"go.toolsEnvVars": {
"CGO_ENABLED": "0"
},
"go.buildTags": "",
"go.testTags": "",
"files.exclude": {
"**/.git": true,
"**/.DS_Store": true,
"**/coverage.out": true,
"**/coverage.html": true,
"**/coverage-integration.out": true,
"**/coverage-integration.html": true
},
"files.watcherExclude": {
"**/.git/objects/**": true,
"**/.git/subtree-cache/**": true,
"**/node_modules/*/**": true,
"**/.hg/store/**": true,
"**/vendor/**": true
},
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
},
"[go]": {
"editor.defaultFormatter": "golang.go",
"editor.formatOnSave": true,
"editor.insertSpaces": false,
"editor.tabSize": 4
},
"gopls": {
"ui.completion.usePlaceholders": true,
"ui.semanticTokens": true,
"ui.codelenses": {
"generate": true,
"regenerate_cgo": true,
"test": true,
"tidy": true,
"upgrade_dependency": true,
"vendor": true
}
}
}

224
.vscode/tasks.json vendored
View File

@@ -6,10 +6,10 @@
"label": "go: build workspace", "label": "go: build workspace",
"command": "build", "command": "build",
"options": { "options": {
"env": { "env": {
"CGO_ENABLED": "0" "CGO_ENABLED": "0"
}, },
"cwd": "${workspaceFolder}/bin", "cwd": "${workspaceFolder}/bin"
}, },
"args": [ "args": [
"../..." "../..."
@@ -17,12 +17,179 @@
"problemMatcher": [ "problemMatcher": [
"$go" "$go"
], ],
"group": "build", "group": "build"
},
{
"type": "shell",
"label": "test: unit tests (all)",
"command": "go test ./pkg/resolvespec ./pkg/restheadspec -v -cover",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": {
"kind": "test",
"isDefault": true
},
"presentation": {
"reveal": "always",
"panel": "shared",
"focus": true
}
},
{
"type": "shell",
"label": "test: unit tests (resolvespec)",
"command": "go test ./pkg/resolvespec -v -cover",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "test: unit tests (restheadspec)",
"command": "go test ./pkg/restheadspec -v -cover",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "test: integration tests (automated)",
"command": "./scripts/run-integration-tests.sh",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated",
"focus": true
}
},
{
"type": "shell",
"label": "test: integration tests (resolvespec only)",
"command": "./scripts/run-integration-tests.sh resolvespec",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated"
}
},
{
"type": "shell",
"label": "test: integration tests (restheadspec only)",
"command": "./scripts/run-integration-tests.sh restheadspec",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated"
}
},
{
"type": "shell",
"label": "test: coverage report",
"command": "make coverage",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "test: integration coverage report",
"command": "make coverage-integration",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "docker: start postgres",
"command": "make docker-up",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "docker: stop postgres",
"command": "make docker-down",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "docker: clean postgres data",
"command": "make clean",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"presentation": {
"reveal": "always",
"panel": "shared"
}
}, },
{ {
"type": "go", "type": "go",
"label": "go: test workspace", "label": "go: test workspace (with race)",
"command": "test", "command": "test",
"options": { "options": {
"cwd": "${workspaceFolder}" "cwd": "${workspaceFolder}"
@@ -37,13 +204,10 @@
"problemMatcher": [ "problemMatcher": [
"$go" "$go"
], ],
"group": { "group": "test",
"kind": "test",
"isDefault": true
},
"presentation": { "presentation": {
"reveal": "always", "reveal": "always",
"panel": "new" "panel": "shared"
} }
}, },
{ {
@@ -70,17 +234,45 @@
}, },
{ {
"type": "shell", "type": "shell",
"label": "go: full test suite", "label": "test: all tests (unit + integration)",
"command": "make test",
"options": {
"cwd": "${workspaceFolder}"
},
"dependsOn": [
"docker: start postgres"
],
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated",
"focus": true
}
},
{
"type": "shell",
"label": "test: full suite with checks",
"dependsOrder": "sequence", "dependsOrder": "sequence",
"dependsOn": [ "dependsOn": [
"go: vet workspace", "go: vet workspace",
"go: test workspace" "test: unit tests (all)",
"test: integration tests (automated)"
], ],
"problemMatcher": [], "problemMatcher": [],
"group": { "group": "test",
"kind": "test", "presentation": {
"isDefault": false "reveal": "always",
"panel": "dedicated"
} }
},
{
"type": "shell",
"label": "Make Release",
"problemMatcher": [],
"command": "sh ${workspaceFolder}/make_release.sh"
} }
] ]
} }

View File

@@ -1,173 +0,0 @@
# Migration Guide: Database and Router Abstraction
This guide explains how to migrate from the direct GORM/Router dependencies to the new abstracted interfaces.
## Overview of Changes
### What was changed:
1. **Database Operations**: GORM-specific code is now abstracted behind `Database` interface
2. **Router Integration**: HTTP router dependencies are abstracted behind `Router` interface
3. **Model Registry**: Models are now managed through a `ModelRegistry` interface
4. **Backward Compatibility**: Existing code continues to work with `NewAPIHandler()`
### Benefits:
- **Database Flexibility**: Switch between GORM, Bun, or other ORMs without code changes
- **Router Flexibility**: Use Gorilla Mux, Gin, Echo, or other routers
- **Better Testing**: Easy to mock database and router interactions
- **Cleaner Separation**: Business logic separated from ORM/router specifics
## Migration Path
### Option 1: No Changes Required (Backward Compatible)
Your existing code continues to work without any changes:
```go
// This still works exactly as before
handler := resolvespec.NewAPIHandler(db)
```
### Option 2: Gradual Migration to New API
#### Step 1: Use New Handler Constructor
```go
// Old way
handler := resolvespec.NewAPIHandler(gormDB)
// New way
handler := resolvespec.NewHandlerWithGORM(gormDB)
```
#### Step 2: Use Interface-based Approach
```go
// Create database adapter
dbAdapter := resolvespec.NewGormAdapter(gormDB)
// Create model registry
registry := resolvespec.NewModelRegistry()
// Register your models
registry.RegisterModel("public.users", &User{})
registry.RegisterModel("public.orders", &Order{})
// Create handler
handler := resolvespec.NewHandler(dbAdapter, registry)
```
## Switching Database Backends
### From GORM to Bun
```go
// Add bun dependency first:
// go get github.com/uptrace/bun
// Old GORM setup
gormDB, _ := gorm.Open(sqlite.Open("test.db"), &gorm.Config{})
gormAdapter := resolvespec.NewGormAdapter(gormDB)
// New Bun setup
sqlDB, _ := sql.Open("sqlite3", "test.db")
bunDB := bun.NewDB(sqlDB, sqlitedialect.New())
bunAdapter := resolvespec.NewBunAdapter(bunDB)
// Handler creation is identical
handler := resolvespec.NewHandler(bunAdapter, registry)
```
## Router Flexibility
### Current Gorilla Mux (Default)
```go
router := mux.NewRouter()
resolvespec.SetupRoutes(router, handler)
```
### BunRouter (Built-in Support)
```go
// Simple setup
router := bunrouter.New()
resolvespec.SetupBunRouterWithResolveSpec(router, handler)
// Or using adapter
routerAdapter := resolvespec.NewStandardBunRouterAdapter()
// Use routerAdapter.GetBunRouter() for the underlying router
```
### Using Router Adapters (Advanced)
```go
// For when you want router abstraction
routerAdapter := resolvespec.NewStandardRouter()
routerAdapter.RegisterRoute("/{schema}/{entity}", handlerFunc)
```
## Model Registration
### Old Way (Still Works)
```go
// Models registered through existing models package
handler.RegisterModel("public", "users", &User{})
```
### New Way (Recommended)
```go
registry := resolvespec.NewModelRegistry()
registry.RegisterModel("public.users", &User{})
registry.RegisterModel("public.orders", &Order{})
handler := resolvespec.NewHandler(dbAdapter, registry)
```
## Interface Definitions
### Database Interface
```go
type Database interface {
NewSelect() SelectQuery
NewInsert() InsertQuery
NewUpdate() UpdateQuery
NewDelete() DeleteQuery
// ... transaction methods
}
```
### Available Adapters
- `GormAdapter` - For GORM (ready to use)
- `BunAdapter` - For Bun (add dependency: `go get github.com/uptrace/bun`)
- Easy to create custom adapters for other ORMs
## Testing Benefits
### Before (Tightly Coupled)
```go
// Hard to test - requires real GORM setup
func TestHandler(t *testing.T) {
db := setupRealGormDB()
handler := resolvespec.NewAPIHandler(db)
// ... test logic
}
```
### After (Mockable)
```go
// Easy to test - mock the interfaces
func TestHandler(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := resolvespec.NewHandler(mockDB, mockRegistry)
// ... test logic with mocks
}
```
## Breaking Changes
- **None for existing code** - Full backward compatibility maintained
- New interfaces are additive, not replacing existing APIs
## Recommended Migration Timeline
1. **Phase 1**: Use existing code (no changes needed)
2. **Phase 2**: Gradually adopt new constructors (`NewHandlerWithGORM`)
3. **Phase 3**: Move to interface-based approach when needed
4. **Phase 4**: Switch database backends if desired
## Getting Help
- Check example functions in `resolvespec.go`
- Review interface definitions in `database.go`
- Examine adapter implementations for patterns

66
Makefile Normal file
View File

@@ -0,0 +1,66 @@
.PHONY: test test-unit test-integration docker-up docker-down clean
# Run all unit tests
test-unit:
@echo "Running unit tests..."
@go test ./pkg/resolvespec ./pkg/restheadspec -v -cover
# Run all integration tests (requires PostgreSQL)
test-integration:
@echo "Running integration tests..."
@go test -tags=integration ./pkg/resolvespec ./pkg/restheadspec -v
# Run all tests (unit + integration)
test: test-unit test-integration
# Start PostgreSQL for integration tests
docker-up:
@echo "Starting PostgreSQL container..."
@docker-compose up -d postgres-test
@echo "Waiting for PostgreSQL to be ready..."
@sleep 5
@echo "PostgreSQL is ready!"
# Stop PostgreSQL container
docker-down:
@echo "Stopping PostgreSQL container..."
@docker-compose down
# Clean up Docker volumes and test data
clean:
@echo "Cleaning up..."
@docker-compose down -v
@echo "Cleanup complete!"
# Run integration tests with Docker (full workflow)
test-integration-docker: docker-up
@echo "Running integration tests with Docker..."
@go test -tags=integration ./pkg/resolvespec ./pkg/restheadspec -v
@$(MAKE) docker-down
# Check test coverage
coverage:
@echo "Generating coverage report..."
@go test ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage.out
@go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report generated: coverage.html"
# Run integration tests coverage
coverage-integration:
@echo "Generating integration test coverage report..."
@go test -tags=integration ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage-integration.out
@go tool cover -html=coverage-integration.out -o coverage-integration.html
@echo "Integration coverage report generated: coverage-integration.html"
help:
@echo "Available targets:"
@echo " test-unit - Run unit tests"
@echo " test-integration - Run integration tests (requires PostgreSQL)"
@echo " test - Run all tests"
@echo " docker-up - Start PostgreSQL container"
@echo " docker-down - Stop PostgreSQL container"
@echo " test-integration-docker - Run integration tests with Docker (automated)"
@echo " clean - Clean up Docker volumes"
@echo " coverage - Generate unit test coverage report"
@echo " coverage-integration - Generate integration test coverage report"
@echo " help - Show this help message"

636
README.md
View File

@@ -1,73 +1,83 @@
# 📜 ResolveSpec 📜 # 📜 ResolveSpec 📜
![Tests](https://github.com/bitechdev/ResolveSpec/workflows/Tests/badge.svg) ![1.00](https://github.com/bitechdev/ResolveSpec/workflows/Tests/badge.svg)
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **two complementary approaches**: ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **two complementary approaches**:
1. **ResolveSpec** - Body-based API with JSON request options 1. **ResolveSpec** - Body-based API with JSON request options
2. **RestHeadSpec** - Header-based API where query options are passed via HTTP headers 2. **RestHeadSpec** - Header-based API where query options are passed via HTTP headers
3. **FuncSpec** - Header-based API to map and call API's to sql functions.
Both share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering. Both share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
**🆕 New in v2.0**: Database-agnostic architecture with support for GORM, Bun, and other ORMs. Router-flexible design works with Gorilla Mux, Gin, Echo, and more. Documentation Generated by LLMs
**🆕 New in v2.1**: RestHeadSpec (HeaderSpec) - Header-based REST API with lifecycle hooks, cursor pagination, and advanced filtering. ![1.00](./generated_slogan.webp)
![slogan](./generated_slogan.webp)
## Table of Contents ## Table of Contents
- [Features](#features) * [Features](#features)
- [Installation](#installation) * [Installation](#installation)
- [Quick Start](#quick-start) * [Quick Start](#quick-start)
- [ResolveSpec (Body-Based API)](#resolvespec-body-based-api) * [ResolveSpec (Body-Based API)](#resolvespec-body-based-api)
- [RestHeadSpec (Header-Based API)](#restheadspec-header-based-api) * [RestHeadSpec (Header-Based API)](#restheadspec-header-based-api)
- [Existing Code (Backward Compatible)](#option-1-existing-code-backward-compatible) * [Existing Code (Backward Compatible)](#option-1-existing-code-backward-compatible)
- [New Database-Agnostic API](#option-2-new-database-agnostic-api) * [New Database-Agnostic API](#option-2-new-database-agnostic-api)
- [Router Integration](#router-integration) * [Router Integration](#router-integration)
- [Migration from v1.x](#migration-from-v1x) * [Migration from v1.x](#migration-from-v1x)
- [Architecture](#architecture) * [Architecture](#architecture)
- [API Structure](#api-structure) * [API Structure](#api-structure)
- [RestHeadSpec: Header-Based API](#restheadspec-header-based-api-1) * [RestHeadSpec: Header-Based API](#restheadspec-header-based-api-1)
- [Lifecycle Hooks](#lifecycle-hooks) * [Lifecycle Hooks](#lifecycle-hooks)
- [Cursor Pagination](#cursor-pagination) * [Cursor Pagination](#cursor-pagination)
- [Response Formats](#response-formats) * [Response Formats](#response-formats)
- [Single Record as Object](#single-record-as-object-default-behavior) * [Single Record as Object](#single-record-as-object-default-behavior)
- [Example Usage](#example-usage) * [Example Usage](#example-usage)
- [Recursive CRUD Operations](#recursive-crud-operations-) * [Recursive CRUD Operations](#recursive-crud-operations-)
- [Testing](#testing) * [Testing](#testing)
- [What's New](#whats-new) * [What's New](#whats-new)
## Features ## Features
### Core Features ### Core Features
- **Dynamic Data Querying**: Select specific columns and relationships to return
- **Relationship Preloading**: Load related entities with custom column selection and filters * **Dynamic Data Querying**: Select specific columns and relationships to return
- **Complex Filtering**: Apply multiple filters with various operators * **Relationship Preloading**: Load related entities with custom column selection and filters
- **Sorting**: Multi-column sort support * **Complex Filtering**: Apply multiple filters with various operators
- **Pagination**: Built-in limit/offset and cursor-based pagination * **Sorting**: Multi-column sort support
- **Computed Columns**: Define virtual columns for complex calculations * **Pagination**: Built-in limit/offset and cursor-based pagination (both ResolveSpec and RestHeadSpec)
- **Custom Operators**: Add custom SQL conditions when needed * **Computed Columns**: Define virtual columns for complex calculations
- **🆕 Recursive CRUD Handler**: Automatically handle nested object graphs with foreign key resolution and per-record operation control via `_request` field * **Custom Operators**: Add custom SQL conditions when needed
* **🆕 Recursive CRUD Handler**: Automatically handle nested object graphs with foreign key resolution and per-record operation control via `_request` field
### Architecture (v2.0+) ### Architecture (v2.0+)
- **🆕 Database Agnostic**: Works with GORM, Bun, or any database layer through adapters
- **🆕 Router Flexible**: Integrates with Gorilla Mux, Gin, Echo, or custom routers * **🆕 Database Agnostic**: Works with GORM, Bun, or any database layer through adapters
- **🆕 Backward Compatible**: Existing code works without changes * **🆕 Router Flexible**: Integrates with Gorilla Mux, Gin, Echo, or custom routers
- **🆕 Better Testing**: Mockable interfaces for easy unit testing * **🆕 Backward Compatible**: Existing code works without changes
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
### RestHeadSpec (v2.1+) ### RestHeadSpec (v2.1+)
- **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
- **🆕 Lifecycle Hooks**: Before/after hooks for create, read, update, and delete operations * **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
- **🆕 Cursor Pagination**: Efficient cursor-based pagination with complex sort support * **🆕 Lifecycle Hooks**: Before/after hooks for create, read, update, and delete operations
- **🆕 Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible formats * **🆕 Cursor Pagination**: Efficient cursor-based pagination with complex sort support
- **🆕 Single Record as Object**: Automatically normalize single-element arrays to objects (enabled by default) * **🆕 Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible formats
- **🆕 Advanced Filtering**: Field filters, search operators, AND/OR logic, and custom SQL * **🆕 Single Record as Object**: Automatically normalize single-element arrays to objects (enabled by default)
- **🆕 Base64 Encoding**: Support for base64-encoded header values * **🆕 Advanced Filtering**: Field filters, search operators, AND/OR logic, and custom SQL
* **🆕 Base64 Encoding**: Support for base64-encoded header values
### Routing & CORS (v3.0+)
* **🆕 Explicit Route Registration**: Routes created per registered model instead of dynamic lookups
* **🆕 OPTIONS Method Support**: Full OPTIONS method support returning model metadata
* **🆕 CORS Headers**: Comprehensive CORS support with all HeadSpec headers allowed
* **🆕 Better Route Control**: Customize routes per model with more flexibility
## API Structure ## API Structure
### URL Patterns ### URL Patterns
``` ```
/[schema]/[table_or_entity]/[id] /[schema]/[table_or_entity]/[id]
/[schema]/[table_or_entity] /[schema]/[table_or_entity]
@@ -77,7 +87,7 @@ Both share the same core architecture and provide dynamic data querying, relatio
### Request Format ### Request Format
```json ```JSON
{ {
"operation": "read|create|update|delete", "operation": "read|create|update|delete",
"data": { "data": {
@@ -102,7 +112,7 @@ RestHeadSpec provides an alternative REST API approach where all query options a
### Quick Example ### Quick Example
```http ```HTTP
GET /public/users HTTP/1.1 GET /public/users HTTP/1.1
Host: api.example.com Host: api.example.com
X-Select-Fields: id,name,email,department_id X-Select-Fields: id,name,email,department_id
@@ -116,20 +126,22 @@ X-DetailApi: true
### Setup with GORM ### Setup with GORM
```go ```Go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec" import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
import "github.com/gorilla/mux" import "github.com/gorilla/mux"
// Create handler // Create handler
handler := restheadspec.NewHandlerWithGORM(db) handler := restheadspec.NewHandlerWithGORM(db)
// Register models using schema.table format // IMPORTANT: Register models BEFORE setting up routes
// Routes are created explicitly for each registered model
handler.Registry.RegisterModel("public.users", &User{}) handler.Registry.RegisterModel("public.users", &User{})
handler.Registry.RegisterModel("public.posts", &Post{}) handler.Registry.RegisterModel("public.posts", &Post{})
// Setup routes // Setup routes (creates explicit routes for each registered model)
// This replaces the old dynamic route lookup approach
router := mux.NewRouter() router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler) restheadspec.SetupMuxRoutes(router, handler, nil)
// Start server // Start server
http.ListenAndServe(":8080", router) http.ListenAndServe(":8080", router)
@@ -137,7 +149,7 @@ http.ListenAndServe(":8080", router)
### Setup with Bun ORM ### Setup with Bun ORM
```go ```Go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec" import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
import "github.com/uptrace/bun" import "github.com/uptrace/bun"
@@ -154,29 +166,70 @@ restheadspec.SetupMuxRoutes(router, handler)
### Common Headers ### Common Headers
| Header | Description | Example | | Header | Description | Example |
|--------|-------------|---------| | --------------------------- | -------------------------------------------------- | ------------------------------ |
| `X-Select-Fields` | Columns to include | `id,name,email` | | `X-Select-Fields` | Columns to include | `id,name,email` |
| `X-Not-Select-Fields` | Columns to exclude | `password,internal_notes` | | `X-Not-Select-Fields` | Columns to exclude | `password,internal_notes` |
| `X-FieldFilter-{col}` | Exact match filter | `X-FieldFilter-Status: active` | | `X-FieldFilter-{col}` | Exact match filter | `X-FieldFilter-Status: active` |
| `X-SearchFilter-{col}` | Fuzzy search (ILIKE) | `X-SearchFilter-Name: john` | | `X-SearchFilter-{col}` | Fuzzy search (ILIKE) | `X-SearchFilter-Name: john` |
| `X-SearchOp-{op}-{col}` | Filter with operator | `X-SearchOp-Gte-Age: 18` | | `X-SearchOp-{op}-{col}` | Filter with operator | `X-SearchOp-Gte-Age: 18` |
| `X-Preload` | Preload relations | `posts:id,title` | | `X-Preload` | Preload relations | `posts:id,title` |
| `X-Sort` | Sort columns | `-created_at,+name` | | `X-Sort` | Sort columns | `-created_at,+name` |
| `X-Limit` | Limit results | `50` | | `X-Limit` | Limit results | `50` |
| `X-Offset` | Offset for pagination | `100` | | `X-Offset` | Offset for pagination | `100` |
| `X-Clean-JSON` | Remove null/empty fields | `true` | | `X-Clean-JSON` | Remove null/empty fields | `true` |
| `X-Single-Record-As-Object` | Return single records as objects (default: `true`) | `false` | | `X-Single-Record-As-Object` | Return single records as objects (default: `true`) | `false` |
**Available Operators**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `contains`, `startswith`, `endswith`, `between`, `betweeninclusive`, `in`, `empty`, `notempty` **Available Operators**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `contains`, `startswith`, `endswith`, `between`, `betweeninclusive`, `in`, `empty`, `notempty`
For complete header documentation, see [pkg/restheadspec/HEADERS.md](pkg/restheadspec/HEADERS.md). For complete header documentation, see [pkg/restheadspec/HEADERS.md](pkg/restheadspec/HEADERS.md).
### CORS & OPTIONS Support
ResolveSpec and RestHeadSpec include comprehensive CORS support for cross-origin requests:
**OPTIONS Method**:
```HTTP
OPTIONS /public/users HTTP/1.1
```
Returns metadata with appropriate CORS headers:
```HTTP
Access-Control-Allow-Origin: *
Access-Control-Allow-Methods: GET, POST, OPTIONS
Access-Control-Allow-Headers: Content-Type, Authorization, X-Select-Fields, X-FieldFilter-*, ...
Access-Control-Max-Age: 86400
Access-Control-Allow-Credentials: true
```
**Key Features**:
* OPTIONS returns model metadata (same as GET metadata endpoint)
* All HTTP methods include CORS headers automatically
* OPTIONS requests don't require authentication (CORS preflight)
* Supports all HeadSpec custom headers (`X-Select-Fields`, `X-FieldFilter-*`, etc.)
* 24-hour max age to reduce preflight requests
**Configuration**:
```Go
import "github.com/bitechdev/ResolveSpec/pkg/common"
// Get default CORS config
corsConfig := common.DefaultCORSConfig()
// Customize if needed
corsConfig.AllowedOrigins = []string{"https://example.com"}
corsConfig.AllowedMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
```
### Lifecycle Hooks ### Lifecycle Hooks
RestHeadSpec supports lifecycle hooks for all CRUD operations: RestHeadSpec supports lifecycle hooks for all CRUD operations:
```go ```Go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec" import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
// Create handler // Create handler
@@ -221,27 +274,29 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
``` ```
**Available Hook Types**: **Available Hook Types**:
- `BeforeRead`, `AfterRead`
- `BeforeCreate`, `AfterCreate` * `BeforeRead`, `AfterRead`
- `BeforeUpdate`, `AfterUpdate` * `BeforeCreate`, `AfterCreate`
- `BeforeDelete`, `AfterDelete` * `BeforeUpdate`, `AfterUpdate`
* `BeforeDelete`, `AfterDelete`
**HookContext** provides: **HookContext** provides:
- `Context`: Request context
- `Handler`: Access to handler, database, and registry * `Context`: Request context
- `Schema`, `Entity`, `TableName`: Request info * `Handler`: Access to handler, database, and registry
- `Model`: The registered model type * `Schema`, `Entity`, `TableName`: Request info
- `Options`: Parsed request options (filters, sorting, etc.) * `Model`: The registered model type
- `ID`: Record ID (for single-record operations) * `Options`: Parsed request options (filters, sorting, etc.)
- `Data`: Request data (for create/update) * `ID`: Record ID (for single-record operations)
- `Result`: Operation result (for after hooks) * `Data`: Request data (for create/update)
- `Writer`: Response writer (allows hooks to modify response) * `Result`: Operation result (for after hooks)
* `Writer`: Response writer (allows hooks to modify response)
### Cursor Pagination ### Cursor Pagination
RestHeadSpec supports efficient cursor-based pagination for large datasets: RestHeadSpec supports efficient cursor-based pagination for large datasets:
```http ```HTTP
GET /public/posts HTTP/1.1 GET /public/posts HTTP/1.1
X-Sort: -created_at,+id X-Sort: -created_at,+id
X-Limit: 50 X-Limit: 50
@@ -249,20 +304,22 @@ X-Cursor-Forward: <cursor_token>
``` ```
**How it works**: **How it works**:
1. First request returns results + cursor token in response 1. First request returns results + cursor token in response
2. Subsequent requests use `X-Cursor-Forward` or `X-Cursor-Backward` 2. Subsequent requests use `X-Cursor-Forward` or `X-Cursor-Backward`
3. Cursor maintains consistent ordering even with data changes 3. Cursor maintains consistent ordering even with data changes
4. Supports complex multi-column sorting 4. Supports complex multi-column sorting
**Benefits over offset pagination**: **Benefits over offset pagination**:
- Consistent results when data changes
- Better performance for large offsets * Consistent results when data changes
- Prevents "skipped" or duplicate records * Better performance for large offsets
- Works with complex sort expressions * Prevents "skipped" or duplicate records
* Works with complex sort expressions
**Example with hooks**: **Example with hooks**:
```go ```Go
// Enable cursor pagination in a hook // Enable cursor pagination in a hook
handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error { handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookContext) error {
// For large tables, enforce cursor pagination // For large tables, enforce cursor pagination
@@ -278,7 +335,8 @@ handler.Hooks.Register(restheadspec.BeforeRead, func(ctx *restheadspec.HookConte
RestHeadSpec supports multiple response formats: RestHeadSpec supports multiple response formats:
**1. Simple Format** (`X-SimpleApi: true`): **1. Simple Format** (`X-SimpleApi: true`):
```json
```JSON
[ [
{ "id": 1, "name": "John" }, { "id": 1, "name": "John" },
{ "id": 2, "name": "Jane" } { "id": 2, "name": "Jane" }
@@ -286,7 +344,8 @@ RestHeadSpec supports multiple response formats:
``` ```
**2. Detail Format** (`X-DetailApi: true`, default): **2. Detail Format** (`X-DetailApi: true`, default):
```json
```JSON
{ {
"success": true, "success": true,
"data": [...], "data": [...],
@@ -300,7 +359,8 @@ RestHeadSpec supports multiple response formats:
``` ```
**3. Syncfusion Format** (`X-Syncfusion: true`): **3. Syncfusion Format** (`X-Syncfusion: true`):
```json
```JSON
{ {
"result": [...], "result": [...],
"count": 100 "count": 100
@@ -312,10 +372,12 @@ RestHeadSpec supports multiple response formats:
By default, RestHeadSpec automatically converts single-element arrays into objects for cleaner API responses. This provides a better developer experience when fetching individual records. By default, RestHeadSpec automatically converts single-element arrays into objects for cleaner API responses. This provides a better developer experience when fetching individual records.
**Default behavior (enabled)**: **Default behavior (enabled)**:
```http
```HTTP
GET /public/users/123 GET /public/users/123
``` ```
```json
```JSON
{ {
"success": true, "success": true,
"data": { "id": 123, "name": "John", "email": "john@example.com" } "data": { "id": 123, "name": "John", "email": "john@example.com" }
@@ -323,7 +385,8 @@ GET /public/users/123
``` ```
Instead of: Instead of:
```json
```JSON
{ {
"success": true, "success": true,
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }] "data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
@@ -331,11 +394,13 @@ Instead of:
``` ```
**To disable** (force arrays for consistency): **To disable** (force arrays for consistency):
```http
```HTTP
GET /public/users/123 GET /public/users/123
X-Single-Record-As-Object: false X-Single-Record-As-Object: false
``` ```
```json
```JSON
{ {
"success": true, "success": true,
"data": [{ "id": 123, "name": "John", "email": "john@example.com" }] "data": [{ "id": 123, "name": "John", "email": "john@example.com" }]
@@ -343,23 +408,26 @@ X-Single-Record-As-Object: false
``` ```
**How it works**: **How it works**:
- When a query returns exactly **one record**, it's returned as an object
- When a query returns **multiple records**, they're returned as an array * When a query returns exactly **one record**, it's returned as an object
- Set `X-Single-Record-As-Object: false` to always receive arrays * When a query returns **multiple records**, they're returned as an array
- Works with all response formats (simple, detail, syncfusion) * Set `X-Single-Record-As-Object: false` to always receive arrays
- Applies to both read operations and create/update returning clauses * Works with all response formats (simple, detail, syncfusion)
* Applies to both read operations and create/update returning clauses
**Benefits**: **Benefits**:
- Cleaner API responses for single-record queries
- No need to unwrap single-element arrays on the client side * Cleaner API responses for single-record queries
- Better TypeScript/type inference support * No need to unwrap single-element arrays on the client side
- Consistent with common REST API patterns * Better TypeScript/type inference support
- Backward compatible via header opt-out * Consistent with common REST API patterns
* Backward compatible via header opt-out
## Example Usage ## Example Usage
### Reading Data with Related Entities ### Reading Data with Related Entities
```json
```JSON
POST /core/users POST /core/users
{ {
"operation": "read", "operation": "read",
@@ -397,13 +465,89 @@ POST /core/users
} }
``` ```
### Cursor Pagination (ResolveSpec)
ResolveSpec now supports cursor-based pagination for efficient traversal of large datasets:
```JSON
POST /core/posts
{
"operation": "read",
"options": {
"sort": [
{
"column": "created_at",
"direction": "desc"
},
{
"column": "id",
"direction": "asc"
}
],
"limit": 50,
"cursor_forward": "12345"
}
}
```
**How it works**:
1. First request returns results + cursor token (last record's ID)
2. Subsequent requests use `cursor_forward` or `cursor_backward` in options
3. Cursor maintains consistent ordering even when data changes
4. Supports complex multi-column sorting
**Benefits over offset pagination**:
- Consistent results when data changes between requests
- Better performance for large offsets
- Prevents "skipped" or duplicate records
- Works with complex sort expressions
**Example request sequence**:
```JSON
// First request - no cursor
POST /core/posts
{
"operation": "read",
"options": {
"sort": [{"column": "created_at", "direction": "desc"}],
"limit": 50
}
}
// Response includes data + last record ID
// Use the last record's ID as cursor_forward for next page
// Second request - with cursor
POST /core/posts
{
"operation": "read",
"options": {
"sort": [{"column": "created_at", "direction": "desc"}],
"limit": 50,
"cursor_forward": "12345" // ID of last record from previous page
}
}
// For backward pagination
POST /core/posts
{
"operation": "read",
"options": {
"sort": [{"column": "created_at", "direction": "desc"}],
"limit": 50,
"cursor_backward": "12300" // ID of first record from current page
}
}
```
### Recursive CRUD Operations (🆕) ### Recursive CRUD Operations (🆕)
ResolveSpec now supports automatic handling of nested object graphs with intelligent foreign key resolution. This allows you to create, update, or delete entire object hierarchies in a single request. ResolveSpec now supports automatic handling of nested object graphs with intelligent foreign key resolution. This allows you to create, update, or delete entire object hierarchies in a single request.
#### Creating Nested Objects #### Creating Nested Objects
```json ```JSON
POST /core/users POST /core/users
{ {
"operation": "create", "operation": "create",
@@ -436,7 +580,7 @@ POST /core/users
Control individual operations for each nested record using the special `_request` field: Control individual operations for each nested record using the special `_request` field:
```json ```JSON
POST /core/users/123 POST /core/users/123
{ {
"operation": "update", "operation": "update",
@@ -462,11 +606,12 @@ POST /core/users/123
} }
``` ```
**Supported `_request` values**: **Supported** **`_request`** **values**:
- `insert` - Create a new related record
- `update` - Update an existing related record * `insert` - Create a new related record
- `delete` - Delete a related record * `update` - Update an existing related record
- `upsert` - Create if doesn't exist, update if exists * `delete` - Delete a related record
* `upsert` - Create if doesn't exist, update if exists
#### How It Works #### How It Works
@@ -478,14 +623,14 @@ POST /core/users/123
#### Benefits #### Benefits
- Reduce API round trips for complex object graphs * Reduce API round trips for complex object graphs
- Maintain referential integrity automatically * Maintain referential integrity automatically
- Simplify client-side code * Simplify client-side code
- Atomic operations with automatic rollback on errors * Atomic operations with automatic rollback on errors
## Installation ## Installation
```bash ```Shell
go get github.com/bitechdev/ResolveSpec go get github.com/bitechdev/ResolveSpec
``` ```
@@ -495,7 +640,7 @@ go get github.com/bitechdev/ResolveSpec
ResolveSpec uses JSON request bodies to specify query options: ResolveSpec uses JSON request bodies to specify query options:
```go ```Go
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec" import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
// Create handler // Create handler
@@ -522,7 +667,7 @@ resolvespec.SetupRoutes(router, handler)
RestHeadSpec uses HTTP headers for query options instead of request body: RestHeadSpec uses HTTP headers for query options instead of request body:
```go ```Go
import "github.com/bitechdev/ResolveSpec/pkg/restheadspec" import "github.com/bitechdev/ResolveSpec/pkg/restheadspec"
// Create handler with GORM // Create handler with GORM
@@ -551,7 +696,7 @@ See [RestHeadSpec: Header-Based API](#restheadspec-header-based-api-1) for compl
Your existing code continues to work without any changes: Your existing code continues to work without any changes:
```go ```Go
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec" import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
// This still works exactly as before // This still works exactly as before
@@ -569,7 +714,7 @@ ResolveSpec v2.0 introduces a new database and router abstraction layer while ma
To update your imports: To update your imports:
```bash ```Shell
# Update go.mod # Update go.mod
go mod edit -replace github.com/Warky-Devs/ResolveSpec=github.com/bitechdev/ResolveSpec@latest go mod edit -replace github.com/Warky-Devs/ResolveSpec=github.com/bitechdev/ResolveSpec@latest
go mod tidy go mod tidy
@@ -581,7 +726,7 @@ go mod tidy
Alternatively, use find and replace in your project: Alternatively, use find and replace in your project:
```bash ```Shell
find . -type f -name "*.go" -exec sed -i 's|github.com/Warky-Devs/ResolveSpec|github.com/bitechdev/ResolveSpec|g' {} + find . -type f -name "*.go" -exec sed -i 's|github.com/Warky-Devs/ResolveSpec|github.com/bitechdev/ResolveSpec|g' {} +
go mod tidy go mod tidy
``` ```
@@ -596,7 +741,7 @@ go mod tidy
### Detailed Migration Guide ### Detailed Migration Guide
For detailed migration instructions, examples, and best practices, see [MIGRATION_GUIDE.md](MIGRATION_GUIDE.md). For detailed migration instructions, examples, and best practices, see [MIGRATION\_GUIDE.md](MIGRATION_GUIDE.md).
## Architecture ## Architecture
@@ -638,22 +783,23 @@ Your Application Code
### Supported Database Layers ### Supported Database Layers
- **GORM** (default, fully supported) * **GORM** (default, fully supported)
- **Bun** (ready to use, included in dependencies) * **Bun** (ready to use, included in dependencies)
- **Custom ORMs** (implement the `Database` interface) * **Custom ORMs** (implement the `Database` interface)
### Supported Routers ### Supported Routers
- **Gorilla Mux** (built-in support with `SetupRoutes()`) * **Gorilla Mux** (built-in support with `SetupRoutes()`)
- **BunRouter** (built-in support with `SetupBunRouterWithResolveSpec()`) * **BunRouter** (built-in support with `SetupBunRouterWithResolveSpec()`)
- **Gin** (manual integration, see examples above) * **Gin** (manual integration, see examples above)
- **Echo** (manual integration, see examples above) * **Echo** (manual integration, see examples above)
- **Custom Routers** (implement request/response adapters) * **Custom Routers** (implement request/response adapters)
### Option 2: New Database-Agnostic API ### Option 2: New Database-Agnostic API
#### With GORM (Recommended Migration Path) #### With GORM (Recommended Migration Path)
```go
```Go
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec" import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
// Create database adapter // Create database adapter
@@ -669,7 +815,8 @@ handler := resolvespec.NewHandler(dbAdapter, registry)
``` ```
#### With Bun ORM #### With Bun ORM
```go
```Go
import "github.com/bitechdev/ResolveSpec/pkg/resolvespec" import "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
import "github.com/uptrace/bun" import "github.com/uptrace/bun"
@@ -684,22 +831,25 @@ handler := resolvespec.NewHandler(dbAdapter, registry)
### Router Integration ### Router Integration
#### Gorilla Mux (Built-in Support) #### Gorilla Mux (Built-in Support)
```go
```Go
import "github.com/gorilla/mux" import "github.com/gorilla/mux"
// Backward compatible way // Register models first
router := mux.NewRouter() handler.Registry.RegisterModel("public.users", &User{})
resolvespec.SetupRoutes(router, handler) handler.Registry.RegisterModel("public.posts", &Post{})
// Or manually: // Setup routes - creates explicit routes for each model
router.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) { router := mux.NewRouter()
vars := mux.Vars(r) resolvespec.SetupMuxRoutes(router, handler, nil)
handler.Handle(w, r, vars)
}).Methods("POST") // Routes created: /public/users, /public/posts, etc.
// Each route includes GET, POST, and OPTIONS methods with CORS support
``` ```
#### Gin (Custom Integration) #### Gin (Custom Integration)
```go
```Go
import "github.com/gin-gonic/gin" import "github.com/gin-gonic/gin"
func setupGin(handler *resolvespec.Handler) *gin.Engine { func setupGin(handler *resolvespec.Handler) *gin.Engine {
@@ -722,7 +872,8 @@ func setupGin(handler *resolvespec.Handler) *gin.Engine {
``` ```
#### Echo (Custom Integration) #### Echo (Custom Integration)
```go
```Go
import "github.com/labstack/echo/v4" import "github.com/labstack/echo/v4"
func setupEcho(handler *resolvespec.Handler) *echo.Echo { func setupEcho(handler *resolvespec.Handler) *echo.Echo {
@@ -745,7 +896,8 @@ func setupEcho(handler *resolvespec.Handler) *echo.Echo {
``` ```
#### BunRouter (Built-in Support) #### BunRouter (Built-in Support)
```go
```Go
import "github.com/uptrace/bunrouter" import "github.com/uptrace/bunrouter"
// Simple setup with built-in function // Simple setup with built-in function
@@ -790,7 +942,8 @@ func setupFullUptrace(bunDB *bun.DB) *bunrouter.Router {
## Configuration ## Configuration
### Model Registration ### Model Registration
```go
```Go
type User struct { type User struct {
ID uint `json:"id" gorm:"primaryKey"` ID uint `json:"id" gorm:"primaryKey"`
Name string `json:"name"` Name string `json:"name"`
@@ -804,20 +957,24 @@ handler.RegisterModel("core", "users", &User{})
## Features in Detail ## Features in Detail
### Filtering ### Filtering
Supported operators: Supported operators:
- eq: Equal
- neq: Not Equal * eq: Equal
- gt: Greater Than * neq: Not Equal
- gte: Greater Than or Equal * gt: Greater Than
- lt: Less Than * gte: Greater Than or Equal
- lte: Less Than or Equal * lt: Less Than
- like: LIKE pattern matching * lte: Less Than or Equal
- ilike: Case-insensitive LIKE * like: LIKE pattern matching
- in: IN clause * ilike: Case-insensitive LIKE
* in: IN clause
### Sorting ### Sorting
Support for multiple sort criteria with direction: Support for multiple sort criteria with direction:
```json
```JSON
"sort": [ "sort": [
{ {
"column": "created_at", "column": "created_at",
@@ -831,8 +988,10 @@ Support for multiple sort criteria with direction:
``` ```
### Computed Columns ### Computed Columns
Define virtual columns using SQL expressions: Define virtual columns using SQL expressions:
```json
```JSON
"computedColumns": [ "computedColumns": [
{ {
"name": "full_name", "name": "full_name",
@@ -845,7 +1004,7 @@ Define virtual columns using SQL expressions:
### With New Architecture (Mockable) ### With New Architecture (Mockable)
```go ```Go
import "github.com/stretchr/testify/mock" import "github.com/stretchr/testify/mock"
// Create mock database // Create mock database
@@ -880,14 +1039,14 @@ ResolveSpec uses GitHub Actions for automated testing and quality checks. The CI
The project includes automated workflows that: The project includes automated workflows that:
- **Test**: Run all tests with race detection and code coverage * **Test**: Run all tests with race detection and code coverage
- **Lint**: Check code quality with golangci-lint * **Lint**: Check code quality with golangci-lint
- **Build**: Verify the project builds successfully * **Build**: Verify the project builds successfully
- **Multi-version**: Test against multiple Go versions (1.23.x, 1.24.x) * **Multi-version**: Test against multiple Go versions (1.23.x, 1.24.x)
### Running Tests Locally ### Running Tests Locally
```bash ```Shell
# Run all tests # Run all tests
go test -v ./... go test -v ./...
@@ -905,13 +1064,13 @@ golangci-lint run
The project includes comprehensive test coverage: The project includes comprehensive test coverage:
- **Unit Tests**: Individual component testing * **Unit Tests**: Individual component testing
- **Integration Tests**: End-to-end API testing * **Integration Tests**: End-to-end API testing
- **CRUD Tests**: Standalone tests for both ResolveSpec and RestHeadSpec APIs * **CRUD Tests**: Standalone tests for both ResolveSpec and RestHeadSpec APIs
To run only the CRUD standalone tests: To run only the CRUD standalone tests:
```bash ```Shell
go test -v ./tests -run TestCRUDStandalone go test -v ./tests -run TestCRUDStandalone
``` ```
@@ -923,18 +1082,18 @@ Check the [Actions tab](../../actions) on GitHub to see the status of recent CI
Add this badge to display CI status in your fork: Add this badge to display CI status in your fork:
```markdown ```Markdown
![Tests](https://github.com/bitechdev/ResolveSpec/workflows/Tests/badge.svg) ![Tests](https://github.com/bitechdev/ResolveSpec/workflows/Tests/badge.svg)
``` ```
## Security Considerations ## Security Considerations
- Implement proper authentication and authorization * Implement proper authentication and authorization
- Validate all input parameters * Validate all input parameters
- Use prepared statements (handled by GORM/Bun/your ORM) * Use prepared statements (handled by GORM/Bun/your ORM)
- Implement rate limiting * Implement rate limiting
- Control access at schema/entity level * Control access at schema/entity level
- **New**: Database abstraction layer provides additional security through interface boundaries * **New**: Database abstraction layer provides additional security through interface boundaries
## Contributing ## Contributing
@@ -946,73 +1105,114 @@ Add this badge to display CI status in your fork:
## License ## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
## What's New ## What's New
### v2.1 (Latest) ### v3.0 (Latest - December 2025)
**Explicit Route Registration (🆕)**:
* **Breaking Change**: Routes are now created explicitly for each registered model
* **Better Control**: Customize routes per model with more flexibility
* **Registration Order**: Models must be registered BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
* **Benefits**: More flexible routing, easier to add custom routes per model, better performance
**OPTIONS Method & CORS Support (🆕)**:
* **OPTIONS Endpoint**: Full OPTIONS method support for CORS preflight requests
* **Metadata Response**: OPTIONS returns model metadata (same as GET /metadata)
* **CORS Headers**: Comprehensive CORS headers on all responses
* **Header Support**: All HeadSpec custom headers (`X-Select-Fields`, `X-FieldFilter-*`, etc.) allowed
* **No Auth on OPTIONS**: CORS preflight requests don't require authentication
* **Configurable**: Customize CORS settings via `common.CORSConfig`
**Migration Notes**:
* Update your code to register models BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
* Routes like `/public/users` are now created per registered model instead of using dynamic `/{schema}/{entity}` pattern
* This is a **breaking change** but provides better control and flexibility
### v2.1
**Cursor Pagination for ResolveSpec (🆕 Dec 9, 2025)**:
* **Cursor-Based Pagination**: Efficient cursor pagination now available in ResolveSpec (body-based API)
* **Consistent with RestHeadSpec**: Both APIs now support cursor pagination for feature parity
* **Multi-Column Sort Support**: Works seamlessly with complex sorting requirements
* **Better Performance**: Improved performance for large datasets compared to offset pagination
* **SQL Safety**: Proper SQL sanitization for cursor values
**Recursive CRUD Handler (🆕 Nov 11, 2025)**: **Recursive CRUD Handler (🆕 Nov 11, 2025)**:
- **Nested Object Graphs**: Automatically handle complex object hierarchies with parent-child relationships
- **Foreign Key Resolution**: Automatic propagation of parent IDs to child records * **Nested Object Graphs**: Automatically handle complex object hierarchies with parent-child relationships
- **Per-Record Operations**: Control create/update/delete operations per record via `_request` field * **Foreign Key Resolution**: Automatic propagation of parent IDs to child records
- **Transaction Safety**: All nested operations execute atomically within database transactions * **Per-Record Operations**: Control create/update/delete operations per record via `_request` field
- **Relationship Detection**: Automatic detection of belongsTo, hasMany, hasOne, and many2many relationships * **Transaction Safety**: All nested operations execute atomically within database transactions
- **Deep Nesting Support**: Handle relationships at any depth level * **Relationship Detection**: Automatic detection of belongsTo, hasMany, hasOne, and many2many relationships
- **Mixed Operations**: Combine insert, update, and delete operations in a single request * **Deep Nesting Support**: Handle relationships at any depth level
* **Mixed Operations**: Combine insert, update, and delete operations in a single request
**Primary Key Improvements (Nov 11, 2025)**: **Primary Key Improvements (Nov 11, 2025)**:
- **GetPrimaryKeyName**: Enhanced primary key detection for better preload and ID field handling
- **Better GORM/Bun Support**: Improved compatibility with both ORMs for primary key operations * **GetPrimaryKeyName**: Enhanced primary key detection for better preload and ID field handling
- **Computed Column Support**: Fixed computed columns functionality across handlers * **Better GORM/Bun Support**: Improved compatibility with both ORMs for primary key operations
* **Computed Column Support**: Fixed computed columns functionality across handlers
**Database Adapter Enhancements (Nov 11, 2025)**: **Database Adapter Enhancements (Nov 11, 2025)**:
- **Bun ORM Relations**: Using Scan model method for better has-many and many-to-many relationship handling
- **Model Method Support**: Enhanced query building with proper model registration * **Bun ORM Relations**: Using Scan model method for better has-many and many-to-many relationship handling
- **Improved Type Safety**: Better handling of relationship queries with type-aware scanning * **Model Method Support**: Enhanced query building with proper model registration
* **Improved Type Safety**: Better handling of relationship queries with type-aware scanning
**RestHeadSpec - Header-Based REST API**: **RestHeadSpec - Header-Based REST API**:
- **Header-Based Querying**: All query options via HTTP headers instead of request body
- **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations * **Header-Based Querying**: All query options via HTTP headers instead of request body
- **Cursor Pagination**: Efficient cursor-based pagination with complex sorting * **Lifecycle Hooks**: Before/after hooks for create, read, update, delete operations
- **Advanced Filtering**: Field filters, search operators, AND/OR logic * **Cursor Pagination**: Efficient cursor-based pagination with complex sorting
- **Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible responses * **Advanced Filtering**: Field filters, search operators, AND/OR logic
- **Single Record as Object**: Automatically return single-element arrays as objects (default, toggleable via header) * **Multiple Response Formats**: Simple, detailed, and Syncfusion-compatible responses
- **Base64 Support**: Base64-encoded header values for complex queries * **Single Record as Object**: Automatically return single-element arrays as objects (default, toggleable via header)
- **Type-Aware Filtering**: Automatic type detection and conversion for filters * **Base64 Support**: Base64-encoded header values for complex queries
* **Type-Aware Filtering**: Automatic type detection and conversion for filters
**Core Improvements**: **Core Improvements**:
- Better model registry with schema.table format support
- Enhanced validation and error handling * Better model registry with schema.table format support
- Improved reflection safety * Enhanced validation and error handling
- Fixed COUNT query issues with table aliasing * Improved reflection safety
- Better pointer handling throughout the codebase * Fixed COUNT query issues with table aliasing
- **Comprehensive Test Coverage**: Added standalone CRUD tests for both ResolveSpec and RestHeadSpec * Better pointer handling throughout the codebase
* **Comprehensive Test Coverage**: Added standalone CRUD tests for both ResolveSpec and RestHeadSpec
### v2.0 ### v2.0
**Breaking Changes**: **Breaking Changes**:
- **None!** Full backward compatibility maintained
* **None!** Full backward compatibility maintained
**New Features**: **New Features**:
- **Database Abstraction**: Support for GORM, Bun, and custom ORMs
- **Router Flexibility**: Works with any HTTP router through adapters * **Database Abstraction**: Support for GORM, Bun, and custom ORMs
- **BunRouter Integration**: Built-in support for uptrace/bunrouter * **Router Flexibility**: Works with any HTTP router through adapters
- **Better Architecture**: Clean separation of concerns with interfaces * **BunRouter Integration**: Built-in support for uptrace/bunrouter
- **Enhanced Testing**: Mockable interfaces for comprehensive testing * **Better Architecture**: Clean separation of concerns with interfaces
- **Migration Guide**: Step-by-step migration instructions * **Enhanced Testing**: Mockable interfaces for comprehensive testing
* **Migration Guide**: Step-by-step migration instructions
**Performance Improvements**: **Performance Improvements**:
- More efficient query building through interface design
- Reduced coupling between components * More efficient query building through interface design
- Better memory management with interface boundaries * Reduced coupling between components
* Better memory management with interface boundaries
## Acknowledgments ## Acknowledgments
- Inspired by REST, OData, and GraphQL's flexibility * Inspired by REST, OData, and GraphQL's flexibility
- **Header-based approach**: Inspired by REST best practices and clean API design * **Header-based approach**: Inspired by REST best practices and clean API design
- **Database Support**: [GORM](https://gorm.io) and [Bun](https://bun.uptrace.dev/) * **Database Support**: [GORM](https://gorm.io) and [Bun](https://bun.uptrace.dev/)
- **Router Support**: Gorilla Mux (built-in), BunRouter, Gin, Echo, and others through adapters * **Router Support**: Gorilla Mux (built-in), BunRouter, Gin, Echo, and others through adapters
- Slogan generated using DALL-E * Slogan generated using DALL-E
- AI used for documentation checking and correction * AI used for documentation checking and correction
- Community feedback and contributions that made v2.0 and v2.1 possible * Community feedback and contributions that made v2.0 and v2.1 possible

View File

@@ -6,8 +6,10 @@ import (
"os" "os"
"time" "time"
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry" "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/server"
"github.com/bitechdev/ResolveSpec/pkg/testmodels" "github.com/bitechdev/ResolveSpec/pkg/testmodels"
"github.com/bitechdev/ResolveSpec/pkg/resolvespec" "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
@@ -19,12 +21,27 @@ import (
) )
func main() { func main() {
// Initialize logger // Load configuration
logger.Init(true) cfgMgr := config.NewManager()
if err := cfgMgr.Load(); err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
cfg, err := cfgMgr.GetConfig()
if err != nil {
log.Fatalf("Failed to get configuration: %v", err)
}
// Initialize logger with configuration
logger.Init(cfg.Logger.Dev)
if cfg.Logger.Path != "" {
logger.UpdateLoggerPath(cfg.Logger.Path, cfg.Logger.Dev)
}
logger.Info("ResolveSpec test server starting") logger.Info("ResolveSpec test server starting")
logger.Info("Configuration loaded - Server will listen on: %s", cfg.Server.Addr)
// Initialize database // Initialize database
db, err := initDB() db, err := initDB(cfg)
if err != nil { if err != nil {
logger.Error("Failed to initialize database: %+v", err) logger.Error("Failed to initialize database: %+v", err)
os.Exit(1) os.Exit(1)
@@ -47,32 +64,54 @@ func main() {
handler.RegisterModel("public", modelNames[i], model) handler.RegisterModel("public", modelNames[i], model)
} }
// Setup routes using new SetupMuxRoutes function // Setup routes using new SetupMuxRoutes function (without authentication)
resolvespec.SetupMuxRoutes(r, handler) resolvespec.SetupMuxRoutes(r, handler, nil)
// Start server // Create graceful server with configuration
logger.Info("Starting server on :8080") srv := server.NewGracefulServer(server.Config{
if err := http.ListenAndServe(":8080", r); err != nil { Addr: cfg.Server.Addr,
Handler: r,
ShutdownTimeout: cfg.Server.ShutdownTimeout,
DrainTimeout: cfg.Server.DrainTimeout,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
})
// Start server with graceful shutdown
logger.Info("Starting server on %s", cfg.Server.Addr)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("Server failed to start: %v", err) logger.Error("Server failed to start: %v", err)
os.Exit(1) os.Exit(1)
} }
} }
func initDB() (*gorm.DB, error) { func initDB(cfg *config.Config) (*gorm.DB, error) {
// Configure GORM logger based on config
logLevel := gormlog.Info
if !cfg.Logger.Dev {
logLevel = gormlog.Warn
}
newLogger := gormlog.New( newLogger := gormlog.New(
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
gormlog.Config{ gormlog.Config{
SlowThreshold: time.Second, // Slow SQL threshold SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: gormlog.Info, // Log level LogLevel: logLevel, // Log level
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
ParameterizedQueries: true, // Don't include params in the SQL log ParameterizedQueries: true, // Don't include params in the SQL log
Colorful: true, // Disable color Colorful: cfg.Logger.Dev,
}, },
) )
// Use database URL from config if available, otherwise use default SQLite
dbURL := cfg.Database.URL
if dbURL == "" {
dbURL = "test.db"
}
// Create SQLite database // Create SQLite database
db, err := gorm.Open(sqlite.Open("test.db"), &gorm.Config{Logger: newLogger, FullSaveAssociations: false}) db, err := gorm.Open(sqlite.Open(dbURL), &gorm.Config{Logger: newLogger, FullSaveAssociations: false})
if err != nil { if err != nil {
return nil, err return nil, err
} }

41
config.yaml Normal file
View File

@@ -0,0 +1,41 @@
# ResolveSpec Test Server Configuration
# This is a minimal configuration for the test server
server:
addr: ":8080"
shutdown_timeout: 30s
drain_timeout: 25s
read_timeout: 10s
write_timeout: 10s
idle_timeout: 120s
logger:
dev: true # Enable development mode for readable logs
path: "" # Empty means log to stdout
cache:
provider: "memory"
middleware:
rate_limit_rps: 100.0
rate_limit_burst: 200
max_request_size: 10485760 # 10MB
cors:
allowed_origins:
- "*"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
- "OPTIONS"
allowed_headers:
- "*"
max_age: 3600
tracing:
enabled: false
database:
url: "" # Empty means use default SQLite (test.db)

57
config.yaml.example Normal file
View File

@@ -0,0 +1,57 @@
# ResolveSpec Configuration Example
# This file demonstrates all available configuration options
# Copy this file to config.yaml and customize as needed
server:
addr: ":8080"
shutdown_timeout: 30s
drain_timeout: 25s
read_timeout: 10s
write_timeout: 10s
idle_timeout: 120s
tracing:
enabled: false
service_name: "resolvespec"
service_version: "1.0.0"
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
cache:
provider: "memory" # Options: memory, redis, memcache
redis:
host: "localhost"
port: 6379
password: ""
db: 0
memcache:
servers:
- "localhost:11211"
max_idle_conns: 10
timeout: 100ms
logger:
dev: false
path: "" # Empty for stdout, or specify file path
middleware:
rate_limit_rps: 100.0
rate_limit_burst: 200
max_request_size: 10485760 # 10MB in bytes
cors:
allowed_origins:
- "*"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
- "OPTIONS"
allowed_headers:
- "*"
max_age: 3600
database:
url: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable"

27
docker-compose.yml Normal file
View File

@@ -0,0 +1,27 @@
services:
postgres-test:
image: postgres:15-alpine
container_name: resolvespec-postgres-test
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
ports:
- "5434:5432"
volumes:
- postgres-test-data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
timeout: 5s
retries: 5
networks:
- resolvespec-test
volumes:
postgres-test-data:
driver: local
networks:
resolvespec-test:
driver: bridge

59
go.mod
View File

@@ -1,45 +1,94 @@
module github.com/bitechdev/ResolveSpec module github.com/bitechdev/ResolveSpec
go 1.23.0 go 1.24.0
toolchain go1.24.6 toolchain go1.24.6
require ( require (
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf
github.com/glebarez/sqlite v1.11.0 github.com/glebarez/sqlite v1.11.0
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1 github.com/gorilla/mux v1.8.1
github.com/stretchr/testify v1.8.1 github.com/prometheus/client_golang v1.23.2
github.com/redis/go-redis/v9 v9.17.1
github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0 github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5 github.com/tidwall/sjson v1.2.5
github.com/uptrace/bun v1.2.15 github.com/uptrace/bun v1.2.15
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15 github.com/uptrace/bun/dialect/sqlitedialect v1.2.15
github.com/uptrace/bun/driver/sqliteshim v1.2.15 github.com/uptrace/bun/driver/sqliteshim v1.2.15
github.com/uptrace/bunrouter v1.0.23 github.com/uptrace/bunrouter v1.0.23
go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0
go.opentelemetry.io/otel/sdk v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/zap v1.27.0 go.uber.org/zap v1.27.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.6.0
gorm.io/gorm v1.25.12 gorm.io/gorm v1.25.12
) )
require ( require (
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/getsentry/sentry-go v0.40.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.28 // indirect github.com/mattn/go-sqlite3 v1.14.28 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/spf13/viper v1.21.0 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect github.com/tidwall/pretty v1.2.0 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.uber.org/multierr v1.10.0 // indirect go.uber.org/multierr v1.10.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.41.0 // indirect
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect
golang.org/x/sys v0.34.0 // indirect golang.org/x/net v0.43.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.28.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/grpc v1.75.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.66.3 // indirect modernc.org/libc v1.66.3 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect

142
go.sum
View File

@@ -1,45 +1,119 @@
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo= github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs=
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -63,31 +137,71 @@ github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IU
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4=
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc= golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc= golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM= modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=

340
pkg/cache/README.md vendored Normal file
View File

@@ -0,0 +1,340 @@
# Cache Package
A flexible, provider-based caching library for Go that supports multiple backend storage systems including in-memory, Redis, and Memcache.
## Features
- **Multiple Providers**: Support for in-memory, Redis, and Memcache backends
- **Pluggable Architecture**: Easy to add custom cache providers
- **Type-Safe API**: Automatic JSON serialization/deserialization
- **TTL Support**: Configurable time-to-live for cache entries
- **Context-Aware**: All operations support Go contexts
- **Statistics**: Built-in cache statistics and monitoring
- **Pattern Deletion**: Delete keys by pattern (Redis)
- **Lazy Loading**: GetOrSet pattern for easy cache-aside implementation
## Installation
```bash
go get github.com/bitechdev/ResolveSpec/pkg/cache
```
For Redis support:
```bash
go get github.com/redis/go-redis/v9
```
For Memcache support:
```bash
go get github.com/bradfitz/gomemcache/memcache
```
## Quick Start
### In-Memory Cache
```go
package main
import (
"context"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
func main() {
// Initialize with in-memory provider
cache.UseMemory(&cache.Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
})
defer cache.Close()
ctx := context.Background()
c := cache.GetDefaultCache()
// Store a value
type User struct {
ID int
Name string
}
user := User{ID: 1, Name: "John"}
c.Set(ctx, "user:1", user, 10*time.Minute)
// Retrieve a value
var retrieved User
c.Get(ctx, "user:1", &retrieved)
}
```
### Redis Cache
```go
cache.UseRedis(&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "",
DB: 0,
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
})
defer cache.Close()
```
### Memcache
```go
cache.UseMemcache(&cache.MemcacheConfig{
Servers: []string{"localhost:11211"},
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
})
defer cache.Close()
```
## API Reference
### Core Methods
#### Set
```go
Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error
```
Stores a value in the cache with automatic JSON serialization.
#### Get
```go
Get(ctx context.Context, key string, dest interface{}) error
```
Retrieves and deserializes a value from the cache.
#### SetBytes / GetBytes
```go
SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error
GetBytes(ctx context.Context, key string) ([]byte, error)
```
Store and retrieve raw bytes without serialization.
#### Delete
```go
Delete(ctx context.Context, key string) error
```
Removes a key from the cache.
#### DeleteByPattern
```go
DeleteByPattern(ctx context.Context, pattern string) error
```
Removes all keys matching a pattern (Redis only).
#### Clear
```go
Clear(ctx context.Context) error
```
Removes all items from the cache.
#### Exists
```go
Exists(ctx context.Context, key string) bool
```
Checks if a key exists in the cache.
#### GetOrSet
```go
GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration,
loader func() (interface{}, error)) error
```
Retrieves a value from cache, or loads and caches it if not found (lazy loading).
#### Stats
```go
Stats(ctx context.Context) (*CacheStats, error)
```
Returns cache statistics including hits, misses, and key counts.
## Provider Configuration
### In-Memory Options
```go
&cache.Options{
DefaultTTL: 5 * time.Minute, // Default expiration time
MaxSize: 10000, // Maximum number of items
EvictionPolicy: "LRU", // Eviction strategy (future)
}
```
### Redis Configuration
```go
&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "", // Optional authentication
DB: 0, // Database number
PoolSize: 10, // Connection pool size
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
}
```
### Memcache Configuration
```go
&cache.MemcacheConfig{
Servers: []string{"localhost:11211"},
MaxIdleConns: 2,
Timeout: 1 * time.Second,
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
}
```
## Advanced Usage
### Custom Provider
```go
// Create a custom provider instance
memProvider := cache.NewMemoryProvider(&cache.Options{
DefaultTTL: 10 * time.Minute,
MaxSize: 500,
})
// Initialize with custom provider
cache.Initialize(memProvider)
```
### Lazy Loading Pattern
```go
var data ExpensiveData
err := c.GetOrSet(ctx, "expensive:key", &data, 10*time.Minute, func() (interface{}, error) {
// This expensive operation only runs if key is not in cache
return computeExpensiveData(), nil
})
```
### Query API Cache
The package includes specialized functions for caching query results:
```go
// Cache a query result
api := "GetUsers"
query := "SELECT * FROM users WHERE active = true"
tablenames := "users"
total := int64(150)
cache.PutQueryAPICache(ctx, api, query, tablenames, total)
// Retrieve cached query
hash := cache.HashQueryAPICache(api, query)
cachedQuery, err := cache.FetchQueryAPICache(ctx, hash)
```
## Provider Comparison
| Feature | In-Memory | Redis | Memcache |
|---------|-----------|-------|----------|
| Persistence | No | Yes | No |
| Distributed | No | Yes | Yes |
| Pattern Delete | No | Yes | No |
| Statistics | Full | Full | Limited |
| Atomic Operations | Yes | Yes | Yes |
| Max Item Size | Memory | 512MB | 1MB |
## Best Practices
1. **Use contexts**: Always pass context for cancellation and timeout control
2. **Set appropriate TTLs**: Balance between freshness and performance
3. **Handle errors**: Cache misses and errors should be handled gracefully
4. **Monitor statistics**: Use Stats() to monitor cache performance
5. **Clean up**: Always call Close() when shutting down
6. **Pattern consistency**: Use consistent key naming patterns (e.g., "user:id:field")
## Example: Complete Application
```go
package main
import (
"context"
"log"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
type UserService struct {
cache *cache.Cache
}
func NewUserService() *UserService {
// Initialize with Redis in production, memory for testing
cache.UseRedis(&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Options: &cache.Options{
DefaultTTL: 10 * time.Minute,
},
})
return &UserService{
cache: cache.GetDefaultCache(),
}
}
func (s *UserService) GetUser(ctx context.Context, userID int) (*User, error) {
var user User
cacheKey := fmt.Sprintf("user:%d", userID)
// Try to get from cache first
err := s.cache.GetOrSet(ctx, cacheKey, &user, 15*time.Minute, func() (interface{}, error) {
// Load from database if not in cache
return s.loadUserFromDB(userID)
})
if err != nil {
return nil, err
}
return &user, nil
}
func (s *UserService) InvalidateUser(ctx context.Context, userID int) error {
cacheKey := fmt.Sprintf("user:%d", userID)
return s.cache.Delete(ctx, cacheKey)
}
func main() {
service := NewUserService()
defer cache.Close()
ctx := context.Background()
user, err := service.GetUser(ctx, 123)
if err != nil {
log.Fatal(err)
}
log.Printf("User: %+v", user)
}
```
## Performance Considerations
- **In-Memory**: Fastest but limited by RAM and not distributed
- **Redis**: Great for distributed systems, persistent, but network overhead
- **Memcache**: Good for distributed caching, simpler than Redis but less features
Choose based on your needs:
- Single instance? Use in-memory
- Need persistence or advanced features? Use Redis
- Simple distributed cache? Use Memcache
## License
See repository license.

76
pkg/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,76 @@
package cache
import (
"context"
"fmt"
"time"
)
var (
defaultCache *Cache
)
// Initialize initializes the cache with a provider.
// If not called, the package will use an in-memory provider by default.
func Initialize(provider Provider) {
defaultCache = NewCache(provider)
}
// UseMemory configures the cache to use in-memory storage.
func UseMemory(opts *Options) error {
provider := NewMemoryProvider(opts)
defaultCache = NewCache(provider)
return nil
}
// UseRedis configures the cache to use Redis storage.
func UseRedis(config *RedisConfig) error {
provider, err := NewRedisProvider(config)
if err != nil {
return fmt.Errorf("failed to initialize Redis provider: %w", err)
}
defaultCache = NewCache(provider)
return nil
}
// UseMemcache configures the cache to use Memcache storage.
func UseMemcache(config *MemcacheConfig) error {
provider, err := NewMemcacheProvider(config)
if err != nil {
return fmt.Errorf("failed to initialize Memcache provider: %w", err)
}
defaultCache = NewCache(provider)
return nil
}
// GetDefaultCache returns the default cache instance.
// Initializes with in-memory provider if not already initialized.
func GetDefaultCache() *Cache {
if defaultCache == nil {
_ = UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
})
}
return defaultCache
}
// SetDefaultCache sets a custom cache instance as the default cache.
// This is useful for testing or when you want to use a pre-configured cache instance.
func SetDefaultCache(cache *Cache) {
defaultCache = cache
}
// GetStats returns cache statistics.
func GetStats(ctx context.Context) (*CacheStats, error) {
cache := GetDefaultCache()
return cache.Stats(ctx)
}
// Close closes the cache and releases resources.
func Close() error {
if defaultCache != nil {
return defaultCache.Close()
}
return nil
}

147
pkg/cache/cache_manager.go vendored Normal file
View File

@@ -0,0 +1,147 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
)
// Cache is the main cache manager that wraps a Provider.
type Cache struct {
provider Provider
}
// NewCache creates a new cache manager with the specified provider.
func NewCache(provider Provider) *Cache {
return &Cache{
provider: provider,
}
}
// Get retrieves and deserializes a value from the cache.
func (c *Cache) Get(ctx context.Context, key string, dest interface{}) error {
data, exists := c.provider.Get(ctx, key)
if !exists {
return fmt.Errorf("key not found: %s", key)
}
if err := json.Unmarshal(data, dest); err != nil {
return fmt.Errorf("failed to deserialize: %w", err)
}
return nil
}
// GetBytes retrieves raw bytes from the cache.
func (c *Cache) GetBytes(ctx context.Context, key string) ([]byte, error) {
data, exists := c.provider.Get(ctx, key)
if !exists {
return nil, fmt.Errorf("key not found: %s", key)
}
return data, nil
}
// Set serializes and stores a value in the cache with the specified TTL.
func (c *Cache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to serialize: %w", err)
}
return c.provider.Set(ctx, key, data, ttl)
}
// SetBytes stores raw bytes in the cache with the specified TTL.
func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return c.provider.Set(ctx, key, value, ttl)
}
// Delete removes a key from the cache.
func (c *Cache) Delete(ctx context.Context, key string) error {
return c.provider.Delete(ctx, key)
}
// DeleteByPattern removes all keys matching the pattern.
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
return c.provider.DeleteByPattern(ctx, pattern)
}
// Clear removes all items from the cache.
func (c *Cache) Clear(ctx context.Context) error {
return c.provider.Clear(ctx)
}
// Exists checks if a key exists in the cache.
func (c *Cache) Exists(ctx context.Context, key string) bool {
return c.provider.Exists(ctx, key)
}
// Stats returns statistics about the cache.
func (c *Cache) Stats(ctx context.Context) (*CacheStats, error) {
return c.provider.Stats(ctx)
}
// Close closes the cache and releases any resources.
func (c *Cache) Close() error {
return c.provider.Close()
}
// GetOrSet retrieves a value from cache, or sets it if it doesn't exist.
// The loader function is called only if the key is not found in cache.
func (c *Cache) GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration, loader func() (interface{}, error)) error {
// Try to get from cache first
err := c.Get(ctx, key, dest)
if err == nil {
return nil
}
// Load the value
value, err := loader()
if err != nil {
return fmt.Errorf("loader failed: %w", err)
}
// Store in cache
if err := c.Set(ctx, key, value, ttl); err != nil {
return fmt.Errorf("failed to cache value: %w", err)
}
// Populate dest with the loaded value
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to serialize loaded value: %w", err)
}
if err := json.Unmarshal(data, dest); err != nil {
return fmt.Errorf("failed to deserialize loaded value: %w", err)
}
return nil
}
// Remember is a convenience function that caches the result of a function call.
// It's similar to GetOrSet but returns the value directly.
func (c *Cache) Remember(ctx context.Context, key string, ttl time.Duration, loader func() (interface{}, error)) (interface{}, error) {
// Try to get from cache first as bytes
data, err := c.GetBytes(ctx, key)
if err == nil {
var result interface{}
if err := json.Unmarshal(data, &result); err == nil {
return result, nil
}
}
// Load the value
value, err := loader()
if err != nil {
return nil, fmt.Errorf("loader failed: %w", err)
}
// Store in cache
if err := c.Set(ctx, key, value, ttl); err != nil {
return nil, fmt.Errorf("failed to cache value: %w", err)
}
return value, nil
}

69
pkg/cache/cache_test.go vendored Normal file
View File

@@ -0,0 +1,69 @@
package cache
import (
"context"
"testing"
"time"
)
func TestSetDefaultCache(t *testing.T) {
// Create a custom cache instance
provider := NewMemoryProvider(&Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 50,
})
customCache := NewCache(provider)
// Set it as the default
SetDefaultCache(customCache)
// Verify it's now the default
retrievedCache := GetDefaultCache()
if retrievedCache != customCache {
t.Error("SetDefaultCache did not set the cache correctly")
}
// Test that we can use it
ctx := context.Background()
testKey := "test_key"
testValue := "test_value"
err := retrievedCache.Set(ctx, testKey, testValue, time.Minute)
if err != nil {
t.Fatalf("Failed to set value: %v", err)
}
var result string
err = retrievedCache.Get(ctx, testKey, &result)
if err != nil {
t.Fatalf("Failed to get value: %v", err)
}
if result != testValue {
t.Errorf("Expected %s, got %s", testValue, result)
}
// Clean up - reset to default
SetDefaultCache(nil)
}
func TestGetDefaultCacheInitialization(t *testing.T) {
// Reset to nil first
SetDefaultCache(nil)
// GetDefaultCache should auto-initialize
cache := GetDefaultCache()
if cache == nil {
t.Error("GetDefaultCache should auto-initialize, got nil")
}
// Should be usable
ctx := context.Background()
err := cache.Set(ctx, "test", "value", time.Minute)
if err != nil {
t.Errorf("Failed to use auto-initialized cache: %v", err)
}
// Clean up
SetDefaultCache(nil)
}

266
pkg/cache/example_usage.go vendored Normal file
View File

@@ -0,0 +1,266 @@
package cache
import (
"context"
"fmt"
"log"
"time"
)
// ExampleInMemoryCache demonstrates using the in-memory cache provider.
func ExampleInMemoryCache() {
// Initialize with in-memory provider
err := UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 1000,
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store a value
type User struct {
ID int
Name string
}
user := User{ID: 1, Name: "John Doe"}
err = cache.Set(ctx, "user:1", user, 10*time.Minute)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Retrieve a value
var retrieved User
err = cache.Get(ctx, "user:1", &retrieved)
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Retrieved user: %+v\n", retrieved)
// Check if key exists
exists := cache.Exists(ctx, "user:1")
fmt.Printf("Key exists: %v\n", exists)
// Delete a key
err = cache.Delete(ctx, "user:1")
if err != nil {
_ = Close()
log.Fatal(err)
}
// Get statistics
stats, err := cache.Stats(ctx)
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Cache stats: %+v\n", stats)
_ = Close()
}
// ExampleRedisCache demonstrates using the Redis cache provider.
func ExampleRedisCache() {
// Initialize with Redis provider
err := UseRedis(&RedisConfig{
Host: "localhost",
Port: 6379,
Password: "", // Set if Redis requires authentication
DB: 0,
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store raw bytes
data := []byte("Hello, Redis!")
err = cache.SetBytes(ctx, "greeting", data, 1*time.Hour)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Retrieve raw bytes
retrieved, err := cache.GetBytes(ctx, "greeting")
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Retrieved data: %s\n", string(retrieved))
// Clear all cache
err = cache.Clear(ctx)
if err != nil {
_ = Close()
log.Fatal(err)
}
_ = Close()
}
// ExampleMemcacheCache demonstrates using the Memcache cache provider.
func ExampleMemcacheCache() {
// Initialize with Memcache provider
err := UseMemcache(&MemcacheConfig{
Servers: []string{"localhost:11211"},
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store a value
type Product struct {
ID int
Name string
Price float64
}
product := Product{ID: 100, Name: "Widget", Price: 29.99}
err = cache.Set(ctx, "product:100", product, 30*time.Minute)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Retrieve a value
var retrieved Product
err = cache.Get(ctx, "product:100", &retrieved)
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Retrieved product: %+v\n", retrieved)
_ = Close()
}
// ExampleGetOrSet demonstrates the GetOrSet pattern for lazy loading.
func ExampleGetOrSet() {
err := UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 1000,
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
cache := GetDefaultCache()
type ExpensiveData struct {
Result string
}
var data ExpensiveData
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
// This expensive operation only runs if the key is not in cache
fmt.Println("Computing expensive result...")
time.Sleep(1 * time.Second)
return ExpensiveData{Result: "computed value"}, nil
})
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Data: %+v\n", data)
// Second call will use cached value
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
fmt.Println("This won't be called!")
return ExpensiveData{Result: "new value"}, nil
})
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Cached data: %+v\n", data)
_ = Close()
}
// ExampleCustomProvider demonstrates using a custom provider.
func ExampleCustomProvider() {
// Create a custom provider
memProvider := NewMemoryProvider(&Options{
DefaultTTL: 10 * time.Minute,
MaxSize: 500,
})
// Initialize with custom provider
Initialize(memProvider)
ctx := context.Background()
cache := GetDefaultCache()
// Use the cache
err := cache.SetBytes(ctx, "key", []byte("value"), 5*time.Minute)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Clean expired items (memory provider specific)
if mp, ok := cache.provider.(*MemoryProvider); ok {
count := mp.CleanExpired(ctx)
fmt.Printf("Cleaned %d expired items\n", count)
}
_ = Close()
}
// ExampleDeleteByPattern demonstrates pattern-based deletion (Redis only).
func ExampleDeleteByPattern() {
err := UseRedis(&RedisConfig{
Host: "localhost",
Port: 6379,
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
cache := GetDefaultCache()
// Store multiple keys with a pattern
_ = cache.SetBytes(ctx, "user:1:profile", []byte("profile1"), 10*time.Minute)
_ = cache.SetBytes(ctx, "user:2:profile", []byte("profile2"), 10*time.Minute)
_ = cache.SetBytes(ctx, "user:1:settings", []byte("settings1"), 10*time.Minute)
// Delete all keys matching pattern (Redis glob pattern)
err = cache.DeleteByPattern(ctx, "user:*:profile")
if err != nil {
_ = Close()
log.Print(err)
return
}
fmt.Println("Deleted all user profile keys")
_ = Close()
}

57
pkg/cache/provider.go vendored Normal file
View File

@@ -0,0 +1,57 @@
package cache
import (
"context"
"time"
)
// Provider defines the interface that all cache providers must implement.
type Provider interface {
// Get retrieves a value from the cache by key.
// Returns nil, false if key doesn't exist or is expired.
Get(ctx context.Context, key string) ([]byte, bool)
// Set stores a value in the cache with the specified TTL.
// If ttl is 0, the item never expires.
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
// Delete removes a key from the cache.
Delete(ctx context.Context, key string) error
// DeleteByPattern removes all keys matching the pattern.
// Pattern syntax depends on the provider implementation.
DeleteByPattern(ctx context.Context, pattern string) error
// Clear removes all items from the cache.
Clear(ctx context.Context) error
// Exists checks if a key exists in the cache.
Exists(ctx context.Context, key string) bool
// Close closes the provider and releases any resources.
Close() error
// Stats returns statistics about the cache provider.
Stats(ctx context.Context) (*CacheStats, error)
}
// CacheStats contains cache statistics.
type CacheStats struct {
Hits int64 `json:"hits"`
Misses int64 `json:"misses"`
Keys int64 `json:"keys"`
ProviderType string `json:"provider_type"`
ProviderStats map[string]any `json:"provider_stats,omitempty"`
}
// Options contains configuration options for cache providers.
type Options struct {
// DefaultTTL is the default time-to-live for cache items.
DefaultTTL time.Duration
// MaxSize is the maximum number of items (for in-memory provider).
MaxSize int
// EvictionPolicy determines how items are evicted (LRU, LFU, etc).
EvictionPolicy string
}

144
pkg/cache/provider_memcache.go vendored Normal file
View File

@@ -0,0 +1,144 @@
package cache
import (
"context"
"fmt"
"time"
"github.com/bradfitz/gomemcache/memcache"
)
// MemcacheProvider is a Memcache implementation of the Provider interface.
type MemcacheProvider struct {
client *memcache.Client
options *Options
}
// MemcacheConfig contains Memcache-specific configuration.
type MemcacheConfig struct {
// Servers is a list of memcache server addresses (e.g., "localhost:11211")
Servers []string
// MaxIdleConns is the maximum number of idle connections (default: 2)
MaxIdleConns int
// Timeout for connection operations (default: 1 second)
Timeout time.Duration
// Options contains general cache options
Options *Options
}
// NewMemcacheProvider creates a new Memcache cache provider.
func NewMemcacheProvider(config *MemcacheConfig) (*MemcacheProvider, error) {
if config == nil {
config = &MemcacheConfig{
Servers: []string{"localhost:11211"},
}
}
if len(config.Servers) == 0 {
config.Servers = []string{"localhost:11211"}
}
if config.MaxIdleConns == 0 {
config.MaxIdleConns = 2
}
if config.Timeout == 0 {
config.Timeout = 1 * time.Second
}
if config.Options == nil {
config.Options = &Options{
DefaultTTL: 5 * time.Minute,
}
}
client := memcache.New(config.Servers...)
client.MaxIdleConns = config.MaxIdleConns
client.Timeout = config.Timeout
// Test connection
if err := client.Ping(); err != nil {
return nil, fmt.Errorf("failed to connect to Memcache: %w", err)
}
return &MemcacheProvider{
client: client,
options: config.Options,
}, nil
}
// Get retrieves a value from the cache by key.
func (m *MemcacheProvider) Get(ctx context.Context, key string) ([]byte, bool) {
item, err := m.client.Get(key)
if err == memcache.ErrCacheMiss {
return nil, false
}
if err != nil {
return nil, false
}
return item.Value, true
}
// Set stores a value in the cache with the specified TTL.
func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if ttl == 0 {
ttl = m.options.DefaultTTL
}
item := &memcache.Item{
Key: key,
Value: value,
Expiration: int32(ttl.Seconds()),
}
return m.client.Set(item)
}
// Delete removes a key from the cache.
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
err := m.client.Delete(key)
if err == memcache.ErrCacheMiss {
return nil
}
return err
}
// DeleteByPattern removes all keys matching the pattern.
// Note: Memcache does not support pattern-based deletion natively.
// This is a no-op for memcache and returns an error.
func (m *MemcacheProvider) DeleteByPattern(ctx context.Context, pattern string) error {
return fmt.Errorf("pattern-based deletion is not supported by Memcache")
}
// Clear removes all items from the cache.
func (m *MemcacheProvider) Clear(ctx context.Context) error {
return m.client.FlushAll()
}
// Exists checks if a key exists in the cache.
func (m *MemcacheProvider) Exists(ctx context.Context, key string) bool {
_, err := m.client.Get(key)
return err == nil
}
// Close closes the provider and releases any resources.
func (m *MemcacheProvider) Close() error {
// Memcache client doesn't have a close method
return nil
}
// Stats returns statistics about the cache provider.
// Note: Memcache provider returns limited statistics.
func (m *MemcacheProvider) Stats(ctx context.Context) (*CacheStats, error) {
stats := &CacheStats{
ProviderType: "memcache",
ProviderStats: map[string]any{
"note": "Memcache does not provide detailed statistics through the standard client",
},
}
return stats, nil
}

238
pkg/cache/provider_memory.go vendored Normal file
View File

@@ -0,0 +1,238 @@
package cache
import (
"context"
"fmt"
"regexp"
"sync"
"sync/atomic"
"time"
)
// memoryItem represents a cached item in memory.
type memoryItem struct {
Value []byte
Expiration time.Time
LastAccess time.Time
HitCount int64
}
// isExpired checks if the item has expired.
func (m *memoryItem) isExpired() bool {
if m.Expiration.IsZero() {
return false
}
return time.Now().After(m.Expiration)
}
// MemoryProvider is an in-memory implementation of the Provider interface.
type MemoryProvider struct {
mu sync.RWMutex
items map[string]*memoryItem
options *Options
hits atomic.Int64
misses atomic.Int64
}
// NewMemoryProvider creates a new in-memory cache provider.
func NewMemoryProvider(opts *Options) *MemoryProvider {
if opts == nil {
opts = &Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
}
}
return &MemoryProvider{
items: make(map[string]*memoryItem),
options: opts,
}
}
// Get retrieves a value from the cache by key.
func (m *MemoryProvider) Get(ctx context.Context, key string) ([]byte, bool) {
// First try with read lock for fast path
m.mu.RLock()
item, exists := m.items[key]
if !exists {
m.mu.RUnlock()
m.misses.Add(1)
return nil, false
}
if item.isExpired() {
m.mu.RUnlock()
// Upgrade to write lock to delete expired item
m.mu.Lock()
delete(m.items, key)
m.mu.Unlock()
m.misses.Add(1)
return nil, false
}
// Update stats and access time with write lock
value := item.Value
m.mu.RUnlock()
// Update access tracking with write lock
m.mu.Lock()
item.LastAccess = time.Now()
item.HitCount++
m.mu.Unlock()
m.hits.Add(1)
return value, true
}
// Set stores a value in the cache with the specified TTL.
func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
if ttl == 0 {
ttl = m.options.DefaultTTL
}
var expiration time.Time
if ttl > 0 {
expiration = time.Now().Add(ttl)
}
// Check max size and evict if necessary
if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize {
if _, exists := m.items[key]; !exists {
m.evictOne()
}
}
m.items[key] = &memoryItem{
Value: value,
Expiration: expiration,
LastAccess: time.Now(),
}
return nil
}
// Delete removes a key from the cache.
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.items, key)
return nil
}
// DeleteByPattern removes all keys matching the pattern.
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
m.mu.Lock()
defer m.mu.Unlock()
re, err := regexp.Compile(pattern)
if err != nil {
return fmt.Errorf("invalid pattern: %w", err)
}
for key := range m.items {
if re.MatchString(key) {
delete(m.items, key)
}
}
return nil
}
// Clear removes all items from the cache.
func (m *MemoryProvider) Clear(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
m.items = make(map[string]*memoryItem)
m.hits.Store(0)
m.misses.Store(0)
return nil
}
// Exists checks if a key exists in the cache.
func (m *MemoryProvider) Exists(ctx context.Context, key string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
item, exists := m.items[key]
if !exists {
return false
}
return !item.isExpired()
}
// Close closes the provider and releases any resources.
func (m *MemoryProvider) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
m.items = nil
return nil
}
// Stats returns statistics about the cache provider.
func (m *MemoryProvider) Stats(ctx context.Context) (*CacheStats, error) {
m.mu.RLock()
defer m.mu.RUnlock()
// Clean expired items first
validKeys := 0
for _, item := range m.items {
if !item.isExpired() {
validKeys++
}
}
return &CacheStats{
Hits: m.hits.Load(),
Misses: m.misses.Load(),
Keys: int64(validKeys),
ProviderType: "memory",
ProviderStats: map[string]any{
"capacity": m.options.MaxSize,
},
}, nil
}
// evictOne removes one item from the cache using LRU strategy.
func (m *MemoryProvider) evictOne() {
var oldestKey string
var oldestTime time.Time
for key, item := range m.items {
if item.isExpired() {
delete(m.items, key)
return
}
if oldestKey == "" || item.LastAccess.Before(oldestTime) {
oldestKey = key
oldestTime = item.LastAccess
}
}
if oldestKey != "" {
delete(m.items, oldestKey)
}
}
// CleanExpired removes all expired items from the cache.
func (m *MemoryProvider) CleanExpired(ctx context.Context) int {
m.mu.Lock()
defer m.mu.Unlock()
count := 0
for key, item := range m.items {
if item.isExpired() {
delete(m.items, key)
count++
}
}
return count
}

185
pkg/cache/provider_redis.go vendored Normal file
View File

@@ -0,0 +1,185 @@
package cache
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
// RedisProvider is a Redis implementation of the Provider interface.
type RedisProvider struct {
client *redis.Client
options *Options
}
// RedisConfig contains Redis-specific configuration.
type RedisConfig struct {
// Host is the Redis server host (default: localhost)
Host string
// Port is the Redis server port (default: 6379)
Port int
// Password for Redis authentication (optional)
Password string
// DB is the Redis database number (default: 0)
DB int
// PoolSize is the maximum number of connections (default: 10)
PoolSize int
// Options contains general cache options
Options *Options
}
// NewRedisProvider creates a new Redis cache provider.
func NewRedisProvider(config *RedisConfig) (*RedisProvider, error) {
if config == nil {
config = &RedisConfig{
Host: "localhost",
Port: 6379,
DB: 0,
}
}
if config.Host == "" {
config.Host = "localhost"
}
if config.Port == 0 {
config.Port = 6379
}
if config.PoolSize == 0 {
config.PoolSize = 10
}
if config.Options == nil {
config.Options = &Options{
DefaultTTL: 5 * time.Minute,
}
}
client := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", config.Host, config.Port),
Password: config.Password,
DB: config.DB,
PoolSize: config.PoolSize,
})
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
return &RedisProvider{
client: client,
options: config.Options,
}, nil
}
// Get retrieves a value from the cache by key.
func (r *RedisProvider) Get(ctx context.Context, key string) ([]byte, bool) {
val, err := r.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, false
}
if err != nil {
return nil, false
}
return val, true
}
// Set stores a value in the cache with the specified TTL.
func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if ttl == 0 {
ttl = r.options.DefaultTTL
}
return r.client.Set(ctx, key, value, ttl).Err()
}
// Delete removes a key from the cache.
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
return r.client.Del(ctx, key).Err()
}
// DeleteByPattern removes all keys matching the pattern.
func (r *RedisProvider) DeleteByPattern(ctx context.Context, pattern string) error {
iter := r.client.Scan(ctx, 0, pattern, 0).Iterator()
pipe := r.client.Pipeline()
count := 0
for iter.Next(ctx) {
pipe.Del(ctx, iter.Val())
count++
// Execute pipeline in batches of 100
if count%100 == 0 {
if _, err := pipe.Exec(ctx); err != nil {
return err
}
pipe = r.client.Pipeline()
}
}
if err := iter.Err(); err != nil {
return err
}
// Execute remaining commands
if count%100 != 0 {
_, err := pipe.Exec(ctx)
return err
}
return nil
}
// Clear removes all items from the cache.
func (r *RedisProvider) Clear(ctx context.Context) error {
return r.client.FlushDB(ctx).Err()
}
// Exists checks if a key exists in the cache.
func (r *RedisProvider) Exists(ctx context.Context, key string) bool {
result, err := r.client.Exists(ctx, key).Result()
if err != nil {
return false
}
return result > 0
}
// Close closes the provider and releases any resources.
func (r *RedisProvider) Close() error {
return r.client.Close()
}
// Stats returns statistics about the cache provider.
func (r *RedisProvider) Stats(ctx context.Context) (*CacheStats, error) {
info, err := r.client.Info(ctx, "stats", "keyspace").Result()
if err != nil {
return nil, fmt.Errorf("failed to get Redis stats: %w", err)
}
dbSize, err := r.client.DBSize(ctx).Result()
if err != nil {
return nil, fmt.Errorf("failed to get DB size: %w", err)
}
// Parse stats from INFO command
// This is a simplified version - you may want to parse more detailed stats
stats := &CacheStats{
Keys: dbSize,
ProviderType: "redis",
ProviderStats: map[string]any{
"info": info,
},
}
return stats, nil
}

127
pkg/cache/query_cache.go vendored Normal file
View File

@@ -0,0 +1,127 @@
package cache
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// QueryCacheKey represents the components used to build a cache key for query total count
type QueryCacheKey struct {
TableName string `json:"table_name"`
Filters []common.FilterOption `json:"filters"`
Sort []common.SortOption `json:"sort"`
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
CustomSQLOr string `json:"custom_sql_or,omitempty"`
Expand []ExpandOptionKey `json:"expand,omitempty"`
Distinct bool `json:"distinct,omitempty"`
CursorForward string `json:"cursor_forward,omitempty"`
CursorBackward string `json:"cursor_backward,omitempty"`
}
// ExpandOptionKey represents expand options for cache key
type ExpandOptionKey struct {
Relation string `json:"relation"`
Where string `json:"where,omitempty"`
}
// BuildQueryCacheKey builds a cache key from query parameters for total count caching
// This is used to cache the total count of records matching a query
func BuildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
key := QueryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
}
return hashString(string(jsonData))
}
// BuildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
// Includes expand, distinct, and cursor pagination options
func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
key := QueryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
Distinct: distinct,
CursorForward: cursorFwd,
CursorBackward: cursorBwd,
}
// Convert expand options to cache key format
if len(expandOpts) > 0 {
key.Expand = make([]ExpandOptionKey, 0, len(expandOpts))
for _, exp := range expandOpts {
// Type assert to get the expand option fields we care about for caching
if expMap, ok := exp.(map[string]interface{}); ok {
expKey := ExpandOptionKey{}
if rel, ok := expMap["relation"].(string); ok {
expKey.Relation = rel
}
if where, ok := expMap["where"].(string); ok {
expKey.Where = where
}
key.Expand = append(key.Expand, expKey)
}
}
// Sort expand options for consistent hashing (already sorted by relation name above)
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s",
tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd))
}
return hashString(string(jsonData))
}
// hashString computes SHA256 hash of a string
func hashString(s string) string {
h := sha256.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}
// GetQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
func GetQueryTotalCacheKey(hash string) string {
return fmt.Sprintf("query_total:%s", hash)
}
// CachedTotal represents a cached total count
type CachedTotal struct {
Total int `json:"total"`
}
// InvalidateCacheForTable removes all cached totals for a specific table
// This should be called when data in the table changes (insert/update/delete)
func InvalidateCacheForTable(ctx context.Context, tableName string) error {
cache := GetDefaultCache()
// Build a pattern to match all query totals for this table
// Note: This requires pattern matching support in the provider
pattern := fmt.Sprintf("query_total:*%s*", strings.ToLower(tableName))
return cache.DeleteByPattern(ctx, pattern)
}

151
pkg/cache/query_cache_test.go vendored Normal file
View File

@@ -0,0 +1,151 @@
package cache
import (
"context"
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestBuildQueryCacheKey(t *testing.T) {
filters := []common.FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
{Column: "age", Operator: "gt", Value: 25},
}
sorts := []common.SortOption{
{Column: "name", Direction: "asc"},
}
// Generate cache key
key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
// Same parameters should generate same key
key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
if key1 != key2 {
t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2)
}
// Different parameters should generate different key
key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "")
if key1 == key3 {
t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3)
}
}
func TestBuildExtendedQueryCacheKey(t *testing.T) {
filters := []common.FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
}
sorts := []common.SortOption{
{Column: "name", Direction: "asc"},
}
expandOpts := []interface{}{
map[string]interface{}{
"relation": "posts",
"where": "status = 'published'",
},
}
// Generate cache key
key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
// Same parameters should generate same key
key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
if key1 != key2 {
t.Errorf("Expected same cache keys for identical parameters")
}
// Different distinct value should generate different key
key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "")
if key1 == key3 {
t.Errorf("Expected different cache keys for different distinct values")
}
}
func TestGetQueryTotalCacheKey(t *testing.T) {
hash := "abc123"
key := GetQueryTotalCacheKey(hash)
expected := "query_total:abc123"
if key != expected {
t.Errorf("Expected %s, got %s", expected, key)
}
}
func TestCachedTotalIntegration(t *testing.T) {
// Initialize cache with memory provider for testing
UseMemory(&Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 100,
})
ctx := context.Background()
// Create test data
filters := []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
}
sorts := []common.SortOption{
{Column: "created_at", Direction: "desc"},
}
// Build cache key
cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "")
cacheKey := GetQueryTotalCacheKey(cacheKeyHash)
// Store a total count in cache
totalToCache := CachedTotal{Total: 42}
err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute)
if err != nil {
t.Fatalf("Failed to set cache: %v", err)
}
// Retrieve from cache
var cachedTotal CachedTotal
err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
if err != nil {
t.Fatalf("Failed to get from cache: %v", err)
}
if cachedTotal.Total != 42 {
t.Errorf("Expected total 42, got %d", cachedTotal.Total)
}
// Test cache miss
nonExistentKey := GetQueryTotalCacheKey("nonexistent")
var missedTotal CachedTotal
err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal)
if err == nil {
t.Errorf("Expected error for cache miss, got nil")
}
}
func TestHashString(t *testing.T) {
input1 := "test string"
input2 := "test string"
input3 := "different string"
hash1 := hashString(input1)
hash2 := hashString(input2)
hash3 := hashString(input3)
// Same input should produce same hash
if hash1 != hash2 {
t.Errorf("Expected same hash for identical inputs")
}
// Different input should produce different hash
if hash1 == hash3 {
t.Errorf("Expected different hash for different inputs")
}
// Hash should be hex encoded SHA256 (64 characters)
if len(hash1) != 64 {
t.Errorf("Expected hash length of 64, got %d", len(hash1))
}
}

View File

@@ -0,0 +1,218 @@
# Automatic Relation Loading Strategies
## Overview
**NEW:** The database adapters now **automatically** choose the optimal loading strategy by inspecting your model's relationship tags!
Simply use `PreloadRelation()` and the system automatically:
- Detects relationship type from Bun/GORM tags
- Uses **JOIN** for many-to-one and one-to-one (efficient, no duplication)
- Uses **separate query** for one-to-many and many-to-many (avoids duplication)
## How It Works
```go
// Just write this - the system handles the rest!
db.NewSelect().
Model(&links).
PreloadRelation("Provider"). // ✓ Auto-detects belongs-to → uses JOIN
PreloadRelation("Tags"). // ✓ Auto-detects has-many → uses separate query
Scan(ctx, &links)
```
### Detection Logic
The system inspects your model's struct tags:
**Bun models:**
```go
type Link struct {
Provider *Provider `bun:"rel:belongs-to"` // → Detected: belongs-to → JOIN
Tags []Tag `bun:"rel:has-many"` // → Detected: has-many → Separate query
}
```
**GORM models:**
```go
type Link struct {
ProviderID int
Provider *Provider `gorm:"foreignKey:ProviderID"` // → Detected: belongs-to → JOIN
Tags []Tag `gorm:"many2many:link_tags"` // → Detected: many-to-many → Separate query
}
```
**Type inference (fallback):**
- `[]Type` (slice) → has-many → Separate query
- `*Type` (pointer) → belongs-to → JOIN
- `Type` (struct) → belongs-to → JOIN
### What Gets Logged
Enable debug logging to see strategy selection:
```go
bunAdapter.EnableQueryDebug()
```
**Output:**
```
DEBUG: PreloadRelation 'Provider' detected as: belongs-to
INFO: Using JOIN strategy for belongs-to relation 'Provider'
DEBUG: PreloadRelation 'Links' detected as: has-many
DEBUG: Using separate query for has-many relation 'Links'
```
## Relationship Types
| Bun Tag | GORM Pattern | Field Type | Strategy | Why |
|---------|--------------|------------|----------|-----|
| `rel:has-many` | Slice field | `[]Type` | Separate Query | Avoids duplicating parent data |
| `rel:belongs-to` | `foreignKey:` | `*Type` | JOIN | Single parent, no duplication |
| `rel:has-one` | Single pointer | `*Type` | JOIN | One-to-one, no duplication |
| `rel:many-to-many` | `many2many:` | `[]Type` | Separate Query | Complex join, avoid cartesian |
## Manual Override
If you need to force a specific strategy, use `JoinRelation()`:
```go
// Force JOIN even for has-many (not recommended)
db.NewSelect().
Model(&providers).
JoinRelation("Links"). // Explicitly use JOIN
Scan(ctx, &providers)
```
## Examples
### Automatic Strategy Selection (Recommended)
```go
// Example 1: Loading parent provider for each link
// System detects belongs-to → uses JOIN automatically
db.NewSelect().
Model(&links).
PreloadRelation("Provider", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &links)
// Generated SQL: Single query with JOIN
// SELECT links.*, providers.*
// FROM links
// LEFT JOIN providers ON links.provider_id = providers.id
// WHERE providers.active = true
// Example 2: Loading child links for each provider
// System detects has-many → uses separate query automatically
db.NewSelect().
Model(&providers).
PreloadRelation("Links", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &providers)
// Generated SQL: Two queries
// Query 1: SELECT * FROM providers
// Query 2: SELECT * FROM links
// WHERE provider_id IN (1, 2, 3, ...)
// AND active = true
```
### Mixed Relationships
```go
type Order struct {
ID int
CustomerID int
Customer *Customer `bun:"rel:belongs-to"` // JOIN
Items []Item `bun:"rel:has-many"` // Separate
Invoice *Invoice `bun:"rel:has-one"` // JOIN
}
// All three handled optimally!
db.NewSelect().
Model(&orders).
PreloadRelation("Customer"). // → JOIN (many-to-one)
PreloadRelation("Items"). // → Separate (one-to-many)
PreloadRelation("Invoice"). // → JOIN (one-to-one)
Scan(ctx, &orders)
```
## Performance Benefits
### Before (Manual Strategy Selection)
```go
// You had to remember which to use:
.PreloadRelation("Provider") // Should I use PreloadRelation or JoinRelation?
.PreloadRelation("Links") // Which is more efficient here?
```
### After (Automatic Selection)
```go
// Just use PreloadRelation everywhere:
.PreloadRelation("Provider") // ✓ System uses JOIN automatically
.PreloadRelation("Links") // ✓ System uses separate query automatically
```
## Migration Guide
**No changes needed!** If you're already using `PreloadRelation()`, it now automatically optimizes:
```go
// Before: Always used separate query
.PreloadRelation("Provider") // Inefficient: extra round trip
// After: Automatic optimization
.PreloadRelation("Provider") // ✓ Now uses JOIN automatically!
```
## Implementation Details
### Supported Bun Tags
- `rel:has-many` → Separate query
- `rel:belongs-to` → JOIN
- `rel:has-one` → JOIN
- `rel:many-to-many` or `rel:m2m` → Separate query
### Supported GORM Patterns
- `many2many:` tag → Separate query
- `foreignKey:` tag → JOIN (belongs-to)
- `[]Type` slice without many2many → Separate query (has-many)
- `*Type` pointer with foreignKey → JOIN (belongs-to)
- `*Type` pointer without foreignKey → JOIN (has-one)
### Fallback Behavior
- `[]Type` (slice) → Separate query (safe default for collections)
- `*Type` or `Type` (single) → JOIN (safe default for single relations)
- Unknown → Separate query (safest default)
## Debugging
To see strategy selection in action:
```go
// Enable debug logging
bunAdapter.EnableQueryDebug() // or gormAdapter.EnableQueryDebug()
// Run your query
db.NewSelect().
Model(&records).
PreloadRelation("RelationName").
Scan(ctx, &records)
// Check logs for:
// - "PreloadRelation 'X' detected as: belongs-to"
// - "Using JOIN strategy for belongs-to relation 'X'"
// - Actual SQL queries executed
```
## Best Practices
1. **Use PreloadRelation() for everything** - Let the system optimize
2. **Define proper relationship tags** - Ensures correct detection
3. **Only use JoinRelation() for overrides** - When you know better than auto-detection
4. **Enable debug logging during development** - Verify optimal strategies are chosen
5. **Trust the system** - It's designed to choose correctly based on relationship type

View File

@@ -0,0 +1,81 @@
package database
import (
"testing"
)
func TestNormalizeTableAlias(t *testing.T) {
tests := []struct {
name string
query string
expectedAlias string
tableName string
want string
}{
{
name: "strips plausible alias from simple condition",
query: "APIL.rid_hub = 2576",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = 2576",
},
{
name: "keeps correct alias",
query: "apiproviderlink.rid_hub = 2576",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "apiproviderlink.rid_hub = 2576",
},
{
name: "strips plausible alias with multiple conditions",
query: "APIL.rid_hub = ? AND APIL.active = ?",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = ? AND active = ?",
},
{
name: "handles mixed correct and plausible aliases",
query: "APIL.rid_hub = ? AND apiproviderlink.active = ?",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = ? AND apiproviderlink.active = ?",
},
{
name: "handles parentheses",
query: "(APIL.rid_hub = ?)",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "(rid_hub = ?)",
},
{
name: "no alias in query",
query: "rid_hub = ?",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = ?",
},
{
name: "keeps reference to different table (not in current table name)",
query: "APIL.rid_hub = ?",
expectedAlias: "apiprovider",
tableName: "apiprovider",
want: "APIL.rid_hub = ?",
},
{
name: "keeps reference with short prefix that might be ambiguous",
query: "AP.rid = ?",
expectedAlias: "apiprovider",
tableName: "apiprovider",
want: "AP.rid = ?",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeTableAlias(tt.query, tt.expectedAlias, tt.tableName)
if got != tt.want {
t.Errorf("normalizeTableAlias() = %q, want %q", got, tt.want)
}
})
}
}

View File

@@ -4,7 +4,9 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"reflect"
"strings" "strings"
"time"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@@ -14,6 +16,24 @@ import (
"github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
// QueryDebugHook is a Bun query hook that logs all SQL queries including preloads
type QueryDebugHook struct{}
func (h *QueryDebugHook) BeforeQuery(ctx context.Context, event *bun.QueryEvent) context.Context {
return ctx
}
func (h *QueryDebugHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) {
query := event.Query
duration := time.Since(event.StartTime)
if event.Err != nil {
logger.Error("SQL Query Failed [%s]: %s. Error: %v", duration, query, event.Err)
} else {
logger.Debug("SQL Query Success [%s]: %s", duration, query)
}
}
// BunAdapter adapts Bun to work with our Database interface // BunAdapter adapts Bun to work with our Database interface
// This demonstrates how the abstraction works with different ORMs // This demonstrates how the abstraction works with different ORMs
type BunAdapter struct { type BunAdapter struct {
@@ -25,6 +45,20 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
return &BunAdapter{db: db} return &BunAdapter{db: db}
} }
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing
func (b *BunAdapter) EnableQueryDebug() {
b.db.AddQueryHook(&QueryDebugHook{})
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
}
// DisableQueryDebug removes all query hooks
func (b *BunAdapter) DisableQueryDebug() {
// Create a new DB without hooks
// Note: Bun doesn't have a RemoveQueryHook, so we'd need to track hooks manually
logger.Info("To disable query debug, recreate the BunAdapter without adding the hook")
}
func (b *BunAdapter) NewSelect() common.SelectQuery { func (b *BunAdapter) NewSelect() common.SelectQuery {
return &BunSelectQuery{ return &BunSelectQuery{
query: b.db.NewSelect(), query: b.db.NewSelect(),
@@ -46,8 +80,8 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery {
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunAdapter.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunAdapter.Exec", r)
} }
}() }()
result, err := b.db.ExecContext(ctx, query, args...) result, err := b.db.ExecContext(ctx, query, args...)
@@ -56,8 +90,8 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) { func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunAdapter.Query"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunAdapter.Query", r)
} }
}() }()
return b.db.NewRaw(query, args...).Scan(ctx, dest) return b.db.NewRaw(query, args...).Scan(ctx, dest)
@@ -86,8 +120,8 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error {
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) { func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunAdapter.RunInTransaction"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
} }
}() }()
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error { return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
@@ -99,12 +133,22 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
// BunSelectQuery implements SelectQuery for Bun // BunSelectQuery implements SelectQuery for Bun
type BunSelectQuery struct { type BunSelectQuery struct {
query *bun.SelectQuery query *bun.SelectQuery
db bun.IDB // Store DB connection for count queries db bun.IDB // Store DB connection for count queries
hasModel bool // Track if Model() was called hasModel bool // Track if Model() was called
schema string // Separated schema name schema string // Separated schema name
tableName string // Just the table name, without schema tableName string // Just the table name, without schema
tableAlias string tableAlias string
deferredPreloads []deferredPreload // Preloads to execute as separate queries
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
}
// deferredPreload represents a preload that will be executed as a separate query
// to avoid PostgreSQL identifier length limits
type deferredPreload struct {
relation string
apply []func(common.SelectQuery) common.SelectQuery
} }
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery { func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
@@ -138,16 +182,156 @@ func (b *BunSelectQuery) Column(columns ...string) common.SelectQuery {
} }
func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery { func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
b.query = b.query.ColumnExpr(query, args) if len(args) > 0 {
b.query = b.query.ColumnExpr(query, args)
} else {
b.query = b.query.ColumnExpr(query)
}
return b return b
} }
func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
// If we're in a JOIN context, add table prefix to unqualified columns
if b.inJoinContext && b.joinTableAlias != "" {
query = addTablePrefix(query, b.joinTableAlias)
} else if b.tableAlias != "" && b.tableName != "" {
// If we have a table alias defined, check if the query references a different alias
// This can happen in preloads where the user expects a certain alias but Bun generates another
query = normalizeTableAlias(query, b.tableAlias, b.tableName)
}
b.query = b.query.Where(query, args...) b.query = b.query.Where(query, args...)
return b return b
} }
// addTablePrefix adds a table prefix to unqualified column references
// This is used in JOIN contexts where conditions must reference the joined table
func addTablePrefix(query, tableAlias string) string {
if tableAlias == "" || query == "" {
return query
}
// Split on spaces and parentheses to find column references
parts := strings.FieldsFunc(query, func(r rune) bool {
return r == ' ' || r == '(' || r == ')' || r == ','
})
modified := query
for _, part := range parts {
// Check if this looks like an unqualified column reference
// (no dot, and likely a column name before an operator)
if !strings.Contains(part, ".") {
// Extract potential column name (before = or other operators)
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
if strings.Contains(part, op) {
colName := strings.Split(part, op)[0]
colName = strings.TrimSpace(colName)
if colName != "" && !isOperatorOrKeyword(colName) {
// Add table prefix
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
modified = strings.ReplaceAll(modified, part, prefixed)
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
}
break
}
}
}
}
return modified
}
// isOperatorOrKeyword checks if a string is likely an operator or SQL keyword
func isOperatorOrKeyword(s string) bool {
s = strings.ToUpper(strings.TrimSpace(s))
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
for _, kw := range keywords {
if s == kw {
return true
}
}
return false
}
// isAcronymMatch checks if prefix is an acronym of tableName
// For example, "apil" matches "apiproviderlink" because each letter appears in sequence
func isAcronymMatch(prefix, tableName string) bool {
if len(prefix) == 0 || len(tableName) == 0 {
return false
}
prefixIdx := 0
for i := 0; i < len(tableName) && prefixIdx < len(prefix); i++ {
if tableName[i] == prefix[prefixIdx] {
prefixIdx++
}
}
// All characters of prefix were found in sequence in tableName
return prefixIdx == len(prefix)
}
// normalizeTableAlias replaces table alias prefixes in SQL conditions
// This handles cases where a user references a table alias that doesn't match
// what Bun generates (common in preload contexts)
func normalizeTableAlias(query, expectedAlias, tableName string) string {
// Pattern: <word>.<column> where <word> might be an incorrect alias
// We'll look for patterns like "APIL.column" and either:
// 1. Remove the alias prefix if it's clearly meant for this table
// 2. Leave it alone if it might be referring to another table (JOIN/preload)
// Split on spaces and parentheses to find qualified references
parts := strings.FieldsFunc(query, func(r rune) bool {
return r == ' ' || r == '(' || r == ')' || r == ','
})
modified := query
for _, part := range parts {
// Check if this looks like a qualified column reference
if dotIndex := strings.Index(part, "."); dotIndex > 0 {
prefix := part[:dotIndex]
column := part[dotIndex+1:]
// Check if the prefix matches our expected alias or table name (case-insensitive)
if strings.EqualFold(prefix, expectedAlias) ||
strings.EqualFold(prefix, tableName) ||
strings.EqualFold(prefix, strings.ToLower(tableName)) {
// Prefix matches current table, it's safe but redundant - leave it
continue
}
// Check if the prefix could plausibly be an alias/acronym for this table
// Only strip if we're confident it's meant for this table
// For example: "APIL" could be an acronym for "apiproviderlink"
prefixLower := strings.ToLower(prefix)
tableNameLower := strings.ToLower(tableName)
// Check if prefix is a substring of table name
isSubstring := strings.Contains(tableNameLower, prefixLower) && len(prefixLower) > 2
// Check if prefix is an acronym of table name
// e.g., "APIL" matches "ApiProviderLink" (A-p-I-providerL-ink)
isAcronym := false
if !isSubstring && len(prefixLower) > 2 {
isAcronym = isAcronymMatch(prefixLower, tableNameLower)
}
if isSubstring || isAcronym {
// This looks like it could be an alias for this table - strip it
logger.Debug("Stripping plausible alias '%s' from WHERE condition, keeping just '%s'", prefix, column)
// Replace the qualified reference with just the column name
modified = strings.ReplaceAll(modified, part, column)
} else {
// Prefix doesn't match the current table at all
// It's likely referring to a different table (JOIN/preload)
// DON'T strip it - leave the qualified reference as-is
logger.Debug("Keeping qualified reference '%s' - prefix '%s' doesn't match current table '%s'", part, prefix, tableName)
}
}
}
return modified
}
func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery { func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
b.query = b.query.WhereOr(query, args...) b.query = b.query.WhereOr(query, args...)
return b return b
@@ -233,8 +417,122 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
return b return b
} }
// // shortenAliasForPostgres shortens a table/relation alias if it would exceed PostgreSQL's 63-char limit
// // when combined with typical column names
// func shortenAliasForPostgres(relationPath string) (string, bool) {
// // Convert relation path to the alias format Bun uses: dots become double underscores
// // Also convert to lowercase and use snake_case as Bun does
// parts := strings.Split(relationPath, ".")
// alias := strings.ToLower(strings.Join(parts, "__"))
// // PostgreSQL truncates identifiers to 63 chars
// // If the alias + typical column name would exceed this, we need to shorten
// // Reserve at least 30 chars for column names (e.g., "__rid_mastertype_hubtype")
// const maxAliasLength = 30
// if len(alias) > maxAliasLength {
// // Create a shortened alias using a hash of the original
// hash := md5.Sum([]byte(alias))
// hashStr := hex.EncodeToString(hash[:])[:8]
// // Keep first few chars of original for readability + hash
// prefixLen := maxAliasLength - 9 // 9 = 1 underscore + 8 hash chars
// if prefixLen > len(alias) {
// prefixLen = len(alias)
// }
// shortened := alias[:prefixLen] + "_" + hashStr
// logger.Debug("Shortened alias '%s' (%d chars) to '%s' (%d chars) to avoid PostgreSQL 63-char limit",
// alias, len(alias), shortened, len(shortened))
// return shortened, true
// }
// return alias, false
// }
// // estimateColumnAliasLength estimates the length of a column alias in a nested preload
// // Bun creates aliases like: relationChain__columnName
// func estimateColumnAliasLength(relationPath string, columnName string) int {
// relationParts := strings.Split(relationPath, ".")
// aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
// // Bun adds "__" between alias and column name
// return len(aliasChain) + 2 + len(columnName)
// }
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Auto-detect relationship type and choose optimal loading strategy
// Get the model from the query if available
model := b.query.GetModel()
if model != nil && model.Value() != nil {
relType := reflection.GetRelationType(model.Value(), relation)
// Log the detected relationship type
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
// If this is a belongs-to or has-one relation, use JOIN for better performance
if relType.ShouldUseJoin() {
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
return b.JoinRelation(relation, apply...)
}
// For has-many, many-to-many, or unknown: use separate query (safer default)
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
}
}
// Check if this relation chain would create problematic long aliases
relationParts := strings.Split(relation, ".")
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
// PostgreSQL's identifier limit is 63 characters
const postgresIdentifierLimit = 63
const safeAliasLimit = 35 // Leave room for column names
// If the alias chain is too long, defer this preload to be executed as a separate query
if len(aliasChain) > safeAliasLimit {
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
// For nested preloads (e.g., "Parent.Child"), split into separate preloads
// This avoids the long concatenated alias
if len(relationParts) > 1 {
// Load first level normally: "Parent"
firstLevel := relationParts[0]
remainingPath := strings.Join(relationParts[1:], ".")
logger.Info("Splitting nested preload: loading '%s' first, then '%s' separately",
firstLevel, remainingPath)
// Apply the first level preload normally
b.query = b.query.Relation(firstLevel)
// Store the remaining nested preload to be executed after the main query
b.deferredPreloads = append(b.deferredPreloads, deferredPreload{
relation: relation,
apply: apply,
})
return b
}
// Single level but still too long - just warn and continue
logger.Warn("Single-level preload '%s' has a very long name (%d chars). "+
"Consider renaming the field to avoid potential issues.",
relation, len(aliasChain))
}
// Normal preload handling
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery { b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
defer func() {
if r := recover(); r != nil {
err := logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
if err != nil {
return
}
}
}()
if len(apply) == 0 { if len(apply) == 0 {
return sq return sq
} }
@@ -245,6 +543,28 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
db: b.db, db: b.db,
} }
// Try to extract table name and alias from the preload model
if model := sq.GetModel(); model != nil && model.Value() != nil {
modelValue := model.Value()
// Extract table name if model implements TableNameProvider
if provider, ok := modelValue.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
wrapper.schema, wrapper.tableName = parseTableName(fullTableName)
}
// Extract table alias if model implements TableAliasProvider
if provider, ok := modelValue.(common.TableAliasProvider); ok {
wrapper.tableAlias = provider.TableAlias()
// Apply the alias to the Bun query so conditions can reference it
if wrapper.tableAlias != "" {
// Note: Bun's Relation() already sets up the table, but we can add
// the alias explicitly if needed
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
}
}
}
// Start with the interface value (not pointer) // Start with the interface value (not pointer)
current := common.SelectQuery(wrapper) current := common.SelectQuery(wrapper)
@@ -267,6 +587,36 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
return b return b
} }
func (b *BunSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// JoinRelation uses a LEFT JOIN instead of a separate query
// This is more efficient for many-to-one or one-to-one relationships
logger.Debug("JoinRelation '%s' - Using JOIN strategy with automatic WHERE prefix addition", relation)
// Wrap the apply functions to automatically add table prefix to WHERE conditions
wrappedApply := make([]func(common.SelectQuery) common.SelectQuery, 0, len(apply))
for _, fn := range apply {
if fn != nil {
wrappedFn := func(originalFn func(common.SelectQuery) common.SelectQuery) func(common.SelectQuery) common.SelectQuery {
return func(q common.SelectQuery) common.SelectQuery {
// Create a special wrapper that adds prefixes to WHERE conditions
if bunQuery, ok := q.(*BunSelectQuery); ok {
// Mark this query as being in JOIN context
bunQuery.inJoinContext = true
bunQuery.joinTableAlias = strings.ToLower(relation)
}
return originalFn(q)
}
}(fn)
wrappedApply = append(wrappedApply, wrappedFn)
}
}
// Use PreloadRelation with the wrapped functions
// Bun's Relation() will use JOIN for belongs-to and has-one relations
return b.PreloadRelation(relation, wrappedApply...)
}
func (b *BunSelectQuery) Order(order string) common.SelectQuery { func (b *BunSelectQuery) Order(order string) common.SelectQuery {
b.query = b.query.Order(order) b.query = b.query.Order(order)
return b return b
@@ -294,59 +644,222 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) { func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunSelectQuery.Scan"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunSelectQuery.Scan", r)
} }
}() }()
if dest == nil { if dest == nil {
return fmt.Errorf("destination cannot be nil") return fmt.Errorf("destination cannot be nil")
} }
return b.query.Scan(ctx, dest)
// Execute the main query first
err = b.query.Scan(ctx, dest)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
return err
}
// Execute any deferred preloads
if len(b.deferredPreloads) > 0 {
err = b.executeDeferredPreloads(ctx, dest)
if err != nil {
logger.Warn("Failed to execute deferred preloads: %v", err)
// Don't fail the whole query, just log the warning
}
}
return nil
} }
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) { func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunSelectQuery.ScanModel"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
} }
}() }()
if b.query.GetModel() == nil { if b.query.GetModel() == nil {
return fmt.Errorf("model is nil") return fmt.Errorf("model is nil")
} }
return b.query.Scan(ctx) // Execute the main query first
err = b.query.Scan(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
return err
}
// Execute any deferred preloads
if len(b.deferredPreloads) > 0 {
model := b.query.GetModel()
err = b.executeDeferredPreloads(ctx, model.Value())
if err != nil {
logger.Warn("Failed to execute deferred preloads: %v", err)
// Don't fail the whole query, just log the warning
}
}
return nil
}
// executeDeferredPreloads executes preloads that were deferred to avoid PostgreSQL identifier length limits
func (b *BunSelectQuery) executeDeferredPreloads(ctx context.Context, dest interface{}) error {
if len(b.deferredPreloads) == 0 {
return nil
}
for _, dp := range b.deferredPreloads {
err := b.executeSingleDeferredPreload(ctx, dest, dp)
if err != nil {
return fmt.Errorf("failed to execute deferred preload '%s': %w", dp.relation, err)
}
}
return nil
}
// executeSingleDeferredPreload executes a single deferred preload
// For a relation like "Parent.Child", it:
// 1. Finds all loaded Parent records in dest
// 2. Loads Child records for those Parents using a separate query (loading only "Child", not "Parent.Child")
// 3. Bun automatically assigns the Child records to the appropriate Parent.Child field
func (b *BunSelectQuery) executeSingleDeferredPreload(ctx context.Context, dest interface{}, dp deferredPreload) error {
relationParts := strings.Split(dp.relation, ".")
if len(relationParts) < 2 {
return fmt.Errorf("deferred preload must be nested (e.g., 'Parent.Child'), got: %s", dp.relation)
}
// The parent relation that was already loaded
parentRelation := relationParts[0]
// The child relation we need to load
childRelation := strings.Join(relationParts[1:], ".")
logger.Debug("Executing deferred preload: loading '%s' on already-loaded '%s'", childRelation, parentRelation)
// Use reflection to access the parent relation field(s) in the loaded records
// Then load the child relation for those parent records
destValue := reflect.ValueOf(dest)
if destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
// Handle both slice and single record
if destValue.Kind() == reflect.Slice {
// Iterate through each record in the slice
for i := 0; i < destValue.Len(); i++ {
record := destValue.Index(i)
if err := b.loadChildRelationForRecord(ctx, record, parentRelation, childRelation, dp.apply); err != nil {
logger.Warn("Failed to load child relation '%s' for record %d: %v", childRelation, i, err)
// Continue with other records
}
}
} else {
// Single record
if err := b.loadChildRelationForRecord(ctx, destValue, parentRelation, childRelation, dp.apply); err != nil {
return fmt.Errorf("failed to load child relation '%s': %w", childRelation, err)
}
}
return nil
}
// loadChildRelationForRecord loads a child relation for a single parent record
func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record reflect.Value, parentRelation, childRelation string, apply []func(common.SelectQuery) common.SelectQuery) error {
// Ensure we're working with the actual struct value, not a pointer
if record.Kind() == reflect.Ptr {
record = record.Elem()
}
// Get the parent relation field
parentField := record.FieldByName(parentRelation)
if !parentField.IsValid() {
// Parent relation field doesn't exist
logger.Debug("Parent relation field '%s' not found in record", parentRelation)
return nil
}
// Check if the parent field is nil (for pointer fields)
if parentField.Kind() == reflect.Ptr && parentField.IsNil() {
// Parent relation not loaded or nil, skip
logger.Debug("Parent relation field '%s' is nil, skipping child preload", parentRelation)
return nil
}
// Get the interface value to pass to Bun
parentValue := parentField.Interface()
// Load the child relation on the parent record
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
return b.db.NewSelect().
Model(parentValue).
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
// Apply any custom query modifications
if len(apply) > 0 {
wrapper := &BunSelectQuery{query: sq, db: b.db}
current := common.SelectQuery(wrapper)
for _, fn := range apply {
if fn != nil {
current = fn(current)
}
}
if finalBun, ok := current.(*BunSelectQuery); ok {
return finalBun.query
}
}
return sq
}).
Scan(ctx)
} }
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) { func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunSelectQuery.Count"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunSelectQuery.Count", r)
count = 0 count = 0
} }
}() }()
// If Model() was set, use bun's native Count() which works properly // If Model() was set, use bun's native Count() which works properly
if b.hasModel { if b.hasModel {
count, err := b.query.Count(ctx) count, err := b.query.Count(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
return count, err return count, err
} }
// Otherwise, wrap as subquery to avoid "Model(nil)" error // Otherwise, wrap as subquery to avoid "Model(nil)" error
// This is needed when only Table() is set without a model // This is needed when only Table() is set without a model
err = b.db.NewSelect(). countQuery := b.db.NewSelect().
TableExpr("(?) AS subquery", b.query). TableExpr("(?) AS subquery", b.query).
ColumnExpr("COUNT(*)"). ColumnExpr("COUNT(*)")
Scan(ctx, &count) err = countQuery.Scan(ctx, &count)
if err != nil {
// Log SQL string for debugging
sqlStr := countQuery.String()
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
}
return count, err return count, err
} }
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) { func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunSelectQuery.Exists"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunSelectQuery.Exists", r)
exists = false exists = false
} }
}() }()
return b.query.Exists(ctx) exists, err = b.query.Exists(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
return exists, err
} }
// BunInsertQuery implements InsertQuery for Bun // BunInsertQuery implements InsertQuery for Bun
@@ -392,11 +905,11 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) { func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunInsertQuery.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunInsertQuery.Exec", r)
} }
}() }()
if b.values != nil && len(b.values) > 0 { if len(b.values) > 0 {
if !b.hasModel { if !b.hasModel {
// If no model was set, use the values map as the model // If no model was set, use the values map as the model
// Bun can insert map[string]interface{} directly // Bun can insert map[string]interface{} directly
@@ -478,11 +991,16 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) { func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunUpdateQuery.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunUpdateQuery.Exec", r)
} }
}() }()
result, err := b.query.Exec(ctx) result, err := b.query.Exec(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
}
return &BunResult{result: result}, err return &BunResult{result: result}, err
} }
@@ -508,11 +1026,16 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) { func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("BunDeleteQuery.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("BunDeleteQuery.Exec", r)
} }
}() }()
result, err := b.query.Exec(ctx) result, err := b.query.Exec(ctx)
if err != nil {
// Log SQL string for debugging
sqlStr := b.query.String()
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
}
return &BunResult{result: result}, err return &BunResult{result: result}, err
} }

View File

@@ -23,6 +23,22 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter {
return &GormAdapter{db: db} return &GormAdapter{db: db}
} }
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
g.db = g.db.Debug()
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
return g
}
// DisableQueryDebug disables query debugging
func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
// GORM's Debug() creates a new session, so we need to get the base DB
// This is a simplified implementation
logger.Info("GORM debug mode - create a new adapter without Debug() to disable")
return g
}
func (g *GormAdapter) NewSelect() common.SelectQuery { func (g *GormAdapter) NewSelect() common.SelectQuery {
return &GormSelectQuery{db: g.db} return &GormSelectQuery{db: g.db}
} }
@@ -41,8 +57,8 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) { func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormAdapter.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormAdapter.Exec", r)
} }
}() }()
result := g.db.WithContext(ctx).Exec(query, args...) result := g.db.WithContext(ctx).Exec(query, args...)
@@ -51,8 +67,8 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) { func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormAdapter.Query"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormAdapter.Query", r)
} }
}() }()
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
@@ -76,8 +92,8 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) { func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormAdapter.RunInTransaction"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
} }
}() }()
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
@@ -88,10 +104,12 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
// GormSelectQuery implements SelectQuery for GORM // GormSelectQuery implements SelectQuery for GORM
type GormSelectQuery struct { type GormSelectQuery struct {
db *gorm.DB db *gorm.DB
schema string // Separated schema name schema string // Separated schema name
tableName string // Just the table name, without schema tableName string // Just the table name, without schema
tableAlias string tableAlias string
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
} }
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery { func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
@@ -125,15 +143,71 @@ func (g *GormSelectQuery) Column(columns ...string) common.SelectQuery {
} }
func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
g.db = g.db.Select(query, args...) if len(args) > 0 {
g.db = g.db.Select(query, args...)
} else {
g.db = g.db.Select(query)
}
return g return g
} }
func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
// If we're in a JOIN context, add table prefix to unqualified columns
if g.inJoinContext && g.joinTableAlias != "" {
query = addTablePrefixGorm(query, g.joinTableAlias)
}
g.db = g.db.Where(query, args...) g.db = g.db.Where(query, args...)
return g return g
} }
// addTablePrefixGorm adds a table prefix to unqualified column references (GORM version)
func addTablePrefixGorm(query, tableAlias string) string {
if tableAlias == "" || query == "" {
return query
}
// Split on spaces and parentheses to find column references
parts := strings.FieldsFunc(query, func(r rune) bool {
return r == ' ' || r == '(' || r == ')' || r == ','
})
modified := query
for _, part := range parts {
// Check if this looks like an unqualified column reference
if !strings.Contains(part, ".") {
// Extract potential column name (before = or other operators)
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
if strings.Contains(part, op) {
colName := strings.Split(part, op)[0]
colName = strings.TrimSpace(colName)
if colName != "" && !isOperatorOrKeywordGorm(colName) {
// Add table prefix
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
modified = strings.ReplaceAll(modified, part, prefixed)
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
}
break
}
}
}
}
return modified
}
// isOperatorOrKeywordGorm checks if a string is likely an operator or SQL keyword (GORM version)
func isOperatorOrKeywordGorm(s string) bool {
s = strings.ToUpper(strings.TrimSpace(s))
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
for _, kw := range keywords {
if s == kw {
return true
}
}
return false
}
func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
g.db = g.db.Or(query, args...) g.db = g.db.Or(query, args...)
return g return g
@@ -217,6 +291,27 @@ func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) co
} }
func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery { func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Auto-detect relationship type and choose optimal loading strategy
// Get the model from GORM's statement if available
if g.db.Statement != nil && g.db.Statement.Model != nil {
relType := reflection.GetRelationType(g.db.Statement.Model, relation)
// Log the detected relationship type
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
// If this is a belongs-to or has-one relation, use JOIN for better performance
if relType.ShouldUseJoin() {
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
return g.JoinRelation(relation, apply...)
}
// For has-many, many-to-many, or unknown: use separate query (safer default)
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
}
}
// Use GORM's Preload (separate query strategy)
g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB { g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB {
if len(apply) == 0 { if len(apply) == 0 {
return db return db
@@ -246,6 +341,42 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
return g return g
} }
func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// JoinRelation uses a JOIN instead of a separate preload query
// This is more efficient for many-to-one or one-to-one relationships
// as it avoids additional round trips to the database
// GORM's Joins() method forces a JOIN for the preload
logger.Debug("JoinRelation '%s' - Using GORM Joins() with automatic WHERE prefix addition", relation)
g.db = g.db.Joins(relation, func(db *gorm.DB) *gorm.DB {
if len(apply) == 0 {
return db
}
wrapper := &GormSelectQuery{
db: db,
inJoinContext: true, // Mark as JOIN context
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
}
current := common.SelectQuery(wrapper)
for _, fn := range apply {
if fn != nil {
current = fn(current)
}
}
if finalGorm, ok := current.(*GormSelectQuery); ok {
return finalGorm.db
}
return db
})
return g
}
func (g *GormSelectQuery) Order(order string) common.SelectQuery { func (g *GormSelectQuery) Order(order string) common.SelectQuery {
g.db = g.db.Order(order) g.db = g.db.Order(order)
return g return g
@@ -273,46 +404,76 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) { func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormSelectQuery.Scan"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormSelectQuery.Scan", r)
} }
}() }()
return g.db.WithContext(ctx).Find(dest).Error err = g.db.WithContext(ctx).Find(dest).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Find(dest)
})
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
}
return err
} }
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) { func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormSelectQuery.ScanModel"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
} }
}() }()
if g.db.Statement.Model == nil { if g.db.Statement.Model == nil {
return fmt.Errorf("ScanModel requires Model() to be set before scanning") return fmt.Errorf("ScanModel requires Model() to be set before scanning")
} }
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Find(g.db.Statement.Model)
})
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
}
return err
} }
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) { func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormSelectQuery.Count"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormSelectQuery.Count", r)
count = 0 count = 0
} }
}() }()
var count64 int64 var count64 int64
err = g.db.WithContext(ctx).Count(&count64).Error err = g.db.WithContext(ctx).Count(&count64).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Count(&count64)
})
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
return int(count64), err return int(count64), err
} }
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) { func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormSelectQuery.Exists"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormSelectQuery.Exists", r)
exists = false exists = false
} }
}() }()
var count int64 var count int64
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Limit(1).Count(&count)
})
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
return count > 0, err return count > 0, err
} }
@@ -354,8 +515,8 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) { func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormInsertQuery.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormInsertQuery.Exec", r)
} }
}() }()
var result *gorm.DB var result *gorm.DB
@@ -446,11 +607,18 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) { func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormUpdateQuery.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormUpdateQuery.Exec", r)
} }
}() }()
result := g.db.WithContext(ctx).Updates(g.updates) result := g.db.WithContext(ctx).Updates(g.updates)
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Updates(g.updates)
})
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
return &GormResult{result: result}, result.Error return &GormResult{result: result}, result.Error
} }
@@ -478,11 +646,18 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) { func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() { defer func() {
if panicErr := logger.RecoverPanic("GormDeleteQuery.Exec"); panicErr != nil { if r := recover(); r != nil {
err = panicErr err = logger.HandlePanic("GormDeleteQuery.Exec", r)
} }
}() }()
result := g.db.WithContext(ctx).Delete(g.model) result := g.db.WithContext(ctx).Delete(g.model)
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Delete(g.model)
})
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
return &GormResult{result: result}, result.Error return &GormResult{result: result}, result.Error
} }

View File

@@ -6,6 +6,7 @@ import (
"github.com/uptrace/bunrouter" "github.com/uptrace/bunrouter"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
) )
// BunRouterAdapter adapts uptrace/bunrouter to work with our Router interface // BunRouterAdapter adapts uptrace/bunrouter to work with our Router interface
@@ -35,7 +36,11 @@ func (b *BunRouterAdapter) HandleFunc(pattern string, handler common.HTTPHandler
func (b *BunRouterAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) { func (b *BunRouterAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) {
// This method would be used when we need to serve through our interface // This method would be used when we need to serve through our interface
// For now, we'll work directly with the underlying router // For now, we'll work directly with the underlying router
panic("ServeHTTP not implemented - use GetBunRouter() for direct access") w.WriteHeader(http.StatusNotImplemented)
_, err := w.Write([]byte(`{"error":"ServeHTTP not implemented - use GetBunRouter() for direct access"}`))
if err != nil {
logger.Warn("Failed to write. %v", err)
}
} }
// GetBunRouter returns the underlying bunrouter for direct access // GetBunRouter returns the underlying bunrouter for direct access
@@ -121,6 +126,16 @@ func (b *BunRouterRequest) QueryParam(key string) string {
return b.req.URL.Query().Get(key) return b.req.URL.Query().Get(key)
} }
func (b *BunRouterRequest) AllQueryParams() map[string]string {
params := make(map[string]string)
for key, values := range b.req.URL.Query() {
if len(values) > 0 {
params[key] = values[0]
}
}
return params
}
func (b *BunRouterRequest) AllHeaders() map[string]string { func (b *BunRouterRequest) AllHeaders() map[string]string {
headers := make(map[string]string) headers := make(map[string]string)
for key, values := range b.req.Header { for key, values := range b.req.Header {
@@ -131,6 +146,12 @@ func (b *BunRouterRequest) AllHeaders() map[string]string {
return headers return headers
} }
// UnderlyingRequest returns the underlying *http.Request
// This is useful when you need to pass the request to other handlers
func (b *BunRouterRequest) UnderlyingRequest() *http.Request {
return b.req.Request
}
// StandardBunRouterAdapter creates routes compatible with standard bunrouter handlers // StandardBunRouterAdapter creates routes compatible with standard bunrouter handlers
type StandardBunRouterAdapter struct { type StandardBunRouterAdapter struct {
*BunRouterAdapter *BunRouterAdapter

View File

@@ -8,6 +8,7 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
) )
// MuxAdapter adapts Gorilla Mux to work with our Router interface // MuxAdapter adapts Gorilla Mux to work with our Router interface
@@ -32,7 +33,11 @@ func (m *MuxAdapter) HandleFunc(pattern string, handler common.HTTPHandlerFunc)
func (m *MuxAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) { func (m *MuxAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) {
// This method would be used when we need to serve through our interface // This method would be used when we need to serve through our interface
// For now, we'll work directly with the underlying router // For now, we'll work directly with the underlying router
panic("ServeHTTP not implemented - use GetMuxRouter() for direct access") w.WriteHeader(http.StatusNotImplemented)
_, err := w.Write([]byte(`{"error":"ServeHTTP not implemented - use GetMuxRouter() for direct access"}`))
if err != nil {
logger.Warn("Failed to write. %v", err)
}
} }
// MuxRouteRegistration implements RouteRegistration for Mux // MuxRouteRegistration implements RouteRegistration for Mux
@@ -117,6 +122,16 @@ func (h *HTTPRequest) QueryParam(key string) string {
return h.req.URL.Query().Get(key) return h.req.URL.Query().Get(key)
} }
func (h *HTTPRequest) AllQueryParams() map[string]string {
params := make(map[string]string)
for key, values := range h.req.URL.Query() {
if len(values) > 0 {
params[key] = values[0]
}
}
return params
}
func (h *HTTPRequest) AllHeaders() map[string]string { func (h *HTTPRequest) AllHeaders() map[string]string {
headers := make(map[string]string) headers := make(map[string]string)
for key, values := range h.req.Header { for key, values := range h.req.Header {
@@ -127,6 +142,12 @@ func (h *HTTPRequest) AllHeaders() map[string]string {
return headers return headers
} }
// UnderlyingRequest returns the underlying *http.Request
// This is useful when you need to pass the request to other handlers
func (h *HTTPRequest) UnderlyingRequest() *http.Request {
return h.req
}
// HTTPResponseWriter adapts our ResponseWriter interface to standard http.ResponseWriter // HTTPResponseWriter adapts our ResponseWriter interface to standard http.ResponseWriter
type HTTPResponseWriter struct { type HTTPResponseWriter struct {
resp http.ResponseWriter resp http.ResponseWriter
@@ -156,6 +177,12 @@ func (h *HTTPResponseWriter) WriteJSON(data interface{}) error {
return json.NewEncoder(h.resp).Encode(data) return json.NewEncoder(h.resp).Encode(data)
} }
// UnderlyingResponseWriter returns the underlying http.ResponseWriter
// This is useful when you need to pass the response writer to other handlers
func (h *HTTPResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
return h.resp
}
// StandardMuxAdapter creates routes compatible with standard http.HandlerFunc // StandardMuxAdapter creates routes compatible with standard http.HandlerFunc
type StandardMuxAdapter struct { type StandardMuxAdapter struct {
*MuxAdapter *MuxAdapter

119
pkg/common/cors.go Normal file
View File

@@ -0,0 +1,119 @@
package common
import (
"fmt"
"strings"
)
// CORSConfig holds CORS configuration
type CORSConfig struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
MaxAge int
}
// DefaultCORSConfig returns a default CORS configuration suitable for HeadSpec
func DefaultCORSConfig() CORSConfig {
return CORSConfig{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowedHeaders: GetHeadSpecHeaders(),
MaxAge: 86400, // 24 hours
}
}
// GetHeadSpecHeaders returns all headers used by HeadSpec
func GetHeadSpecHeaders() []string {
return []string{
// Standard headers
"Content-Type",
"Authorization",
"Accept",
"Accept-Language",
"Content-Language",
// Field Selection
"X-Select-Fields",
"X-Not-Select-Fields",
"X-Clean-JSON",
// Filtering & Search
"X-FieldFilter-*",
"X-SearchFilter-*",
"X-SearchOp-*",
"X-SearchOr-*",
"X-SearchAnd-*",
"X-SearchCols",
"X-Custom-SQL-W",
"X-Custom-SQL-W-*",
"X-Custom-SQL-Or",
"X-Custom-SQL-Or-*",
// Joins & Relations
"X-Preload",
"X-Preload-*",
"X-Expand",
"X-Expand-*",
"X-Custom-SQL-Join",
"X-Custom-SQL-Join-*",
// Sorting & Pagination
"X-Sort",
"X-Sort-*",
"X-Limit",
"X-Offset",
"X-Cursor-Forward",
"X-Cursor-Backward",
// Advanced Features
"X-AdvSQL-*",
"X-CQL-Sel-*",
"X-Distinct",
"X-SkipCount",
"X-SkipCache",
"X-Fetch-RowNumber",
"X-PKRow",
// Response Format
"X-SimpleAPI",
"X-DetailAPI",
"X-Syncfusion",
"X-Single-Record-As-Object",
// Transaction Control
"X-Transaction-Atomic",
// X-Files - comprehensive JSON configuration
"X-Files",
}
}
// SetCORSHeaders sets CORS headers on a response writer
func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
// Set allowed origins
if len(config.AllowedOrigins) > 0 {
w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
}
// Set allowed methods
if len(config.AllowedMethods) > 0 {
w.SetHeader("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
}
// Set allowed headers
if len(config.AllowedHeaders) > 0 {
w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
}
// Set max age
if config.MaxAge > 0 {
w.SetHeader("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
}
// Allow credentials
w.SetHeader("Access-Control-Allow-Credentials", "true")
// Expose headers that clients can read
w.SetHeader("Access-Control-Expose-Headers", "Content-Range, X-Api-Range-Total, X-Api-Range-Size")
}

View File

@@ -0,0 +1,97 @@
package common
// Example showing how to use the common handler interfaces
// This file demonstrates the handler interface hierarchy and usage patterns
// ProcessWithAnyHandler demonstrates using the base SpecHandler interface
// which works with any handler type (resolvespec, restheadspec, or funcspec)
func ProcessWithAnyHandler(handler SpecHandler) Database {
// All handlers expose GetDatabase() through the SpecHandler interface
return handler.GetDatabase()
}
// ProcessCRUDRequest demonstrates using the CRUDHandler interface
// which works with resolvespec.Handler and restheadspec.Handler
func ProcessCRUDRequest(handler CRUDHandler, w ResponseWriter, r Request, params map[string]string) {
// Both resolvespec and restheadspec handlers implement Handle()
handler.Handle(w, r, params)
}
// ProcessMetadataRequest demonstrates getting metadata from CRUD handlers
func ProcessMetadataRequest(handler CRUDHandler, w ResponseWriter, r Request, params map[string]string) {
// Both resolvespec and restheadspec handlers implement HandleGet()
handler.HandleGet(w, r, params)
}
// Example usage patterns (not executable, just for documentation):
/*
// Example 1: Using with resolvespec.Handler
func ExampleResolveSpec() {
db := // ... get database
registry := // ... get registry
handler := resolvespec.NewHandler(db, registry)
// Can be used as SpecHandler
var specHandler SpecHandler = handler
database := specHandler.GetDatabase()
// Can be used as CRUDHandler
var crudHandler CRUDHandler = handler
crudHandler.Handle(w, r, params)
crudHandler.HandleGet(w, r, params)
}
// Example 2: Using with restheadspec.Handler
func ExampleRestHeadSpec() {
db := // ... get database
registry := // ... get registry
handler := restheadspec.NewHandler(db, registry)
// Can be used as SpecHandler
var specHandler SpecHandler = handler
database := specHandler.GetDatabase()
// Can be used as CRUDHandler
var crudHandler CRUDHandler = handler
crudHandler.Handle(w, r, params)
crudHandler.HandleGet(w, r, params)
}
// Example 3: Using with funcspec.Handler
func ExampleFuncSpec() {
db := // ... get database
handler := funcspec.NewHandler(db)
// Can be used as SpecHandler
var specHandler SpecHandler = handler
database := specHandler.GetDatabase()
// Can be used as QueryHandler
var queryHandler QueryHandler = handler
// funcspec has different methods: SqlQueryList() and SqlQuery()
// which return HTTP handler functions
}
// Example 4: Polymorphic handler processing
func ProcessHandlers(handlers []SpecHandler) {
for _, handler := range handlers {
// All handlers expose the database
db := handler.GetDatabase()
// Type switch for specific handler types
switch h := handler.(type) {
case CRUDHandler:
// This is resolvespec or restheadspec
// Can call Handle() and HandleGet()
_ = h
case QueryHandler:
// This is funcspec
// Can call SqlQueryList() and SqlQuery()
_ = h
}
}
}
*/

View File

@@ -0,0 +1,47 @@
package common
import (
"fmt"
"reflect"
)
// ValidateAndUnwrapModelResult contains the result of model validation
type ValidateAndUnwrapModelResult struct {
ModelType reflect.Type
Model interface{}
ModelPtr interface{}
OriginalType reflect.Type
}
// ValidateAndUnwrapModel validates that a model is a struct type and unwraps
// pointers, slices, and arrays to get to the base struct type.
// Returns an error if the model is not a valid struct type.
func ValidateAndUnwrapModel(model interface{}) (*ValidateAndUnwrapModelResult, error) {
modelType := reflect.TypeOf(model)
originalType := modelType
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil, fmt.Errorf("model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType)
}
// If the registered model was a pointer or slice, use the unwrapped struct type
if originalType != modelType {
model = reflect.New(modelType).Elem().Interface()
}
// Create a pointer to the model type for database operations
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
return &ValidateAndUnwrapModelResult{
ModelType: modelType,
Model: model,
ModelPtr: modelPtr,
OriginalType: originalType,
}, nil
}

View File

@@ -1,6 +1,11 @@
package common package common
import "context" import (
"context"
"encoding/json"
"io"
"net/http"
)
// Database interface designed to work with both GORM and Bun // Database interface designed to work with both GORM and Bun
type Database interface { type Database interface {
@@ -33,6 +38,7 @@ type SelectQuery interface {
LeftJoin(query string, args ...interface{}) SelectQuery LeftJoin(query string, args ...interface{}) SelectQuery
Preload(relation string, conditions ...interface{}) SelectQuery Preload(relation string, conditions ...interface{}) SelectQuery
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
Order(order string) SelectQuery Order(order string) SelectQuery
Limit(n int) SelectQuery Limit(n int) SelectQuery
Offset(n int) SelectQuery Offset(n int) SelectQuery
@@ -116,6 +122,8 @@ type Request interface {
Body() ([]byte, error) Body() ([]byte, error)
PathParam(key string) string PathParam(key string) string
QueryParam(key string) string QueryParam(key string) string
AllQueryParams() map[string]string // Get all query parameters as a map
UnderlyingRequest() *http.Request // Get the underlying *http.Request for forwarding to other handlers
} }
// ResponseWriter interface abstracts HTTP response // ResponseWriter interface abstracts HTTP response
@@ -124,11 +132,113 @@ type ResponseWriter interface {
WriteHeader(statusCode int) WriteHeader(statusCode int)
Write(data []byte) (int, error) Write(data []byte) (int, error)
WriteJSON(data interface{}) error WriteJSON(data interface{}) error
UnderlyingResponseWriter() http.ResponseWriter // Get the underlying http.ResponseWriter for forwarding to other handlers
} }
// HTTPHandlerFunc type for HTTP handlers // HTTPHandlerFunc type for HTTP handlers
type HTTPHandlerFunc func(ResponseWriter, Request) type HTTPHandlerFunc func(ResponseWriter, Request)
// WrapHTTPRequest wraps standard http.ResponseWriter and *http.Request into common interfaces
func WrapHTTPRequest(w http.ResponseWriter, r *http.Request) (ResponseWriter, Request) {
return &StandardResponseWriter{w: w}, &StandardRequest{r: r}
}
// StandardResponseWriter adapts http.ResponseWriter to ResponseWriter interface
type StandardResponseWriter struct {
w http.ResponseWriter
status int
}
func (s *StandardResponseWriter) SetHeader(key, value string) {
s.w.Header().Set(key, value)
}
func (s *StandardResponseWriter) WriteHeader(statusCode int) {
s.status = statusCode
s.w.WriteHeader(statusCode)
}
func (s *StandardResponseWriter) Write(data []byte) (int, error) {
return s.w.Write(data)
}
func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
s.SetHeader("Content-Type", "application/json")
return json.NewEncoder(s.w).Encode(data)
}
func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
return s.w
}
// StandardRequest adapts *http.Request to Request interface
type StandardRequest struct {
r *http.Request
body []byte
}
func (s *StandardRequest) Method() string {
return s.r.Method
}
func (s *StandardRequest) URL() string {
return s.r.URL.String()
}
func (s *StandardRequest) Header(key string) string {
return s.r.Header.Get(key)
}
func (s *StandardRequest) AllHeaders() map[string]string {
headers := make(map[string]string)
for key, values := range s.r.Header {
if len(values) > 0 {
headers[key] = values[0]
}
}
return headers
}
func (s *StandardRequest) Body() ([]byte, error) {
if s.body != nil {
return s.body, nil
}
if s.r.Body == nil {
return nil, nil
}
defer s.r.Body.Close()
body, err := io.ReadAll(s.r.Body)
if err != nil {
return nil, err
}
s.body = body
return body, nil
}
func (s *StandardRequest) PathParam(key string) string {
// Standard http.Request doesn't have path params
// This should be set by the router
return ""
}
func (s *StandardRequest) QueryParam(key string) string {
return s.r.URL.Query().Get(key)
}
func (s *StandardRequest) AllQueryParams() map[string]string {
params := make(map[string]string)
for key, values := range s.r.URL.Query() {
if len(values) > 0 {
params[key] = values[0]
}
}
return params
}
func (s *StandardRequest) UnderlyingRequest() *http.Request {
return s.r
}
// TableNameProvider interface for models that provide table names // TableNameProvider interface for models that provide table names
type TableNameProvider interface { type TableNameProvider interface {
TableName() string TableName() string
@@ -147,3 +257,39 @@ type PrimaryKeyNameProvider interface {
type SchemaProvider interface { type SchemaProvider interface {
SchemaName() string SchemaName() string
} }
// SpecHandler interface represents common functionality across all spec handlers
// This is the base interface implemented by:
// - resolvespec.Handler: Handles CRUD operations via request body with explicit operation field
// - restheadspec.Handler: Handles CRUD operations via HTTP methods (GET/POST/PUT/DELETE)
// - funcspec.Handler: Handles custom SQL query execution with dynamic parameters
//
// The interface hierarchy is:
//
// SpecHandler (base)
// ├── CRUDHandler (resolvespec, restheadspec)
// └── QueryHandler (funcspec)
type SpecHandler interface {
// GetDatabase returns the underlying database connection
GetDatabase() Database
}
// CRUDHandler interface for handlers that support CRUD operations
// This is implemented by resolvespec.Handler and restheadspec.Handler
type CRUDHandler interface {
SpecHandler
// Handle processes API requests through router-agnostic interface
Handle(w ResponseWriter, r Request, params map[string]string)
// HandleGet processes GET requests for metadata
HandleGet(w ResponseWriter, r Request, params map[string]string)
}
// QueryHandler interface for handlers that execute SQL queries
// This is implemented by funcspec.Handler
// Note: funcspec uses standard http.ResponseWriter and *http.Request instead of common interfaces
type QueryHandler interface {
SpecHandler
// Methods are defined in funcspec package due to different function signature requirements
}

447
pkg/common/sql_helpers.go Normal file
View File

@@ -0,0 +1,447 @@
package common
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ValidateAndFixPreloadWhere validates and normalizes WHERE clauses for preloads
//
// NOTE: For preload queries, table aliases from the parent query are not valid since
// the preload executes as a separate query with its own table alias. This function
// now simply validates basic syntax without requiring or adding prefixes.
// The actual alias normalization happens in the database adapter layer.
//
// Returns the WHERE clause and an error if it contains obviously invalid syntax.
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
if where == "" {
return where, nil
}
where = strings.TrimSpace(where)
// Just do basic validation - don't require or add prefixes
// The database adapter will handle alias normalization
// Check if the WHERE clause contains any qualified column references
// If it does, log a debug message but don't fail - let the adapter handle it
if strings.Contains(where, ".") {
logger.Debug("Preload WHERE clause for '%s' contains qualified column references: '%s'. "+
"Note: In preload context, table aliases from parent query are not available. "+
"The database adapter will normalize aliases automatically.", relationName, where)
}
// Validate that it's not empty or just whitespace
if where == "" {
return where, nil
}
// Return the WHERE clause as-is
// The BunSelectQuery.Where() method will handle alias normalization via normalizeTableAlias()
return where, nil
}
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
func IsSQLExpression(cond string) bool {
// Common SQL literals and expressions
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
for _, literal := range sqlLiterals {
if cond == literal {
return true
}
}
return false
}
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
// These conditions should be removed from WHERE clauses as they have no filtering effect
func IsTrivialCondition(cond string) bool {
cond = strings.TrimSpace(cond)
lowerCond := strings.ToLower(cond)
// Conditions that always evaluate to true
trivialConditions := []string{
"1=1", "1 = 1", "1= 1", "1 =1",
"true", "true = true", "true=true", "true= true", "true =true",
"0=0", "0 = 0", "0= 0", "0 =0",
}
for _, trivial := range trivialConditions {
if lowerCond == trivial {
return true
}
}
return false
}
// validateWhereClauseSecurity checks for dangerous SQL statements in WHERE clauses
// Returns an error if any dangerous keywords are found
func validateWhereClauseSecurity(where string) error {
if where == "" {
return nil
}
lowerWhere := strings.ToLower(where)
// List of dangerous SQL keywords that should never appear in WHERE clauses
dangerousKeywords := []string{
"delete ", "delete\t", "delete\n", "delete;",
"update ", "update\t", "update\n", "update;",
"truncate ", "truncate\t", "truncate\n", "truncate;",
"drop ", "drop\t", "drop\n", "drop;",
"alter ", "alter\t", "alter\n", "alter;",
"create ", "create\t", "create\n", "create;",
"insert ", "insert\t", "insert\n", "insert;",
"grant ", "grant\t", "grant\n", "grant;",
"revoke ", "revoke\t", "revoke\n", "revoke;",
"exec ", "exec\t", "exec\n", "exec;",
"execute ", "execute\t", "execute\n", "execute;",
";delete", ";update", ";truncate", ";drop", ";alter", ";create", ";insert",
}
for _, keyword := range dangerousKeywords {
if strings.Contains(lowerWhere, keyword) {
logger.Error("Dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
return fmt.Errorf("dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
}
}
return nil
}
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
//
// Parameters:
// - where: The WHERE clause string to sanitize
// - tableName: The correct table/relation name to use when fixing incorrect prefixes
// - options: Optional RequestOptions containing preload relations that should be allowed as valid prefixes
//
// Returns:
// - The sanitized WHERE clause with trivial conditions removed and incorrect prefixes fixed
// - An empty string if all conditions were trivial or the input was empty
//
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
// prefix matches a preloaded relation name, in which case it's left unchanged.
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
if where == "" {
return ""
}
where = strings.TrimSpace(where)
// Validate that the WHERE clause doesn't contain dangerous SQL statements
if err := validateWhereClauseSecurity(where); err != nil {
logger.Debug("Security validation failed for WHERE clause: %v", err)
return ""
}
// Strip outer parentheses and re-trim
where = stripOuterParentheses(where)
// Get valid columns from the model if tableName is provided
var validColumns map[string]bool
if tableName != "" {
validColumns = getValidColumnsForTable(tableName)
}
// Build a set of allowed table prefixes (main table + preloaded relations)
allowedPrefixes := make(map[string]bool)
if tableName != "" {
allowedPrefixes[tableName] = true
}
// Add preload relation names as allowed prefixes
if len(options) > 0 && options[0] != nil {
for pi := range options[0].Preload {
if options[0].Preload[pi].Relation != "" {
allowedPrefixes[options[0].Preload[pi].Relation] = true
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
}
}
}
// Split by AND to handle multiple conditions
conditions := splitByAND(where)
validConditions := make([]string, 0, len(conditions))
for _, cond := range conditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Strip parentheses from the condition before checking
condToCheck := stripOuterParentheses(cond)
// Skip trivial conditions that always evaluate to true
if IsTrivialCondition(condToCheck) {
logger.Debug("Removing trivial condition: '%s'", cond)
continue
}
// If tableName is provided and the condition HAS a table prefix, check if it's correct
if tableName != "" && hasTablePrefix(condToCheck) {
// Extract the current prefix and column name
currentPrefix, columnName := extractTableAndColumn(condToCheck)
if currentPrefix != "" && columnName != "" {
// Check if the prefix is allowed (main table or preload relation)
if !allowedPrefixes[currentPrefix] {
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace the incorrect prefix with the correct main table name
oldRef := currentPrefix + "." + columnName
newRef := tableName + "." + columnName
cond = strings.Replace(cond, oldRef, newRef, 1)
logger.Debug("Fixed incorrect table prefix in condition: '%s' -> '%s'", oldRef, newRef)
} else {
logger.Debug("Skipping prefix fix for '%s.%s' - not a valid column in main table (might be preload relation)", currentPrefix, columnName)
}
}
}
}
validConditions = append(validConditions, cond)
}
if len(validConditions) == 0 {
return ""
}
result := strings.Join(validConditions, " AND ")
if result != where {
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
}
return result
}
// stripOuterParentheses removes matching outer parentheses from a string
// It handles nested parentheses correctly
func stripOuterParentheses(s string) string {
s = strings.TrimSpace(s)
for {
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
return s
}
// Check if these parentheses match (i.e., they're the outermost pair)
depth := 0
matched := false
for i := 0; i < len(s); i++ {
switch s[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 && i == len(s)-1 {
matched = true
} else if depth == 0 {
// Found a closing paren before the end, so outer parens don't match
return s
}
}
}
if !matched {
return s
}
// Strip the outer parentheses and continue
s = strings.TrimSpace(s[1 : len(s)-1])
}
}
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
// This is parenthesis-aware and won't split on AND operators inside subqueries
func splitByAND(where string) []string {
conditions := []string{}
currentCondition := strings.Builder{}
depth := 0 // Track parenthesis depth
i := 0
for i < len(where) {
ch := where[i]
// Track parenthesis depth
if ch == '(' {
depth++
currentCondition.WriteByte(ch)
i++
continue
} else if ch == ')' {
depth--
currentCondition.WriteByte(ch)
i++
continue
}
// Only look for AND operators at depth 0 (not inside parentheses)
if depth == 0 {
// Check if we're at an AND operator (case-insensitive)
// We need at least " AND " (5 chars) or " and " (5 chars)
if i+5 <= len(where) {
substring := where[i : i+5]
lowerSubstring := strings.ToLower(substring)
if lowerSubstring == " and " {
// Found an AND operator at the top level
// Add the current condition to the list
conditions = append(conditions, currentCondition.String())
currentCondition.Reset()
// Skip past the AND operator
i += 5
continue
}
}
}
// Not an AND operator or we're inside parentheses, just add the character
currentCondition.WriteByte(ch)
i++
}
// Add the last condition
if currentCondition.Len() > 0 {
conditions = append(conditions, currentCondition.String())
}
return conditions
}
// hasTablePrefix checks if a condition already has a table/relation prefix (contains a dot)
func hasTablePrefix(cond string) bool {
// Look for patterns like "table.column" or "`table`.`column`" or "\"table\".\"column\""
return strings.Contains(cond, ".")
}
// ExtractColumnName extracts the column name from a WHERE condition
// For example: "status = 'active'" returns "status"
func ExtractColumnName(cond string) string {
// Common SQL operators
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
for _, op := range operators {
if idx := strings.Index(cond, op); idx > 0 {
columnName := strings.TrimSpace(cond[:idx])
// Remove quotes if present
columnName = strings.Trim(columnName, "`\"'")
return columnName
}
}
// If no operator found, check if it's a simple identifier (for boolean columns)
parts := strings.Fields(cond)
if len(parts) > 0 {
columnName := strings.Trim(parts[0], "`\"'")
// Check if it's a valid identifier (not a SQL keyword)
if !IsSQLKeyword(strings.ToLower(columnName)) {
return columnName
}
}
return ""
}
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
func IsSQLKeyword(word string) bool {
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
for _, kw := range keywords {
if word == kw {
return true
}
}
return false
}
// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry
// Returns a map of column names for fast lookup, or nil if the model is not found
func getValidColumnsForTable(tableName string) map[string]bool {
// Try to get the model from the registry
model, err := modelregistry.GetModelByName(tableName)
if err != nil {
// Model not found, return nil to indicate we should use fallback behavior
return nil
}
// Get SQL columns from the model
columns := reflection.GetSQLModelColumns(model)
if len(columns) == 0 {
// No columns found, return nil
return nil
}
// Build a map for fast lookup
columnMap := make(map[string]bool, len(columns))
for _, col := range columns {
columnMap[strings.ToLower(col)] = true
}
return columnMap
}
// extractTableAndColumn extracts the table prefix and column name from a qualified reference
// For example: "users.status = 'active'" returns ("users", "status")
// Returns empty strings if no table prefix is found
func extractTableAndColumn(cond string) (table string, column string) {
// Common SQL operators to find the column reference
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
var columnRef string
// Find the column reference (left side of the operator)
for _, op := range operators {
if idx := strings.Index(cond, op); idx > 0 {
columnRef = strings.TrimSpace(cond[:idx])
break
}
}
// If no operator found, the whole condition might be the column reference
if columnRef == "" {
parts := strings.Fields(cond)
if len(parts) > 0 {
columnRef = parts[0]
}
}
if columnRef == "" {
return "", ""
}
// Remove any quotes
columnRef = strings.Trim(columnRef, "`\"'")
// Check if it contains a dot (qualified reference)
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
table = columnRef[:dotIdx]
column = columnRef[dotIdx+1:]
// Remove quotes from table and column if present
table = strings.Trim(table, "`\"'")
column = strings.Trim(column, "`\"'")
return table, column
}
return "", ""
}
// isValidColumn checks if a column name exists in the valid columns map
// Handles case-insensitive comparison
func isValidColumn(columnName string, validColumns map[string]bool) bool {
if validColumns == nil {
return true // No model info, assume valid
}
return validColumns[strings.ToLower(columnName)]
}

View File

@@ -0,0 +1,579 @@
package common
import (
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
func TestSanitizeWhereClause(t *testing.T) {
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "trivial conditions in parentheses",
where: "(true AND true AND true)",
tableName: "mastertask",
expected: "",
},
{
name: "trivial conditions without parentheses",
where: "true AND true AND true",
tableName: "mastertask",
expected: "",
},
{
name: "single trivial condition",
where: "true",
tableName: "mastertask",
expected: "",
},
{
name: "valid condition with parentheses - no prefix added",
where: "(status = 'active')",
tableName: "users",
expected: "status = 'active'",
},
{
name: "mixed trivial and valid conditions - no prefix added",
where: "true AND status = 'active' AND 1=1",
tableName: "users",
expected: "status = 'active'",
},
{
name: "condition with correct table prefix - unchanged",
where: "users.status = 'active'",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "condition with incorrect table prefix - fixed",
where: "wrong_table.status = 'active'",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "multiple conditions with incorrect prefix - fixed",
where: "wrong_table.status = 'active' AND wrong_table.age > 18",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "multiple valid conditions without prefix - no prefix added",
where: "status = 'active' AND age > 18",
tableName: "users",
expected: "status = 'active' AND age > 18",
},
{
name: "no table name provided",
where: "status = 'active'",
tableName: "",
expected: "status = 'active'",
},
{
name: "empty where clause",
where: "",
tableName: "users",
expected: "",
},
{
name: "mixed correct and incorrect prefixes",
where: "users.status = 'active' AND wrong_table.age > 18",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "mixed case AND operators",
where: "status = 'active' AND age > 18 and name = 'John'",
tableName: "users",
expected: "status = 'active' AND age > 18 AND name = 'John'",
},
{
name: "subquery with ORDER BY and LIMIT - allowed",
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
tableName: "users",
expected: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
},
{
name: "dangerous DELETE keyword - blocked",
where: "status = 'active'; DELETE FROM users",
tableName: "users",
expected: "",
},
{
name: "dangerous UPDATE keyword - blocked",
where: "1=1; UPDATE users SET admin = true",
tableName: "users",
expected: "",
},
{
name: "dangerous TRUNCATE keyword - blocked",
where: "status = 'active' OR TRUNCATE TABLE users",
tableName: "users",
expected: "",
},
{
name: "dangerous DROP keyword - blocked",
where: "status = 'active'; DROP TABLE users",
tableName: "users",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}
func TestStripOuterParentheses(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "single level parentheses",
input: "(true)",
expected: "true",
},
{
name: "multiple levels",
input: "((true))",
expected: "true",
},
{
name: "no parentheses",
input: "true",
expected: "true",
},
{
name: "mismatched parentheses",
input: "(true",
expected: "(true",
},
{
name: "complex expression",
input: "(a AND b)",
expected: "a AND b",
},
{
name: "nested but not outer",
input: "(a AND (b OR c)) AND d",
expected: "(a AND (b OR c)) AND d",
},
{
name: "with spaces",
input: " ( true ) ",
expected: "true",
},
{
name: "complex sub query",
input: "(a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3)",
expected: "a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := stripOuterParentheses(tt.input)
if result != tt.expected {
t.Errorf("stripOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestIsTrivialCondition(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{"true", "true", true},
{"true with spaces", " true ", true},
{"TRUE uppercase", "TRUE", true},
{"1=1", "1=1", true},
{"1 = 1", "1 = 1", true},
{"true = true", "true = true", true},
{"valid condition", "status = 'active'", false},
{"false", "false", false},
{"column name", "is_active", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsTrivialCondition(tt.input)
if result != tt.expected {
t.Errorf("IsTrivialCondition(%q) = %v; want %v", tt.input, result, tt.expected)
}
})
}
}
func TestExtractTableAndColumn(t *testing.T) {
tests := []struct {
name string
input string
expectedTable string
expectedCol string
}{
{
name: "qualified column with equals",
input: "users.status = 'active'",
expectedTable: "users",
expectedCol: "status",
},
{
name: "qualified column with greater than",
input: "users.age > 18",
expectedTable: "users",
expectedCol: "age",
},
{
name: "qualified column with LIKE",
input: "users.name LIKE '%john%'",
expectedTable: "users",
expectedCol: "name",
},
{
name: "qualified column with IN",
input: "users.status IN ('active', 'pending')",
expectedTable: "users",
expectedCol: "status",
},
{
name: "unqualified column",
input: "status = 'active'",
expectedTable: "",
expectedCol: "",
},
{
name: "qualified with backticks",
input: "`users`.`status` = 'active'",
expectedTable: "users",
expectedCol: "status",
},
{
name: "schema.table.column reference",
input: "public.users.status = 'active'",
expectedTable: "public.users",
expectedCol: "status",
},
{
name: "empty string",
input: "",
expectedTable: "",
expectedCol: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
table, col := extractTableAndColumn(tt.input)
if table != tt.expectedTable || col != tt.expectedCol {
t.Errorf("extractTableAndColumn(%q) = (%q, %q); want (%q, %q)",
tt.input, table, col, tt.expectedTable, tt.expectedCol)
}
})
}
}
func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
tests := []struct {
name string
where string
tableName string
options *RequestOptions
expected string
}{
{
name: "preload relation prefix is preserved",
where: "Department.name = 'Engineering'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
},
},
expected: "Department.name = 'Engineering'",
},
{
name: "multiple preload relations - all preserved",
where: "Department.name = 'Engineering' AND Manager.status = 'active'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
{Relation: "Manager"},
},
},
expected: "Department.name = 'Engineering' AND Manager.status = 'active'",
},
{
name: "mix of main table and preload relation",
where: "users.status = 'active' AND Department.name = 'Engineering'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
},
},
expected: "users.status = 'active' AND Department.name = 'Engineering'",
},
{
name: "incorrect prefix fixed when not a preload relation",
where: "wrong_table.status = 'active' AND Department.name = 'Engineering'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
},
},
expected: "users.status = 'active' AND Department.name = 'Engineering'",
},
{
name: "no options provided - works as before",
where: "wrong_table.status = 'active'",
tableName: "users",
options: nil,
expected: "users.status = 'active'",
},
{
name: "empty preload list - works as before",
where: "wrong_table.status = 'active'",
tableName: "users",
options: &RequestOptions{Preload: []PreloadOption{}},
expected: "users.status = 'active'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var result string
if tt.options != nil {
result = SanitizeWhereClause(tt.where, tt.tableName, tt.options)
} else {
result = SanitizeWhereClause(tt.where, tt.tableName)
}
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q, options) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}
// Test model for model-aware sanitization tests
type MasterTask struct {
ID int `bun:"id,pk"`
Name string `bun:"name"`
Status string `bun:"status"`
UserID int `bun:"user_id"`
}
func TestSplitByAND(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "uppercase AND",
input: "status = 'active' AND age > 18",
expected: []string{"status = 'active'", "age > 18"},
},
{
name: "lowercase and",
input: "status = 'active' and age > 18",
expected: []string{"status = 'active'", "age > 18"},
},
{
name: "mixed case AND",
input: "status = 'active' AND age > 18 and name = 'John'",
expected: []string{"status = 'active'", "age > 18", "name = 'John'"},
},
{
name: "single condition",
input: "status = 'active'",
expected: []string{"status = 'active'"},
},
{
name: "multiple uppercase AND",
input: "a = 1 AND b = 2 AND c = 3",
expected: []string{"a = 1", "b = 2", "c = 3"},
},
{
name: "multiple case subquery",
input: "a = 1 AND b = 2 AND c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
expected: []string{"a = 1", "b = 2", "c = 3", "(select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := splitByAND(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("splitByAND(%q) returned %d conditions; want %d", tt.input, len(result), len(tt.expected))
return
}
for i := range result {
if strings.TrimSpace(result[i]) != strings.TrimSpace(tt.expected[i]) {
t.Errorf("splitByAND(%q)[%d] = %q; want %q", tt.input, i, result[i], tt.expected[i])
}
}
})
}
}
func TestValidateWhereClauseSecurity(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
}{
{
name: "safe WHERE clause",
input: "status = 'active' AND age > 18",
expectError: false,
},
{
name: "safe subquery",
input: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
expectError: false,
},
{
name: "DELETE keyword",
input: "status = 'active'; DELETE FROM users",
expectError: true,
},
{
name: "UPDATE keyword",
input: "1=1; UPDATE users SET admin = true",
expectError: true,
},
{
name: "TRUNCATE keyword",
input: "status = 'active' OR TRUNCATE TABLE users",
expectError: true,
},
{
name: "DROP keyword",
input: "status = 'active'; DROP TABLE users",
expectError: true,
},
{
name: "INSERT keyword",
input: "status = 'active'; INSERT INTO users (name) VALUES ('hacker')",
expectError: true,
},
{
name: "ALTER keyword",
input: "1=1; ALTER TABLE users ADD COLUMN is_admin BOOLEAN",
expectError: true,
},
{
name: "CREATE keyword",
input: "1=1; CREATE TABLE malicious (id INT)",
expectError: true,
},
{
name: "empty clause",
input: "",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateWhereClauseSecurity(tt.input)
if tt.expectError && err == nil {
t.Errorf("validateWhereClauseSecurity(%q) expected error but got none", tt.input)
}
if !tt.expectError && err != nil {
t.Errorf("validateWhereClauseSecurity(%q) unexpected error: %v", tt.input, err)
}
})
}
}
func TestSanitizeWhereClauseWithModel(t *testing.T) {
// Register the test model
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
if err != nil {
// Model might already be registered, ignore error
t.Logf("Model registration returned: %v", err)
}
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "valid column without prefix - no prefix added",
where: "status = 'active'",
tableName: "mastertask",
expected: "status = 'active'",
},
{
name: "multiple valid columns without prefix - no prefix added",
where: "status = 'active' AND user_id = 123",
tableName: "mastertask",
expected: "status = 'active' AND user_id = 123",
},
{
name: "incorrect table prefix on valid column - fixed",
where: "wrong_table.status = 'active'",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "incorrect prefix on invalid column - not fixed",
where: "wrong_table.invalid_column = 'value'",
tableName: "mastertask",
expected: "wrong_table.invalid_column = 'value'",
},
{
name: "mix of valid and trivial conditions",
where: "true AND status = 'active' AND 1=1",
tableName: "mastertask",
expected: "status = 'active'",
},
{
name: "parentheses with valid column - no prefix added",
where: "(status = 'active')",
tableName: "mastertask",
expected: "status = 'active'",
},
{
name: "correct prefix - unchanged",
where: "mastertask.status = 'active'",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "multiple conditions with mixed prefixes",
where: "mastertask.status = 'active' AND wrong_table.user_id = 123",
tableName: "mastertask",
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -9,18 +9,18 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
// TestSqlInt16 tests SqlInt16 type // TestNewSqlInt16 tests NewSqlInt16 type
func TestSqlInt16(t *testing.T) { func TestNewSqlInt16(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
input interface{} input interface{}
expected SqlInt16 expected SqlInt16
}{ }{
{"int", 42, SqlInt16(42)}, {"int", 42, Null(int16(42), true)},
{"int32", int32(100), SqlInt16(100)}, {"int32", int32(100), NewSqlInt16(100)},
{"int64", int64(200), SqlInt16(200)}, {"int64", int64(200), NewSqlInt16(200)},
{"string", "123", SqlInt16(123)}, {"string", "123", NewSqlInt16(123)},
{"nil", nil, SqlInt16(0)}, {"nil", nil, Null(int16(0), false)},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -36,15 +36,15 @@ func TestSqlInt16(t *testing.T) {
} }
} }
func TestSqlInt16_Value(t *testing.T) { func TestNewSqlInt16_Value(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
input SqlInt16 input SqlInt16
expected driver.Value expected driver.Value
}{ }{
{"zero", SqlInt16(0), nil}, {"zero", Null(int16(0), false), nil},
{"positive", SqlInt16(42), int64(42)}, {"positive", NewSqlInt16(42), int16(42)},
{"negative", SqlInt16(-10), int64(-10)}, {"negative", NewSqlInt16(-10), int16(-10)},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -60,8 +60,8 @@ func TestSqlInt16_Value(t *testing.T) {
} }
} }
func TestSqlInt16_JSON(t *testing.T) { func TestNewSqlInt16_JSON(t *testing.T) {
n := SqlInt16(42) n := NewSqlInt16(42)
// Marshal // Marshal
data, err := json.Marshal(n) data, err := json.Marshal(n)
@@ -78,24 +78,24 @@ func TestSqlInt16_JSON(t *testing.T) {
if err := json.Unmarshal([]byte("123"), &n2); err != nil { if err := json.Unmarshal([]byte("123"), &n2); err != nil {
t.Fatalf("Unmarshal failed: %v", err) t.Fatalf("Unmarshal failed: %v", err)
} }
if n2 != 123 { if n2.Int64() != 123 {
t.Errorf("expected 123, got %d", n2) t.Errorf("expected 123, got %d", n2.Int64())
} }
} }
// TestSqlInt64 tests SqlInt64 type // TestNewSqlInt64 tests NewSqlInt64 type
func TestSqlInt64(t *testing.T) { func TestNewSqlInt64(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
input interface{} input interface{}
expected SqlInt64 expected SqlInt64
}{ }{
{"int", 42, SqlInt64(42)}, {"int", 42, NewSqlInt64(42)},
{"int32", int32(100), SqlInt64(100)}, {"int32", int32(100), NewSqlInt64(100)},
{"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)}, {"int64", int64(9223372036854775807), NewSqlInt64(9223372036854775807)},
{"uint32", uint32(100), SqlInt64(100)}, {"uint32", uint32(100), NewSqlInt64(100)},
{"uint64", uint64(200), SqlInt64(200)}, {"uint64", uint64(200), NewSqlInt64(200)},
{"nil", nil, SqlInt64(0)}, {"nil", nil, SqlInt64{}},
} }
for _, tt := range tests { for _, tt := range tests {
@@ -135,8 +135,8 @@ func TestSqlFloat64(t *testing.T) {
if n.Valid != tt.valid { if n.Valid != tt.valid {
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid) t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
} }
if tt.valid && n.Float64 != tt.expected { if tt.valid && n.Float64() != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n.Float64) t.Errorf("expected %v, got %v", tt.expected, n.Float64())
} }
}) })
} }
@@ -162,7 +162,7 @@ func TestSqlTimeStamp(t *testing.T) {
if err := ts.Scan(tt.input); err != nil { if err := ts.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err) t.Fatalf("Scan failed: %v", err)
} }
if ts.GetTime().IsZero() { if ts.Time().IsZero() {
t.Error("expected non-zero time") t.Error("expected non-zero time")
} }
}) })
@@ -171,7 +171,7 @@ func TestSqlTimeStamp(t *testing.T) {
func TestSqlTimeStamp_JSON(t *testing.T) { func TestSqlTimeStamp_JSON(t *testing.T) {
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC) now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
ts := SqlTimeStamp(now) ts := NewSqlTimeStamp(now)
// Marshal // Marshal
data, err := json.Marshal(ts) data, err := json.Marshal(ts)
@@ -188,8 +188,8 @@ func TestSqlTimeStamp_JSON(t *testing.T) {
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil { if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
t.Fatalf("Unmarshal failed: %v", err) t.Fatalf("Unmarshal failed: %v", err)
} }
if ts2.GetTime().Year() != 2024 { if ts2.Time().Year() != 2024 {
t.Errorf("expected year 2024, got %d", ts2.GetTime().Year()) t.Errorf("expected year 2024, got %d", ts2.Time().Year())
} }
// Test null // Test null
@@ -226,7 +226,7 @@ func TestSqlDate(t *testing.T) {
} }
func TestSqlDate_JSON(t *testing.T) { func TestSqlDate_JSON(t *testing.T) {
date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC)) date := NewSqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
// Marshal // Marshal
data, err := json.Marshal(date) data, err := json.Marshal(date)
@@ -471,8 +471,8 @@ func TestSqlUUID_Scan(t *testing.T) {
if u.Valid != tt.valid { if u.Valid != tt.valid {
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid) t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
} }
if tt.valid && u.String != tt.expected { if tt.valid && u.String() != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, u.String) t.Errorf("expected %s, got %s", tt.expected, u.String())
} }
}) })
} }
@@ -480,13 +480,13 @@ func TestSqlUUID_Scan(t *testing.T) {
func TestSqlUUID_Value(t *testing.T) { func TestSqlUUID_Value(t *testing.T) {
testUUID := uuid.New() testUUID := uuid.New()
u := SqlUUID{String: testUUID.String(), Valid: true} u := NewSqlUUID(testUUID)
val, err := u.Value() val, err := u.Value()
if err != nil { if err != nil {
t.Fatalf("Value failed: %v", err) t.Fatalf("Value failed: %v", err)
} }
if val != testUUID.String() { if val != testUUID {
t.Errorf("expected %s, got %s", testUUID.String(), val) t.Errorf("expected %s, got %s", testUUID.String(), val)
} }
@@ -503,7 +503,7 @@ func TestSqlUUID_Value(t *testing.T) {
func TestSqlUUID_JSON(t *testing.T) { func TestSqlUUID_JSON(t *testing.T) {
testUUID := uuid.New() testUUID := uuid.New()
u := SqlUUID{String: testUUID.String(), Valid: true} u := NewSqlUUID(testUUID)
// Marshal // Marshal
data, err := json.Marshal(u) data, err := json.Marshal(u)
@@ -520,8 +520,8 @@ func TestSqlUUID_JSON(t *testing.T) {
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil { if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
t.Fatalf("Unmarshal failed: %v", err) t.Fatalf("Unmarshal failed: %v", err)
} }
if u2.String != testUUID.String() { if u2.String() != testUUID.String() {
t.Errorf("expected %s, got %s", testUUID.String(), u2.String) t.Errorf("expected %s, got %s", testUUID.String(), u2.String())
} }
// Test null // Test null

View File

@@ -32,15 +32,22 @@ type Parameter struct {
} }
type PreloadOption struct { type PreloadOption struct {
Relation string `json:"relation"` Relation string `json:"relation"`
Columns []string `json:"columns"` Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"` OmitColumns []string `json:"omit_columns"`
Sort []SortOption `json:"sort"` Sort []SortOption `json:"sort"`
Filters []FilterOption `json:"filters"` Filters []FilterOption `json:"filters"`
Where string `json:"where"` Where string `json:"where"`
Limit *int `json:"limit"` Limit *int `json:"limit"`
Offset *int `json:"offset"` Offset *int `json:"offset"`
Updatable *bool `json:"updateable"` // if true, the relation can be updated Updatable *bool `json:"updateable"` // if true, the relation can be updated
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
// Relationship keys from XFiles - used to build proper foreign key filters
PrimaryKey string `json:"primary_key"` // Primary key of the related table
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
} }
type FilterOption struct { type FilterOption struct {

View File

@@ -6,6 +6,7 @@ import (
"strings" "strings"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
// ColumnValidator validates column names against a model's fields // ColumnValidator validates column names against a model's fields
@@ -92,23 +93,6 @@ func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
return strings.ToLower(field.Name) return strings.ToLower(field.Name)
} }
// extractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func extractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// ValidateColumn validates a single column name // ValidateColumn validates a single column name
// Returns nil if valid, error if invalid // Returns nil if valid, error if invalid
// Columns prefixed with "cql" (case insensitive) are always valid // Columns prefixed with "cql" (case insensitive) are always valid
@@ -125,7 +109,7 @@ func (v *ColumnValidator) ValidateColumn(column string) error {
} }
// Extract source column name (remove JSON operators like ->> or ->) // Extract source column name (remove JSON operators like ->> or ->)
sourceColumn := extractSourceColumn(column) sourceColumn := reflection.ExtractSourceColumn(column)
// Check if column exists in model // Check if column exists in model
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists { if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {

View File

@@ -2,6 +2,8 @@ package common
import ( import (
"testing" "testing"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
func TestExtractSourceColumn(t *testing.T) { func TestExtractSourceColumn(t *testing.T) {
@@ -49,9 +51,9 @@ func TestExtractSourceColumn(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
result := extractSourceColumn(tc.input) result := reflection.ExtractSourceColumn(tc.input)
if result != tc.expected { if result != tc.expected {
t.Errorf("extractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected) t.Errorf("reflection.ExtractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
} }
}) })
} }

291
pkg/config/README.md Normal file
View File

@@ -0,0 +1,291 @@
# ResolveSpec Configuration System
A centralized configuration system with support for multiple configuration sources: config files (YAML, TOML, JSON), environment variables, and programmatic configuration.
## Features
- **Multiple Config Sources**: Config files, environment variables, and code
- **Priority Order**: Environment variables > Config file > Defaults
- **Multiple Formats**: YAML, TOML, JSON supported
- **Type Safety**: Strongly-typed configuration structs
- **Sensible Defaults**: Works out of the box with reasonable defaults
## Quick Start
### Basic Usage
```go
import "github.com/heinhel/ResolveSpec/pkg/config"
// Create a new config manager
mgr := config.NewManager()
// Load configuration from file and environment
if err := mgr.Load(); err != nil {
log.Fatal(err)
}
// Get the complete configuration
cfg, err := mgr.GetConfig()
if err != nil {
log.Fatal(err)
}
// Use the configuration
fmt.Println("Server address:", cfg.Server.Addr)
```
### Custom Configuration Paths
```go
mgr := config.NewManagerWithOptions(
config.WithConfigFile("/path/to/config.yaml"),
config.WithEnvPrefix("MYAPP"),
)
```
## Configuration Sources
### 1. Config Files
Place a `config.yaml` file in one of these locations:
- Current directory (`.`)
- `./config/`
- `/etc/resolvespec/`
- `$HOME/.resolvespec/`
Example `config.yaml`:
```yaml
server:
addr: ":8080"
shutdown_timeout: 30s
tracing:
enabled: true
service_name: "my-service"
cache:
provider: "redis"
redis:
host: "localhost"
port: 6379
```
### 2. Environment Variables
All configuration can be set via environment variables with the `RESOLVESPEC_` prefix:
```bash
export RESOLVESPEC_SERVER_ADDR=":9090"
export RESOLVESPEC_TRACING_ENABLED=true
export RESOLVESPEC_CACHE_PROVIDER=redis
export RESOLVESPEC_CACHE_REDIS_HOST=localhost
```
Nested configuration uses underscores:
- `server.addr``RESOLVESPEC_SERVER_ADDR`
- `cache.redis.host``RESOLVESPEC_CACHE_REDIS_HOST`
### 3. Programmatic Configuration
```go
mgr := config.NewManager()
mgr.Set("server.addr", ":9090")
mgr.Set("tracing.enabled", true)
cfg, _ := mgr.GetConfig()
```
## Configuration Options
### Server Configuration
```yaml
server:
addr: ":8080" # Server address
shutdown_timeout: 30s # Graceful shutdown timeout
drain_timeout: 25s # Connection drain timeout
read_timeout: 10s # HTTP read timeout
write_timeout: 10s # HTTP write timeout
idle_timeout: 120s # HTTP idle timeout
```
### Tracing Configuration
```yaml
tracing:
enabled: false # Enable/disable tracing
service_name: "resolvespec" # Service name
service_version: "1.0.0" # Service version
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
```
### Cache Configuration
```yaml
cache:
provider: "memory" # Options: memory, redis, memcache
redis:
host: "localhost"
port: 6379
password: ""
db: 0
memcache:
servers:
- "localhost:11211"
max_idle_conns: 10
timeout: 100ms
```
### Logger Configuration
```yaml
logger:
dev: false # Development mode (human-readable output)
path: "" # Log file path (empty = stdout)
```
### Middleware Configuration
```yaml
middleware:
rate_limit_rps: 100.0 # Requests per second
rate_limit_burst: 200 # Burst size
max_request_size: 10485760 # Max request size in bytes (10MB)
```
### CORS Configuration
```yaml
cors:
allowed_origins:
- "*"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
- "OPTIONS"
allowed_headers:
- "*"
max_age: 3600
```
### Database Configuration
```yaml
database:
url: "host=localhost user=postgres password=postgres dbname=mydb port=5432 sslmode=disable"
```
## Priority and Overrides
Configuration sources are applied in this order (highest priority first):
1. **Environment Variables** (highest priority)
2. **Config File**
3. **Defaults** (lowest priority)
This allows you to:
- Set defaults in code
- Override with a config file
- Override specific values with environment variables
## Examples
### Production Setup
```yaml
# config.yaml
server:
addr: ":8080"
tracing:
enabled: true
service_name: "myapi"
endpoint: "http://jaeger:4318/v1/traces"
cache:
provider: "redis"
redis:
host: "redis"
port: 6379
password: "${REDIS_PASSWORD}"
logger:
dev: false
path: "/var/log/myapi/app.log"
```
### Development Setup
```bash
# Use environment variables for development
export RESOLVESPEC_LOGGER_DEV=true
export RESOLVESPEC_TRACING_ENABLED=false
export RESOLVESPEC_CACHE_PROVIDER=memory
```
### Testing Setup
```go
// Override config for tests
mgr := config.NewManager()
mgr.Set("cache.provider", "memory")
mgr.Set("database.url", testDBURL)
cfg, _ := mgr.GetConfig()
```
## Best Practices
1. **Use config files for base configuration** - Define your standard settings
2. **Use environment variables for secrets** - Never commit passwords/tokens
3. **Use environment variables for deployment-specific values** - Different per environment
4. **Keep defaults sensible** - Application should work with minimal configuration
5. **Document your configuration** - Comment your config.yaml files
## Integration with ResolveSpec Components
The configuration system integrates seamlessly with ResolveSpec components:
```go
cfg, _ := config.NewManager().Load().GetConfig()
// Server
srv := server.NewGracefulServer(server.Config{
Addr: cfg.Server.Addr,
ShutdownTimeout: cfg.Server.ShutdownTimeout,
// ... other fields
})
// Tracing
if cfg.Tracing.Enabled {
tracer := tracing.Init(tracing.Config{
ServiceName: cfg.Tracing.ServiceName,
ServiceVersion: cfg.Tracing.ServiceVersion,
Endpoint: cfg.Tracing.Endpoint,
})
defer tracer.Shutdown(context.Background())
}
// Cache
var cacheProvider cache.Provider
switch cfg.Cache.Provider {
case "redis":
cacheProvider = cache.NewRedisProvider(cfg.Cache.Redis.Host, cfg.Cache.Redis.Port, ...)
case "memcache":
cacheProvider = cache.NewMemcacheProvider(cfg.Cache.Memcache.Servers, ...)
default:
cacheProvider = cache.NewMemoryProvider()
}
// Logger
logger.Init(cfg.Logger.Dev)
if cfg.Logger.Path != "" {
logger.UpdateLoggerPath(cfg.Logger.Path, cfg.Logger.Dev)
}
```

93
pkg/config/config.go Normal file
View File

@@ -0,0 +1,93 @@
package config
import "time"
// Config represents the complete application configuration
type Config struct {
Server ServerConfig `mapstructure:"server"`
Tracing TracingConfig `mapstructure:"tracing"`
Cache CacheConfig `mapstructure:"cache"`
Logger LoggerConfig `mapstructure:"logger"`
ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"`
Middleware MiddlewareConfig `mapstructure:"middleware"`
CORS CORSConfig `mapstructure:"cors"`
Database DatabaseConfig `mapstructure:"database"`
}
// ServerConfig holds server-related configuration
type ServerConfig struct {
Addr string `mapstructure:"addr"`
ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
DrainTimeout time.Duration `mapstructure:"drain_timeout"`
ReadTimeout time.Duration `mapstructure:"read_timeout"`
WriteTimeout time.Duration `mapstructure:"write_timeout"`
IdleTimeout time.Duration `mapstructure:"idle_timeout"`
}
// TracingConfig holds OpenTelemetry tracing configuration
type TracingConfig struct {
Enabled bool `mapstructure:"enabled"`
ServiceName string `mapstructure:"service_name"`
ServiceVersion string `mapstructure:"service_version"`
Endpoint string `mapstructure:"endpoint"`
}
// CacheConfig holds cache provider configuration
type CacheConfig struct {
Provider string `mapstructure:"provider"` // memory, redis, memcache
Redis RedisConfig `mapstructure:"redis"`
Memcache MemcacheConfig `mapstructure:"memcache"`
}
// RedisConfig holds Redis-specific configuration
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
}
// MemcacheConfig holds Memcache-specific configuration
type MemcacheConfig struct {
Servers []string `mapstructure:"servers"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
Timeout time.Duration `mapstructure:"timeout"`
}
// LoggerConfig holds logger configuration
type LoggerConfig struct {
Dev bool `mapstructure:"dev"`
Path string `mapstructure:"path"`
}
// MiddlewareConfig holds middleware configuration
type MiddlewareConfig struct {
RateLimitRPS float64 `mapstructure:"rate_limit_rps"`
RateLimitBurst int `mapstructure:"rate_limit_burst"`
MaxRequestSize int64 `mapstructure:"max_request_size"`
}
// CORSConfig holds CORS configuration
type CORSConfig struct {
AllowedOrigins []string `mapstructure:"allowed_origins"`
AllowedMethods []string `mapstructure:"allowed_methods"`
AllowedHeaders []string `mapstructure:"allowed_headers"`
MaxAge int `mapstructure:"max_age"`
}
// DatabaseConfig holds database configuration (primarily for testing)
type DatabaseConfig struct {
URL string `mapstructure:"url"`
}
// ErrorTrackingConfig holds error tracking configuration
type ErrorTrackingConfig struct {
Enabled bool `mapstructure:"enabled"`
Provider string `mapstructure:"provider"` // sentry, noop
DSN string `mapstructure:"dsn"` // Sentry DSN
Environment string `mapstructure:"environment"` // e.g., production, staging, development
Release string `mapstructure:"release"` // Application version/release
Debug bool `mapstructure:"debug"` // Enable debug mode
SampleRate float64 `mapstructure:"sample_rate"` // Error sample rate (0.0-1.0)
TracesSampleRate float64 `mapstructure:"traces_sample_rate"` // Traces sample rate (0.0-1.0)
}

168
pkg/config/manager.go Normal file
View File

@@ -0,0 +1,168 @@
package config
import (
"fmt"
"strings"
"github.com/spf13/viper"
)
// Manager handles configuration loading from multiple sources
type Manager struct {
v *viper.Viper
}
// NewManager creates a new configuration manager with defaults
func NewManager() *Manager {
v := viper.New()
// Set configuration file settings
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(".")
v.AddConfigPath("./config")
v.AddConfigPath("/etc/resolvespec")
v.AddConfigPath("$HOME/.resolvespec")
// Enable environment variable support
v.SetEnvPrefix("RESOLVESPEC")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
// Set default values
setDefaults(v)
return &Manager{v: v}
}
// NewManagerWithOptions creates a new configuration manager with custom options
func NewManagerWithOptions(opts ...Option) *Manager {
m := NewManager()
for _, opt := range opts {
opt(m)
}
return m
}
// Option is a functional option for configuring the Manager
type Option func(*Manager)
// WithConfigFile sets a specific config file path
func WithConfigFile(path string) Option {
return func(m *Manager) {
m.v.SetConfigFile(path)
}
}
// WithConfigName sets the config file name (without extension)
func WithConfigName(name string) Option {
return func(m *Manager) {
m.v.SetConfigName(name)
}
}
// WithConfigPath adds a path to search for config files
func WithConfigPath(path string) Option {
return func(m *Manager) {
m.v.AddConfigPath(path)
}
}
// WithEnvPrefix sets the environment variable prefix
func WithEnvPrefix(prefix string) Option {
return func(m *Manager) {
m.v.SetEnvPrefix(prefix)
}
}
// Load attempts to load configuration from file and environment
func (m *Manager) Load() error {
// Try to read config file (not an error if it doesn't exist)
if err := m.v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return fmt.Errorf("error reading config file: %w", err)
}
// Config file not found; will rely on defaults and env vars
}
return nil
}
// GetConfig returns the complete configuration
func (m *Manager) GetConfig() (*Config, error) {
var cfg Config
if err := m.v.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
return &cfg, nil
}
// Get returns a configuration value by key
func (m *Manager) Get(key string) interface{} {
return m.v.Get(key)
}
// GetString returns a string configuration value
func (m *Manager) GetString(key string) string {
return m.v.GetString(key)
}
// GetInt returns an int configuration value
func (m *Manager) GetInt(key string) int {
return m.v.GetInt(key)
}
// GetBool returns a bool configuration value
func (m *Manager) GetBool(key string) bool {
return m.v.GetBool(key)
}
// Set sets a configuration value
func (m *Manager) Set(key string, value interface{}) {
m.v.Set(key, value)
}
// setDefaults sets default configuration values
func setDefaults(v *viper.Viper) {
// Server defaults
v.SetDefault("server.addr", ":8080")
v.SetDefault("server.shutdown_timeout", "30s")
v.SetDefault("server.drain_timeout", "25s")
v.SetDefault("server.read_timeout", "10s")
v.SetDefault("server.write_timeout", "10s")
v.SetDefault("server.idle_timeout", "120s")
// Tracing defaults
v.SetDefault("tracing.enabled", false)
v.SetDefault("tracing.service_name", "resolvespec")
v.SetDefault("tracing.service_version", "1.0.0")
v.SetDefault("tracing.endpoint", "")
// Cache defaults
v.SetDefault("cache.provider", "memory")
v.SetDefault("cache.redis.host", "localhost")
v.SetDefault("cache.redis.port", 6379)
v.SetDefault("cache.redis.password", "")
v.SetDefault("cache.redis.db", 0)
v.SetDefault("cache.memcache.servers", []string{"localhost:11211"})
v.SetDefault("cache.memcache.max_idle_conns", 10)
v.SetDefault("cache.memcache.timeout", "100ms")
// Logger defaults
v.SetDefault("logger.dev", false)
v.SetDefault("logger.path", "")
// Middleware defaults
v.SetDefault("middleware.rate_limit_rps", 100.0)
v.SetDefault("middleware.rate_limit_burst", 200)
v.SetDefault("middleware.max_request_size", 10485760) // 10MB
// CORS defaults
v.SetDefault("cors.allowed_origins", []string{"*"})
v.SetDefault("cors.allowed_methods", []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"})
v.SetDefault("cors.allowed_headers", []string{"*"})
v.SetDefault("cors.max_age", 3600)
// Database defaults
v.SetDefault("database.url", "")
}

166
pkg/config/manager_test.go Normal file
View File

@@ -0,0 +1,166 @@
package config
import (
"os"
"testing"
"time"
)
func TestNewManager(t *testing.T) {
mgr := NewManager()
if mgr == nil {
t.Fatal("Expected manager to be non-nil")
}
if mgr.v == nil {
t.Fatal("Expected viper instance to be non-nil")
}
}
func TestDefaultValues(t *testing.T) {
mgr := NewManager()
if err := mgr.Load(); err != nil {
t.Fatalf("Failed to load config: %v", err)
}
cfg, err := mgr.GetConfig()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
// Test default values
tests := []struct {
name string
got interface{}
expected interface{}
}{
{"server.addr", cfg.Server.Addr, ":8080"},
{"server.shutdown_timeout", cfg.Server.ShutdownTimeout, 30 * time.Second},
{"tracing.enabled", cfg.Tracing.Enabled, false},
{"tracing.service_name", cfg.Tracing.ServiceName, "resolvespec"},
{"cache.provider", cfg.Cache.Provider, "memory"},
{"cache.redis.host", cfg.Cache.Redis.Host, "localhost"},
{"cache.redis.port", cfg.Cache.Redis.Port, 6379},
{"logger.dev", cfg.Logger.Dev, false},
{"middleware.rate_limit_rps", cfg.Middleware.RateLimitRPS, 100.0},
{"middleware.rate_limit_burst", cfg.Middleware.RateLimitBurst, 200},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.expected {
t.Errorf("%s: got %v, want %v", tt.name, tt.got, tt.expected)
}
})
}
}
func TestEnvironmentVariableOverrides(t *testing.T) {
// Set environment variables
os.Setenv("RESOLVESPEC_SERVER_ADDR", ":9090")
os.Setenv("RESOLVESPEC_TRACING_ENABLED", "true")
os.Setenv("RESOLVESPEC_CACHE_PROVIDER", "redis")
os.Setenv("RESOLVESPEC_LOGGER_DEV", "true")
defer func() {
os.Unsetenv("RESOLVESPEC_SERVER_ADDR")
os.Unsetenv("RESOLVESPEC_TRACING_ENABLED")
os.Unsetenv("RESOLVESPEC_CACHE_PROVIDER")
os.Unsetenv("RESOLVESPEC_LOGGER_DEV")
}()
mgr := NewManager()
if err := mgr.Load(); err != nil {
t.Fatalf("Failed to load config: %v", err)
}
cfg, err := mgr.GetConfig()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
// Test environment variable overrides
tests := []struct {
name string
got interface{}
expected interface{}
}{
{"server.addr", cfg.Server.Addr, ":9090"},
{"tracing.enabled", cfg.Tracing.Enabled, true},
{"cache.provider", cfg.Cache.Provider, "redis"},
{"logger.dev", cfg.Logger.Dev, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.expected {
t.Errorf("%s: got %v, want %v", tt.name, tt.got, tt.expected)
}
})
}
}
func TestProgrammaticConfiguration(t *testing.T) {
mgr := NewManager()
mgr.Set("server.addr", ":7070")
mgr.Set("tracing.service_name", "test-service")
cfg, err := mgr.GetConfig()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
if cfg.Server.Addr != ":7070" {
t.Errorf("server.addr: got %s, want :7070", cfg.Server.Addr)
}
if cfg.Tracing.ServiceName != "test-service" {
t.Errorf("tracing.service_name: got %s, want test-service", cfg.Tracing.ServiceName)
}
}
func TestGetterMethods(t *testing.T) {
mgr := NewManager()
mgr.Set("test.string", "value")
mgr.Set("test.int", 42)
mgr.Set("test.bool", true)
if got := mgr.GetString("test.string"); got != "value" {
t.Errorf("GetString: got %s, want value", got)
}
if got := mgr.GetInt("test.int"); got != 42 {
t.Errorf("GetInt: got %d, want 42", got)
}
if got := mgr.GetBool("test.bool"); !got {
t.Errorf("GetBool: got %v, want true", got)
}
}
func TestWithOptions(t *testing.T) {
mgr := NewManagerWithOptions(
WithEnvPrefix("MYAPP"),
WithConfigName("myconfig"),
)
if mgr == nil {
t.Fatal("Expected manager to be non-nil")
}
// Set environment variable with custom prefix
os.Setenv("MYAPP_SERVER_ADDR", ":5000")
defer os.Unsetenv("MYAPP_SERVER_ADDR")
if err := mgr.Load(); err != nil {
t.Fatalf("Failed to load config: %v", err)
}
cfg, err := mgr.GetConfig()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
if cfg.Server.Addr != ":5000" {
t.Errorf("server.addr: got %s, want :5000", cfg.Server.Addr)
}
}

150
pkg/errortracking/README.md Normal file
View File

@@ -0,0 +1,150 @@
# Error Tracking
This package provides error tracking integration for ResolveSpec, with built-in support for Sentry.
## Features
- **Provider Interface**: Flexible design supporting multiple error tracking backends
- **Sentry Integration**: Full-featured Sentry support with automatic error, warning, and panic tracking
- **Automatic Logger Integration**: All `logger.Error()` and `logger.Warn()` calls are automatically sent to the error tracker
- **Panic Tracking**: Automatic panic capture with stack traces
- **NoOp Provider**: Zero-overhead when error tracking is disabled
## Configuration
Add error tracking configuration to your config file:
```yaml
error_tracking:
enabled: true
provider: "sentry" # Currently supports: "sentry" or "noop"
dsn: "https://your-sentry-dsn@sentry.io/project-id"
environment: "production" # e.g., production, staging, development
release: "v1.0.0" # Your application version
debug: false
sample_rate: 1.0 # Error sample rate (0.0-1.0)
traces_sample_rate: 0.1 # Traces sample rate (0.0-1.0)
```
## Usage
### Initialization
Initialize error tracking in your application startup:
```go
package main
import (
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
func main() {
// Load your configuration
cfg := config.Config{
ErrorTracking: config.ErrorTrackingConfig{
Enabled: true,
Provider: "sentry",
DSN: "https://your-sentry-dsn@sentry.io/project-id",
Environment: "production",
Release: "v1.0.0",
SampleRate: 1.0,
},
}
// Initialize logger
logger.Init(false)
// Initialize error tracking
provider, err := errortracking.NewProviderFromConfig(cfg.ErrorTracking)
if err != nil {
logger.Error("Failed to initialize error tracking: %v", err)
} else {
logger.InitErrorTracking(provider)
}
// Your application code...
// Cleanup on shutdown
defer logger.CloseErrorTracking()
}
```
### Automatic Tracking
Once initialized, all logger errors and warnings are automatically sent to the error tracker:
```go
// This will be logged AND sent to Sentry
logger.Error("Database connection failed: %v", err)
// This will also be logged AND sent to Sentry
logger.Warn("Cache miss for key: %s", key)
```
### Panic Tracking
Panics are automatically captured when using the logger's panic handlers:
```go
// Using CatchPanic
defer logger.CatchPanic("MyFunction")
// Using CatchPanicCallback
defer logger.CatchPanicCallback("MyFunction", func(err any) {
// Custom cleanup
})
// Using HandlePanic
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("MyMethod", r)
}
}()
```
### Manual Tracking
You can also use the provider directly for custom error tracking:
```go
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
func someFunction() {
tracker := logger.GetErrorTracker()
if tracker != nil {
// Capture an error
tracker.CaptureError(context.Background(), err, errortracking.SeverityError, map[string]interface{}{
"user_id": userID,
"request_id": requestID,
})
// Capture a message
tracker.CaptureMessage(context.Background(), "Important event occurred", errortracking.SeverityInfo, map[string]interface{}{
"event_type": "user_signup",
})
// Capture a panic
tracker.CapturePanic(context.Background(), recovered, stackTrace, map[string]interface{}{
"context": "background_job",
})
}
}
```
## Severity Levels
The package supports the following severity levels:
- `SeverityError`: For errors that should be tracked and investigated
- `SeverityWarning`: For warnings that may indicate potential issues
- `SeverityInfo`: For informational messages
- `SeverityDebug`: For debug-level information
```

View File

@@ -0,0 +1,67 @@
package errortracking
import (
"context"
"errors"
"testing"
)
func TestNoOpProvider(t *testing.T) {
provider := NewNoOpProvider()
// Test that all methods can be called without panicking
t.Run("CaptureError", func(t *testing.T) {
provider.CaptureError(context.Background(), errors.New("test error"), SeverityError, nil)
})
t.Run("CaptureMessage", func(t *testing.T) {
provider.CaptureMessage(context.Background(), "test message", SeverityWarning, nil)
})
t.Run("CapturePanic", func(t *testing.T) {
provider.CapturePanic(context.Background(), "panic!", []byte("stack trace"), nil)
})
t.Run("Flush", func(t *testing.T) {
result := provider.Flush(5)
if !result {
t.Error("Expected Flush to return true")
}
})
t.Run("Close", func(t *testing.T) {
err := provider.Close()
if err != nil {
t.Errorf("Expected Close to return nil, got %v", err)
}
})
}
func TestSeverityLevels(t *testing.T) {
tests := []struct {
name string
severity Severity
expected string
}{
{"Error", SeverityError, "error"},
{"Warning", SeverityWarning, "warning"},
{"Info", SeverityInfo, "info"},
{"Debug", SeverityDebug, "debug"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if string(tt.severity) != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, string(tt.severity))
}
})
}
}
func TestProviderInterface(t *testing.T) {
// Test that NoOpProvider implements Provider interface
var _ Provider = (*NoOpProvider)(nil)
// Test that SentryProvider implements Provider interface
var _ Provider = (*SentryProvider)(nil)
}

View File

@@ -0,0 +1,33 @@
package errortracking
import (
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/config"
)
// NewProviderFromConfig creates an error tracking provider based on the configuration
func NewProviderFromConfig(cfg config.ErrorTrackingConfig) (Provider, error) {
if !cfg.Enabled {
return NewNoOpProvider(), nil
}
switch cfg.Provider {
case "sentry":
if cfg.DSN == "" {
return nil, fmt.Errorf("sentry DSN is required when error tracking is enabled")
}
return NewSentryProvider(SentryConfig{
DSN: cfg.DSN,
Environment: cfg.Environment,
Release: cfg.Release,
Debug: cfg.Debug,
SampleRate: cfg.SampleRate,
TracesSampleRate: cfg.TracesSampleRate,
})
case "noop", "":
return NewNoOpProvider(), nil
default:
return nil, fmt.Errorf("unknown error tracking provider: %s", cfg.Provider)
}
}

View File

@@ -0,0 +1,33 @@
package errortracking
import (
"context"
)
// Severity represents the severity level of an error
type Severity string
const (
SeverityError Severity = "error"
SeverityWarning Severity = "warning"
SeverityInfo Severity = "info"
SeverityDebug Severity = "debug"
)
// Provider defines the interface for error tracking providers
type Provider interface {
// CaptureError captures an error with the given severity and additional context
CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{})
// CaptureMessage captures a message with the given severity and additional context
CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{})
// CapturePanic captures a panic with stack trace
CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{})
// Flush waits for all events to be sent (useful for graceful shutdown)
Flush(timeout int) bool
// Close closes the provider and releases resources
Close() error
}

37
pkg/errortracking/noop.go Normal file
View File

@@ -0,0 +1,37 @@
package errortracking
import "context"
// NoOpProvider is a no-op implementation of the Provider interface
// Used when error tracking is disabled
type NoOpProvider struct{}
// NewNoOpProvider creates a new NoOp provider
func NewNoOpProvider() *NoOpProvider {
return &NoOpProvider{}
}
// CaptureError does nothing
func (n *NoOpProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
// No-op
}
// CaptureMessage does nothing
func (n *NoOpProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
// No-op
}
// CapturePanic does nothing
func (n *NoOpProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
// No-op
}
// Flush does nothing and returns true
func (n *NoOpProvider) Flush(timeout int) bool {
return true
}
// Close does nothing
func (n *NoOpProvider) Close() error {
return nil
}

154
pkg/errortracking/sentry.go Normal file
View File

@@ -0,0 +1,154 @@
package errortracking
import (
"context"
"fmt"
"time"
"github.com/getsentry/sentry-go"
)
// SentryProvider implements the Provider interface using Sentry
type SentryProvider struct {
hub *sentry.Hub
}
// SentryConfig holds the configuration for Sentry
type SentryConfig struct {
DSN string
Environment string
Release string
Debug bool
SampleRate float64
TracesSampleRate float64
}
// NewSentryProvider creates a new Sentry provider
func NewSentryProvider(config SentryConfig) (*SentryProvider, error) {
err := sentry.Init(sentry.ClientOptions{
Dsn: config.DSN,
Environment: config.Environment,
Release: config.Release,
Debug: config.Debug,
AttachStacktrace: true,
SampleRate: config.SampleRate,
TracesSampleRate: config.TracesSampleRate,
})
if err != nil {
return nil, fmt.Errorf("failed to initialize Sentry: %w", err)
}
return &SentryProvider{
hub: sentry.CurrentHub(),
}, nil
}
// CaptureError captures an error with the given severity and additional context
func (s *SentryProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
if err == nil {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = s.convertSeverity(severity)
event.Message = err.Error()
event.Exception = []sentry.Exception{
{
Value: err.Error(),
Type: fmt.Sprintf("%T", err),
Stacktrace: sentry.ExtractStacktrace(err),
},
}
if extra != nil {
event.Extra = extra
}
hub.CaptureEvent(event)
}
// CaptureMessage captures a message with the given severity and additional context
func (s *SentryProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
if message == "" {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = s.convertSeverity(severity)
event.Message = message
if extra != nil {
event.Extra = extra
}
hub.CaptureEvent(event)
}
// CapturePanic captures a panic with stack trace
func (s *SentryProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
if recovered == nil {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = sentry.LevelError
event.Message = fmt.Sprintf("Panic: %v", recovered)
event.Exception = []sentry.Exception{
{
Value: fmt.Sprintf("%v", recovered),
Type: "panic",
},
}
if extra != nil {
event.Extra = extra
}
if stackTrace != nil {
event.Extra["stack_trace"] = string(stackTrace)
}
hub.CaptureEvent(event)
}
// Flush waits for all events to be sent (useful for graceful shutdown)
func (s *SentryProvider) Flush(timeout int) bool {
return sentry.Flush(time.Duration(timeout) * time.Second)
}
// Close closes the provider and releases resources
func (s *SentryProvider) Close() error {
sentry.Flush(2 * time.Second)
return nil
}
// convertSeverity converts our Severity to Sentry's Level
func (s *SentryProvider) convertSeverity(severity Severity) sentry.Level {
switch severity {
case SeverityError:
return sentry.LevelError
case SeverityWarning:
return sentry.LevelWarning
case SeverityInfo:
return sentry.LevelInfo
case SeverityDebug:
return sentry.LevelDebug
default:
return sentry.LevelError
}
}

1021
pkg/funcspec/function_api.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,906 @@
package funcspec
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// MockDatabase implements common.Database interface for testing
type MockDatabase struct {
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
}
func (m *MockDatabase) NewSelect() common.SelectQuery {
return nil
}
func (m *MockDatabase) NewInsert() common.InsertQuery {
return nil
}
func (m *MockDatabase) NewUpdate() common.UpdateQuery {
return nil
}
func (m *MockDatabase) NewDelete() common.DeleteQuery {
return nil
}
func (m *MockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
if m.ExecFunc != nil {
return m.ExecFunc(ctx, query, args...)
}
return &MockResult{rows: 0}, nil
}
func (m *MockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if m.QueryFunc != nil {
return m.QueryFunc(ctx, dest, query, args...)
}
return nil
}
func (m *MockDatabase) BeginTx(ctx context.Context) (common.Database, error) {
return m, nil
}
func (m *MockDatabase) CommitTx(ctx context.Context) error {
return nil
}
func (m *MockDatabase) RollbackTx(ctx context.Context) error {
return nil
}
func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
if m.RunInTransactionFunc != nil {
return m.RunInTransactionFunc(ctx, fn)
}
return fn(m)
}
// MockResult implements common.Result interface for testing
type MockResult struct {
rows int64
id int64
}
func (m *MockResult) RowsAffected() int64 {
return m.rows
}
func (m *MockResult) LastInsertId() (int64, error) {
return m.id, nil
}
// Helper function to create a test request with user context
func createTestRequest(method, path string, queryParams map[string]string, headers map[string]string, body []byte) *http.Request {
u, _ := url.Parse(path)
if queryParams != nil {
q := u.Query()
for k, v := range queryParams {
q.Set(k, v)
}
u.RawQuery = q.Encode()
}
var bodyReader *bytes.Reader
if body != nil {
bodyReader = bytes.NewReader(body)
} else {
bodyReader = bytes.NewReader([]byte{})
}
req := httptest.NewRequest(method, u.String(), bodyReader)
if headers != nil {
for k, v := range headers {
req.Header.Set(k, v)
}
}
// Add user context
userCtx := &security.UserContext{
UserID: 1,
UserName: "testuser",
SessionID: "test-session-123",
}
ctx := context.WithValue(req.Context(), security.UserContextKey, userCtx)
req = req.WithContext(ctx)
return req
}
// TestNewHandler tests handler creation
func TestNewHandler(t *testing.T) {
db := &MockDatabase{}
handler := NewHandler(db)
if handler == nil {
t.Fatal("Expected handler to be created, got nil")
}
if handler.db != db {
t.Error("Expected handler to have the provided database")
}
if handler.hooks == nil {
t.Error("Expected handler to have a hook registry")
}
}
// TestHandlerHooks tests the Hooks method
func TestHandlerHooks(t *testing.T) {
handler := NewHandler(&MockDatabase{})
hooks := handler.Hooks()
if hooks == nil {
t.Fatal("Expected hooks registry to be non-nil")
}
// Should return the same instance
hooks2 := handler.Hooks()
if hooks != hooks2 {
t.Error("Expected Hooks() to return the same registry instance")
}
}
// TestExtractInputVariables tests the extractInputVariables function
func TestExtractInputVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
expectedVars []string
}{
{
name: "No variables",
sqlQuery: "SELECT * FROM users",
expectedVars: []string{},
},
{
name: "Single variable",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
expectedVars: []string{"[user_id]"},
},
{
name: "Multiple variables",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = [user_name]",
expectedVars: []string{"[user_id]", "[user_name]"},
},
{
name: "Nested brackets",
sqlQuery: "SELECT * FROM users WHERE data::jsonb @> '[field]'::jsonb AND id = [user_id]",
expectedVars: []string{"[field]", "[user_id]"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
inputvars := make([]string, 0)
result := handler.extractInputVariables(tt.sqlQuery, &inputvars)
if result != tt.sqlQuery {
t.Errorf("Expected SQL query to be unchanged, got %s", result)
}
if len(inputvars) != len(tt.expectedVars) {
t.Errorf("Expected %d variables, got %d: %v", len(tt.expectedVars), len(inputvars), inputvars)
return
}
for i, expected := range tt.expectedVars {
if inputvars[i] != expected {
t.Errorf("Expected variable %d to be %s, got %s", i, expected, inputvars[i])
}
}
})
}
}
// TestValidSQL tests the SQL sanitization function
func TestValidSQL(t *testing.T) {
tests := []struct {
name string
input string
mode string
expected string
}{
{
name: "Column name with valid characters",
input: "user_id",
mode: "colname",
expected: "user_id",
},
{
name: "Column name with dots (table.column)",
input: "users.user_id",
mode: "colname",
expected: "users.user_id",
},
{
name: "Column name with SQL injection attempt",
input: "id'; DROP TABLE users--",
mode: "colname",
expected: "idDROPTABLEusers",
},
{
name: "Column value with single quotes",
input: "O'Brien",
mode: "colvalue",
expected: "O''Brien",
},
{
name: "Select with dangerous keywords",
input: "name, email; DROP TABLE users",
mode: "select",
expected: "name, email TABLE users",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ValidSQL(tt.input, tt.mode)
if result != tt.expected {
t.Errorf("ValidSQL(%q, %q) = %q, expected %q", tt.input, tt.mode, result, tt.expected)
}
})
}
}
// TestIsNumeric tests the IsNumeric function
func TestIsNumeric(t *testing.T) {
tests := []struct {
input string
expected bool
}{
{"123", true},
{"123.45", true},
{"-123", true},
{"-123.45", true},
{"0", true},
{"abc", false},
{"12.34.56", false},
{"", false},
{"123abc", false},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := IsNumeric(tt.input)
if result != tt.expected {
t.Errorf("IsNumeric(%q) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
// TestSqlQryWhere tests the WHERE clause manipulation
func TestSqlQryWhere(t *testing.T) {
tests := []struct {
name string
sqlQuery string
condition string
expected string
}{
{
name: "Add WHERE to query without WHERE",
sqlQuery: "SELECT * FROM users",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' ",
},
{
name: "Add AND to query with existing WHERE",
sqlQuery: "SELECT * FROM users WHERE id > 0",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE id > 0 AND status = 'active' ",
},
{
name: "Add WHERE before ORDER BY",
sqlQuery: "SELECT * FROM users ORDER BY name",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' ORDER BY name",
},
{
name: "Add WHERE before GROUP BY",
sqlQuery: "SELECT COUNT(*) FROM users GROUP BY department",
condition: "status = 'active'",
expected: "SELECT COUNT(*) FROM users WHERE status = 'active' GROUP BY department",
},
{
name: "Add WHERE before LIMIT",
sqlQuery: "SELECT * FROM users LIMIT 10",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' LIMIT 10",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sqlQryWhere(tt.sqlQuery, tt.condition)
if result != tt.expected {
t.Errorf("sqlQryWhere() = %q, expected %q", result, tt.expected)
}
})
}
}
// TestGetIPAddress tests IP address extraction
func TestGetIPAddress(t *testing.T) {
tests := []struct {
name string
setupReq func() *http.Request
expected string
}{
{
name: "X-Forwarded-For header",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
return req
},
expected: "192.168.1.100",
},
{
name: "X-Real-IP header",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.200")
return req
},
expected: "192.168.1.200",
},
{
name: "RemoteAddr fallback",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
return req
},
expected: "192.168.1.1:12345",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupReq()
result := getIPAddress(req)
if result != tt.expected {
t.Errorf("getIPAddress() = %q, expected %q", result, tt.expected)
}
})
}
}
// TestParsePaginationParams tests pagination parameter parsing
func TestParsePaginationParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
queryParams map[string]string
expectedSort string
expectedLimit int
expectedOffset int
}{
{
name: "No parameters - defaults",
queryParams: map[string]string{},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "All parameters provided",
queryParams: map[string]string{
"sort": "name,-created_at",
"limit": "100",
"offset": "50",
},
expectedSort: "name,-created_at",
expectedLimit: 100,
expectedOffset: 50,
},
{
name: "Invalid limit - use default",
queryParams: map[string]string{
"limit": "invalid",
},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "Negative offset - use default",
queryParams: map[string]string{
"offset": "-10",
},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
sort, limit, offset := handler.parsePaginationParams(req)
if sort != tt.expectedSort {
t.Errorf("Expected sort=%q, got %q", tt.expectedSort, sort)
}
if limit != tt.expectedLimit {
t.Errorf("Expected limit=%d, got %d", tt.expectedLimit, limit)
}
if offset != tt.expectedOffset {
t.Errorf("Expected offset=%d, got %d", tt.expectedOffset, offset)
}
})
}
}
// TestSqlQuery tests the SqlQuery handler for single record queries
func TestSqlQuery(t *testing.T) {
tests := []struct {
name string
sqlQuery string
blankParams bool
queryParams map[string]string
headers map[string]string
setupDB func() *MockDatabase
expectedStatus int
validateResp func(t *testing.T, body []byte)
}{
{
name: "Basic query - returns single record",
sqlQuery: "SELECT * FROM users WHERE id = 1",
blankParams: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "Test User", "email": "test@example.com"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, body []byte) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if result["name"] != "Test User" {
t.Errorf("Expected name='Test User', got %v", result["name"])
}
},
},
{
name: "Query with no results",
sqlQuery: "SELECT * FROM users WHERE id = 999",
blankParams: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
// Return empty array
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, body []byte) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 0 {
t.Errorf("Expected empty result, got %v", result)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.setupDB()
handler := NewHandler(db)
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQuery(tt.sqlQuery, tt.blankParams)
handlerFunc(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
if tt.validateResp != nil {
tt.validateResp(t, w.Body.Bytes())
}
})
}
}
// TestSqlQueryList tests the SqlQueryList handler for list queries
func TestSqlQueryList(t *testing.T) {
tests := []struct {
name string
sqlQuery string
noCount bool
blankParams bool
allowFilter bool
queryParams map[string]string
headers map[string]string
setupDB func() *MockDatabase
expectedStatus int
validateResp func(t *testing.T, w *httptest.ResponseRecorder)
}{
{
name: "Basic list query",
sqlQuery: "SELECT * FROM users",
noCount: false,
blankParams: false,
allowFilter: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
callCount := 0
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
callCount++
if strings.Contains(query, "COUNT") {
// Count query
countResult := dest.(*struct{ Count int64 })
countResult.Count = 2
} else {
// Main query
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "User 1"},
{"id": float64(2), "name": "User 2"},
}
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var result []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 2 {
t.Errorf("Expected 2 results, got %d", len(result))
}
// Check Content-Range header
contentRange := w.Header().Get("Content-Range")
if !strings.Contains(contentRange, "2") {
t.Errorf("Expected Content-Range to contain total count, got: %s", contentRange)
}
},
},
{
name: "List query with noCount",
sqlQuery: "SELECT * FROM users",
noCount: true,
blankParams: false,
allowFilter: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if strings.Contains(query, "COUNT") {
t.Error("Count query should not be executed when noCount is true")
}
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "User 1"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var result []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 1 {
t.Errorf("Expected 1 result, got %d", len(result))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.setupDB()
handler := NewHandler(db)
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQueryList(tt.sqlQuery, tt.noCount, tt.blankParams, tt.allowFilter)
handlerFunc(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", tt.expectedStatus, w.Code, w.Body.String())
}
if tt.validateResp != nil {
tt.validateResp(t, w)
}
})
}
}
// TestMergeQueryParams tests query parameter merging
func TestMergeQueryParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
queryParams map[string]string
allowFilter bool
expectedQuery string
checkVars func(t *testing.T, vars map[string]interface{})
}{
{
name: "Replace placeholder with parameter",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
queryParams: map[string]string{"p-user_id": "123"},
allowFilter: false,
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["p-user_id"] != "123" {
t.Errorf("Expected p-user_id=123, got %v", vars["p-user_id"])
}
},
},
{
name: "Add filter when allowed",
sqlQuery: "SELECT * FROM users",
queryParams: map[string]string{"status": "active"},
allowFilter: true,
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["status"] != "active" {
t.Errorf("Expected status=active, got %v", vars["status"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
result := handler.mergeQueryParams(req, tt.sqlQuery, variables, tt.allowFilter, propQry)
if result == "" {
t.Error("Expected non-empty SQL query result")
}
if tt.checkVars != nil {
tt.checkVars(t, variables)
}
})
}
}
// TestMergeHeaderParams tests header parameter merging
func TestMergeHeaderParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
headers map[string]string
expectedQuery string
checkVars func(t *testing.T, vars map[string]interface{})
}{
{
name: "Field filter header",
sqlQuery: "SELECT * FROM users",
headers: map[string]string{"X-FieldFilter-Status": "1"},
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["x-fieldfilter-status"] != "1" {
t.Errorf("Expected x-fieldfilter-status=1, got %v", vars["x-fieldfilter-status"])
}
},
},
{
name: "Search filter header",
sqlQuery: "SELECT * FROM users",
headers: map[string]string{"X-SearchFilter-Name": "john"},
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["x-searchfilter-name"] != "john" {
t.Errorf("Expected x-searchfilter-name=john, got %v", vars["x-searchfilter-name"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", nil, tt.headers, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
complexAPI := false
result := handler.mergeHeaderParams(req, tt.sqlQuery, variables, propQry, &complexAPI)
if result == "" {
t.Error("Expected non-empty SQL query result")
}
if tt.checkVars != nil {
tt.checkVars(t, variables)
}
})
}
}
// TestReplaceMetaVariables tests meta variable replacement
func TestReplaceMetaVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
userCtx := &security.UserContext{
UserID: 123,
UserName: "testuser",
SessionID: "ABC456",
SessionRID: 456,
}
metainfo := map[string]interface{}{
"ipaddress": "192.168.1.1",
"url": "/api/test",
}
variables := map[string]interface{}{
"param1": "value1",
}
tests := []struct {
name string
sqlQuery string
expectedCheck func(result string) bool
}{
{
name: "Replace [rid_user]",
sqlQuery: "SELECT * FROM users WHERE created_by = [rid_user]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "123")
},
},
{
name: "Replace [user]",
sqlQuery: "SELECT * FROM audit WHERE username = [user]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "'testuser'")
},
},
{
name: "Replace [rid_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [rid_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "456")
},
}, {
name: "Replace [id_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [id_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "ABC456")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", nil, nil, nil)
result := handler.replaceMetaVariables(tt.sqlQuery, req, userCtx, metainfo, variables)
if !tt.expectedCheck(result) {
t.Errorf("Meta variable replacement failed. Query: %s", result)
}
})
}
}
// TestGetReplacementForBlankParam tests the blank parameter replacement logic
func TestGetReplacementForBlankParam(t *testing.T) {
tests := []struct {
name string
sqlQuery string
param string
expected string
}{
{
name: "Parameter in single quotes",
sqlQuery: "SELECT * FROM users WHERE name = '[username]'",
param: "[username]",
expected: "",
},
{
name: "Parameter in dollar quotes",
sqlQuery: "SELECT * FROM users WHERE data = $[jsondata]$",
param: "[jsondata]",
expected: "",
},
{
name: "Parameter not in quotes",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter not in quotes with AND",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND status = 1",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter in mixed quote context - before quote",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = 'test'",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter in mixed quote context - in quotes",
sqlQuery: "SELECT * FROM users WHERE name = '[username]' AND id = 1",
param: "[username]",
expected: "",
},
{
name: "Parameter with dollar quote tag",
sqlQuery: "SELECT * FROM users WHERE body = $tag$[content]$tag$",
param: "[content]",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getReplacementForBlankParam(tt.sqlQuery, tt.param)
if result != tt.expected {
t.Errorf("Expected replacement '%s', got '%s' for query: %s", tt.expected, result, tt.sqlQuery)
}
})
}
}

160
pkg/funcspec/hooks.go Normal file
View File

@@ -0,0 +1,160 @@
package funcspec
import (
"context"
"fmt"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// HookType defines the type of hook to execute
type HookType string
const (
// Query operation hooks (for SqlQuery - single record)
BeforeQuery HookType = "before_query"
AfterQuery HookType = "after_query"
// Query list operation hooks (for SqlQueryList - multiple records)
BeforeQueryList HookType = "before_query_list"
AfterQueryList HookType = "after_query_list"
// SQL execution hooks (just before SQL is executed)
BeforeSQLExec HookType = "before_sql_exec"
AfterSQLExec HookType = "after_sql_exec"
// Response hooks (before response is sent)
BeforeResponse HookType = "before_response"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler // Reference to the handler for accessing database
Request *http.Request
Writer http.ResponseWriter
// SQL query and variables
SQLQuery string // The SQL query being executed (can be modified by hooks)
Variables map[string]interface{} // Variables extracted from request
InputVars []string // Input variable placeholders found in query
MetaInfo map[string]interface{} // Metadata about the request
PropQry map[string]string // Property query parameters
// User context
UserContext *security.UserContext
// Pagination and filtering (for list queries)
SortColumns string
Limit int
Offset int
// Results
Result interface{} // Query result (single record or list)
Total int64 // Total count (for list queries)
Error error // Error if operation failed
ComplexAPI bool // Whether complex API response format is requested
NoCount bool // Whether count query should be skipped
BlankParams bool // Whether blank parameters should be removed
AllowFilter bool // Whether filtering is allowed
// Allow hooks to abort the operation
Abort bool // If set to true, the operation will be aborted
AbortMessage string // Message to return if aborted
AbortCode int // HTTP status code if aborted
}
// HookFunc is the signature for hook functions
// It receives a HookContext and can modify it or return an error
// If an error is returned, the operation will be aborted
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
// NewHookRegistry creates a new hook registry
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
// Register adds a new hook for the specified hook type
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered funcspec hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
// RegisterMultiple registers a hook for multiple hook types
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
// Execute runs all hooks for the specified type in order
// If any hook returns an error, execution stops and the error is returned
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
return nil
}
logger.Debug("Executing %d funcspec hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("Funcspec hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
// Check if hook requested abort
if ctx.Abort {
logger.Warn("Funcspec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil
}
// Clear removes all hooks for the specified type
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
logger.Info("Cleared all funcspec hooks for %s", hookType)
}
// ClearAll removes all registered hooks
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
logger.Info("Cleared all funcspec hooks")
}
// Count returns the number of hooks registered for a specific type
func (r *HookRegistry) Count(hookType HookType) int {
if hooks, exists := r.hooks[hookType]; exists {
return len(hooks)
}
return 0
}
// HasHooks returns true if there are any hooks registered for the specified type
func (r *HookRegistry) HasHooks(hookType HookType) bool {
return r.Count(hookType) > 0
}
// GetAllHookTypes returns all hook types that have registered hooks
func (r *HookRegistry) GetAllHookTypes() []HookType {
types := make([]HookType, 0, len(r.hooks))
for hookType := range r.hooks {
types = append(types, hookType)
}
return types
}

View File

@@ -0,0 +1,137 @@
package funcspec
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Example hook functions demonstrating various use cases
// ExampleLoggingHook logs all SQL queries before execution
func ExampleLoggingHook(ctx *HookContext) error {
logger.Info("Executing SQL query for user %s: %s", ctx.UserContext.UserName, ctx.SQLQuery)
return nil
}
// ExampleSecurityHook validates user permissions before executing queries
func ExampleSecurityHook(ctx *HookContext) error {
// Example: Block queries that try to access sensitive tables
if strings.Contains(strings.ToLower(ctx.SQLQuery), "sensitive_table") {
if ctx.UserContext.UserID != 1 { // Only admin can access
ctx.Abort = true
ctx.AbortCode = 403
ctx.AbortMessage = "Access denied: insufficient permissions"
return fmt.Errorf("access denied to sensitive_table")
}
}
return nil
}
// ExampleQueryModificationHook modifies SQL queries to add user-specific filtering
func ExampleQueryModificationHook(ctx *HookContext) error {
// Example: Automatically add user_id filter for non-admin users
if ctx.UserContext.UserID != 1 { // Not admin
// Add WHERE clause to filter by user_id
if !strings.Contains(strings.ToLower(ctx.SQLQuery), "where") {
ctx.SQLQuery = fmt.Sprintf("%s WHERE user_id = %d", ctx.SQLQuery, ctx.UserContext.UserID)
} else {
ctx.SQLQuery = strings.Replace(
ctx.SQLQuery,
"WHERE",
fmt.Sprintf("WHERE user_id = %d AND", ctx.UserContext.UserID),
1,
)
}
logger.Debug("Modified query for user %d: %s", ctx.UserContext.UserID, ctx.SQLQuery)
}
return nil
}
// ExampleResultFilterHook filters results after query execution
func ExampleResultFilterHook(ctx *HookContext) error {
// Example: Remove sensitive fields from results for non-admin users
if ctx.UserContext.UserID != 1 { // Not admin
switch result := ctx.Result.(type) {
case []map[string]interface{}:
// Filter list results
for i := range result {
delete(result[i], "password")
delete(result[i], "ssn")
delete(result[i], "credit_card")
}
case map[string]interface{}:
// Filter single record
delete(result, "password")
delete(result, "ssn")
delete(result, "credit_card")
}
}
return nil
}
// ExampleAuditHook logs all queries and results for audit purposes
func ExampleAuditHook(ctx *HookContext) error {
// Log to audit table or external system
logger.Info("AUDIT: User %s (%d) executed query from %s",
ctx.UserContext.UserName,
ctx.UserContext.UserID,
ctx.Request.RemoteAddr,
)
// In a real implementation, you might:
// - Insert into an audit log table
// - Send to a logging service
// - Write to a file
return nil
}
// ExampleCacheHook implements simple response caching
func ExampleCacheHook(ctx *HookContext) error {
// This is a simplified example - real caching would use a proper cache store
// Check if we have a cached result for this query
// cacheKey := fmt.Sprintf("%s:%s", ctx.UserContext.UserName, ctx.SQLQuery)
// if cachedResult := checkCache(cacheKey); cachedResult != nil {
// ctx.Result = cachedResult
// ctx.Abort = true // Skip query execution
// ctx.AbortMessage = "Serving from cache"
// }
return nil
}
// ExampleErrorHandlingHook provides custom error handling
func ExampleErrorHandlingHook(ctx *HookContext) error {
if ctx.Error != nil {
// Log error with context
logger.Error("Query failed for user %s: %v\nQuery: %s",
ctx.UserContext.UserName,
ctx.Error,
ctx.SQLQuery,
)
// You could send notifications, update metrics, etc.
}
return nil
}
// Example of registering hooks:
//
// func SetupHooks(handler *Handler) {
// hooks := handler.Hooks()
//
// // Register security hook before query execution
// hooks.Register(BeforeQuery, ExampleSecurityHook)
// hooks.Register(BeforeQueryList, ExampleSecurityHook)
//
// // Register logging hook before SQL execution
// hooks.Register(BeforeSQLExec, ExampleLoggingHook)
//
// // Register result filtering after query
// hooks.Register(AfterQuery, ExampleResultFilterHook)
// hooks.Register(AfterQueryList, ExampleResultFilterHook)
//
// // Register audit hook after execution
// hooks.RegisterMultiple([]HookType{AfterQuery, AfterQueryList}, ExampleAuditHook)
// }

589
pkg/funcspec/hooks_test.go Normal file
View File

@@ -0,0 +1,589 @@
package funcspec
import (
"context"
"fmt"
"net/http/httptest"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// TestNewHookRegistry tests hook registry creation
func TestNewHookRegistry(t *testing.T) {
registry := NewHookRegistry()
if registry == nil {
t.Fatal("Expected registry to be created, got nil")
}
if registry.hooks == nil {
t.Error("Expected hooks map to be initialized")
}
}
// TestRegisterHook tests registering a single hook
func TestRegisterHook(t *testing.T) {
registry := NewHookRegistry()
hookCalled := false
testHook := func(ctx *HookContext) error {
hookCalled = true
return nil
}
registry.Register(BeforeQuery, testHook)
if !registry.HasHooks(BeforeQuery) {
t.Error("Expected hook to be registered")
}
if registry.Count(BeforeQuery) != 1 {
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeQuery))
}
// Execute the hook
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if !hookCalled {
t.Error("Expected hook to be called")
}
}
// TestRegisterMultipleHooks tests registering multiple hooks for same type
func TestRegisterMultipleHooks(t *testing.T) {
registry := NewHookRegistry()
callOrder := []int{}
hook1 := func(ctx *HookContext) error {
callOrder = append(callOrder, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
callOrder = append(callOrder, 2)
return nil
}
hook3 := func(ctx *HookContext) error {
callOrder = append(callOrder, 3)
return nil
}
registry.Register(BeforeQuery, hook1)
registry.Register(BeforeQuery, hook2)
registry.Register(BeforeQuery, hook3)
if registry.Count(BeforeQuery) != 3 {
t.Errorf("Expected 3 hooks, got %d", registry.Count(BeforeQuery))
}
// Execute hooks
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
// Verify hooks were called in order
if len(callOrder) != 3 {
t.Errorf("Expected 3 hooks to be called, got %d", len(callOrder))
}
for i, expected := range []int{1, 2, 3} {
if callOrder[i] != expected {
t.Errorf("Expected hook %d at position %d, got %d", expected, i, callOrder[i])
}
}
}
// TestRegisterMultipleHookTypes tests registering a hook for multiple types
func TestRegisterMultipleHookTypes(t *testing.T) {
registry := NewHookRegistry()
callCount := 0
testHook := func(ctx *HookContext) error {
callCount++
return nil
}
hookTypes := []HookType{BeforeQuery, AfterQuery, BeforeSQLExec}
registry.RegisterMultiple(hookTypes, testHook)
// Verify hook is registered for all types
for _, hookType := range hookTypes {
if !registry.HasHooks(hookType) {
t.Errorf("Expected hook to be registered for %s", hookType)
}
if registry.Count(hookType) != 1 {
t.Errorf("Expected 1 hook for %s, got %d", hookType, registry.Count(hookType))
}
}
// Execute each hook type
ctx := &HookContext{}
for _, hookType := range hookTypes {
if err := registry.Execute(hookType, ctx); err != nil {
t.Errorf("Hook execution failed for %s: %v", hookType, err)
}
}
if callCount != 3 {
t.Errorf("Expected hook to be called 3 times, got %d", callCount)
}
}
// TestHookError tests hook error handling
func TestHookError(t *testing.T) {
registry := NewHookRegistry()
expectedError := fmt.Errorf("test error")
errorHook := func(ctx *HookContext) error {
return expectedError
}
registry.Register(BeforeQuery, errorHook)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error from hook, got nil")
}
if err.Error() != fmt.Sprintf("hook execution failed: %v", expectedError) {
t.Errorf("Expected error message to contain hook error, got: %v", err)
}
}
// TestHookAbort tests hook abort functionality
func TestHookAbort(t *testing.T) {
registry := NewHookRegistry()
abortHook := func(ctx *HookContext) error {
ctx.Abort = true
ctx.AbortMessage = "Operation aborted by hook"
ctx.AbortCode = 403
return nil
}
registry.Register(BeforeQuery, abortHook)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error when hook aborts, got nil")
}
if !ctx.Abort {
t.Error("Expected Abort to be true")
}
if ctx.AbortMessage != "Operation aborted by hook" {
t.Errorf("Expected abort message, got: %s", ctx.AbortMessage)
}
if ctx.AbortCode != 403 {
t.Errorf("Expected abort code 403, got: %d", ctx.AbortCode)
}
}
// TestHookChainWithError tests that hook chain stops on first error
func TestHookChainWithError(t *testing.T) {
registry := NewHookRegistry()
callOrder := []int{}
hook1 := func(ctx *HookContext) error {
callOrder = append(callOrder, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
callOrder = append(callOrder, 2)
return fmt.Errorf("error in hook 2")
}
hook3 := func(ctx *HookContext) error {
callOrder = append(callOrder, 3)
return nil
}
registry.Register(BeforeQuery, hook1)
registry.Register(BeforeQuery, hook2)
registry.Register(BeforeQuery, hook3)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error from hook chain")
}
// Only first two hooks should have been called
if len(callOrder) != 2 {
t.Errorf("Expected 2 hooks to be called, got %d", len(callOrder))
}
if callOrder[0] != 1 || callOrder[1] != 2 {
t.Errorf("Expected hooks 1 and 2 to be called, got: %v", callOrder)
}
}
// TestClearHooks tests clearing hooks
func TestClearHooks(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
if !registry.HasHooks(BeforeQuery) {
t.Error("Expected BeforeQuery hook to be registered")
}
registry.Clear(BeforeQuery)
if registry.HasHooks(BeforeQuery) {
t.Error("Expected BeforeQuery hooks to be cleared")
}
if !registry.HasHooks(AfterQuery) {
t.Error("Expected AfterQuery hook to still be registered")
}
}
// TestClearAllHooks tests clearing all hooks
func TestClearAllHooks(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
registry.Register(BeforeSQLExec, testHook)
registry.ClearAll()
if registry.HasHooks(BeforeQuery) || registry.HasHooks(AfterQuery) || registry.HasHooks(BeforeSQLExec) {
t.Error("Expected all hooks to be cleared")
}
}
// TestGetAllHookTypes tests getting all registered hook types
func TestGetAllHookTypes(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
types := registry.GetAllHookTypes()
if len(types) != 2 {
t.Errorf("Expected 2 hook types, got %d", len(types))
}
// Verify the types are present
foundBefore := false
foundAfter := false
for _, hookType := range types {
if hookType == BeforeQuery {
foundBefore = true
}
if hookType == AfterQuery {
foundAfter = true
}
}
if !foundBefore || !foundAfter {
t.Error("Expected both BeforeQuery and AfterQuery hook types")
}
}
// TestHookContextModification tests that hooks can modify the context
func TestHookContextModification(t *testing.T) {
registry := NewHookRegistry()
// Hook that modifies SQL query
modifyHook := func(ctx *HookContext) error {
ctx.SQLQuery = "SELECT * FROM modified_table"
ctx.Variables["new_var"] = "new_value"
return nil
}
registry.Register(BeforeQuery, modifyHook)
ctx := &HookContext{
SQLQuery: "SELECT * FROM original_table",
Variables: make(map[string]interface{}),
}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if ctx.SQLQuery != "SELECT * FROM modified_table" {
t.Errorf("Expected SQL query to be modified, got: %s", ctx.SQLQuery)
}
if ctx.Variables["new_var"] != "new_value" {
t.Errorf("Expected variable to be added, got: %v", ctx.Variables)
}
}
// TestExampleHooks tests the example hooks
func TestExampleLoggingHook(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: "SELECT * FROM test",
UserContext: &security.UserContext{
UserName: "testuser",
},
}
err := ExampleLoggingHook(ctx)
if err != nil {
t.Errorf("ExampleLoggingHook failed: %v", err)
}
}
func TestExampleSecurityHook(t *testing.T) {
tests := []struct {
name string
sqlQuery string
userID int
shouldAbort bool
}{
{
name: "Admin accessing sensitive table",
sqlQuery: "SELECT * FROM sensitive_table",
userID: 1,
shouldAbort: false,
},
{
name: "Non-admin accessing sensitive table",
sqlQuery: "SELECT * FROM sensitive_table",
userID: 2,
shouldAbort: true,
},
{
name: "Non-admin accessing normal table",
sqlQuery: "SELECT * FROM users",
userID: 2,
shouldAbort: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: tt.sqlQuery,
UserContext: &security.UserContext{
UserID: tt.userID,
},
}
_ = ExampleSecurityHook(ctx)
if tt.shouldAbort {
if !ctx.Abort {
t.Error("Expected security hook to abort operation")
}
if ctx.AbortCode != 403 {
t.Errorf("Expected abort code 403, got %d", ctx.AbortCode)
}
} else {
if ctx.Abort {
t.Error("Expected security hook not to abort operation")
}
}
})
}
}
func TestExampleResultFilterHook(t *testing.T) {
tests := []struct {
name string
userID int
result interface{}
validate func(t *testing.T, result interface{})
}{
{
name: "Admin user - no filtering",
userID: 1,
result: map[string]interface{}{
"id": 1,
"name": "Test",
"password": "secret",
},
validate: func(t *testing.T, result interface{}) {
m := result.(map[string]interface{})
if _, exists := m["password"]; !exists {
t.Error("Expected password field to remain for admin")
}
},
},
{
name: "Regular user - sensitive fields removed",
userID: 2,
result: map[string]interface{}{
"id": 1,
"name": "Test",
"password": "secret",
"ssn": "123-45-6789",
},
validate: func(t *testing.T, result interface{}) {
m := result.(map[string]interface{})
if _, exists := m["password"]; exists {
t.Error("Expected password field to be removed")
}
if _, exists := m["ssn"]; exists {
t.Error("Expected ssn field to be removed")
}
if _, exists := m["name"]; !exists {
t.Error("Expected name field to remain")
}
},
},
{
name: "Regular user - list results filtered",
userID: 2,
result: []map[string]interface{}{
{"id": 1, "name": "User 1", "password": "secret1"},
{"id": 2, "name": "User 2", "password": "secret2"},
},
validate: func(t *testing.T, result interface{}) {
list := result.([]map[string]interface{})
for _, m := range list {
if _, exists := m["password"]; exists {
t.Error("Expected password field to be removed from list")
}
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
Result: tt.result,
UserContext: &security.UserContext{
UserID: tt.userID,
},
}
err := ExampleResultFilterHook(ctx)
if err != nil {
t.Errorf("Hook failed: %v", err)
}
if tt.validate != nil {
tt.validate(t, ctx.Result)
}
})
}
}
func TestExampleAuditHook(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
ctx := &HookContext{
Context: context.Background(),
Request: req,
UserContext: &security.UserContext{
UserID: 123,
UserName: "testuser",
},
}
err := ExampleAuditHook(ctx)
if err != nil {
t.Errorf("ExampleAuditHook failed: %v", err)
}
}
func TestExampleErrorHandlingHook(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: "SELECT * FROM test",
Error: fmt.Errorf("test error"),
UserContext: &security.UserContext{
UserName: "testuser",
},
}
err := ExampleErrorHandlingHook(ctx)
if err != nil {
t.Errorf("ExampleErrorHandlingHook failed: %v", err)
}
}
// TestHookIntegrationWithHandler tests hooks integrated with the handler
func TestHookIntegrationWithHandler(t *testing.T) {
db := &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
queryDB := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "Test User"},
}
return nil
},
}
return fn(queryDB)
},
}
handler := NewHandler(db)
// Register a hook that modifies the SQL query
hookCalled := false
handler.Hooks().Register(BeforeSQLExec, func(ctx *HookContext) error {
hookCalled = true
// Verify we can access context data
if ctx.SQLQuery == "" {
t.Error("Expected SQL query to be set")
}
if ctx.UserContext == nil {
t.Error("Expected user context to be set")
}
return nil
})
// Execute a query
req := createTestRequest("GET", "/test", nil, nil, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", false)
handlerFunc(w, req)
if !hookCalled {
t.Error("Expected hook to be called during query execution")
}
if w.Code != 200 {
t.Errorf("Expected status 200, got %d", w.Code)
}
}

411
pkg/funcspec/parameters.go Normal file
View File

@@ -0,0 +1,411 @@
package funcspec
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
)
// RequestParameters holds parsed parameters from headers and query string
type RequestParameters struct {
// Field selection
SelectFields []string
NotSelectFields []string
Distinct bool
// Filtering
FieldFilters map[string]string // column -> value (exact match)
SearchFilters map[string]string // column -> value (ILIKE)
SearchOps map[string]FilterOperator // column -> {operator, value, logic}
CustomSQLWhere string
CustomSQLOr string
// Sorting & Pagination
SortColumns string
Limit int
Offset int
// Advanced features
SkipCount bool
SkipCache bool
// Response format
ResponseFormat string // "simple", "detail", "syncfusion"
ComplexAPI bool // true if NOT simple API
}
// FilterOperator represents a filter with operator
type FilterOperator struct {
Operator string // eq, neq, gt, lt, gte, lte, like, ilike, in, between, etc.
Value string
Logic string // AND or OR
}
// ParseParameters parses all parameters from request headers and query string
func (h *Handler) ParseParameters(r *http.Request) *RequestParameters {
params := &RequestParameters{
FieldFilters: make(map[string]string),
SearchFilters: make(map[string]string),
SearchOps: make(map[string]FilterOperator),
Limit: 20, // Default limit
Offset: 0, // Default offset
ResponseFormat: "simple", // Default format
ComplexAPI: false, // Default to simple API
}
// Merge headers and query parameters
combined := make(map[string]string)
// Add all headers (normalize to lowercase)
for key, values := range r.Header {
if len(values) > 0 {
combined[strings.ToLower(key)] = values[0]
}
}
// Add all query parameters (override headers)
for key, values := range r.URL.Query() {
if len(values) > 0 {
combined[strings.ToLower(key)] = values[0]
}
}
// Parse each parameter
for key, value := range combined {
// Decode value if base64 encoded
decodedValue := h.decodeValue(value)
switch {
// Field Selection
case strings.HasPrefix(key, "x-select-fields"):
params.SelectFields = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(key, "x-not-select-fields"):
params.NotSelectFields = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(key, "x-distinct"):
params.Distinct = strings.EqualFold(decodedValue, "true")
// Filtering
case strings.HasPrefix(key, "x-fieldfilter-"):
colName := strings.TrimPrefix(key, "x-fieldfilter-")
params.FieldFilters[colName] = decodedValue
case strings.HasPrefix(key, "x-searchfilter-"):
colName := strings.TrimPrefix(key, "x-searchfilter-")
params.SearchFilters[colName] = decodedValue
case strings.HasPrefix(key, "x-searchop-"):
h.parseSearchOp(params, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-searchor-"):
h.parseSearchOp(params, key, decodedValue, "OR")
case strings.HasPrefix(key, "x-searchand-"):
h.parseSearchOp(params, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-custom-sql-w"):
if params.CustomSQLWhere != "" {
params.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", params.CustomSQLWhere, decodedValue)
} else {
params.CustomSQLWhere = decodedValue
}
case strings.HasPrefix(key, "x-custom-sql-or"):
if params.CustomSQLOr != "" {
params.CustomSQLOr = fmt.Sprintf("%s OR (%s)", params.CustomSQLOr, decodedValue)
} else {
params.CustomSQLOr = decodedValue
}
// Sorting & Pagination
case key == "sort" || strings.HasPrefix(key, "x-sort"):
params.SortColumns = decodedValue
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
// Handle sort(col1,-col2) syntax
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
params.SortColumns = sortValue
case key == "limit" || strings.HasPrefix(key, "x-limit"):
if limit, err := strconv.Atoi(decodedValue); err == nil && limit > 0 {
params.Limit = limit
}
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
// Handle limit(offset,limit) or limit(limit) syntax
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
parts := strings.Split(limitValue, ",")
if len(parts) > 1 {
if offset, err := strconv.Atoi(parts[0]); err == nil {
params.Offset = offset
}
if limit, err := strconv.Atoi(parts[1]); err == nil {
params.Limit = limit
}
} else {
if limit, err := strconv.Atoi(parts[0]); err == nil {
params.Limit = limit
}
}
case key == "offset" || strings.HasPrefix(key, "x-offset"):
if offset, err := strconv.Atoi(decodedValue); err == nil && offset >= 0 {
params.Offset = offset
}
// Advanced features
case strings.HasPrefix(key, "x-skipcount"):
params.SkipCount = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(key, "x-skipcache"):
params.SkipCache = strings.EqualFold(decodedValue, "true")
// Response Format
case strings.HasPrefix(key, "x-simpleapi"):
params.ResponseFormat = "simple"
params.ComplexAPI = decodedValue != "1" && !strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(key, "x-detailapi"):
params.ResponseFormat = "detail"
params.ComplexAPI = true
case strings.HasPrefix(key, "x-syncfusion"):
params.ResponseFormat = "syncfusion"
params.ComplexAPI = true
}
}
return params
}
// parseSearchOp parses x-searchop-{operator}-{column} or x-searchor-{operator}-{column}
func (h *Handler) parseSearchOp(params *RequestParameters, headerKey, value, logic string) {
var prefix string
if logic == "OR" {
prefix = "x-searchor-"
} else {
prefix = "x-searchop-"
if strings.HasPrefix(headerKey, "x-searchand-") {
prefix = "x-searchand-"
}
}
rest := strings.TrimPrefix(headerKey, prefix)
parts := strings.SplitN(rest, "-", 2)
if len(parts) != 2 {
logger.Warn("Invalid search operator header format: %s", headerKey)
return
}
operator := parts[0]
colName := parts[1]
params.SearchOps[colName] = FilterOperator{
Operator: operator,
Value: value,
Logic: logic,
}
logger.Debug("%s search operator: %s %s %s", logic, colName, operator, value)
}
// decodeValue decodes base64 encoded values (ZIP_ or __ prefix)
func (h *Handler) decodeValue(value string) string {
decoded, _ := restheadspec.DecodeParam(value)
return decoded
}
// parseCommaSeparated parses comma-separated values
func (h *Handler) parseCommaSeparated(value string) []string {
if value == "" {
return nil
}
parts := strings.Split(value, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part != "" {
result = append(result, part)
}
}
return result
}
// ApplyFieldSelection applies column selection to SQL query
func (h *Handler) ApplyFieldSelection(sqlQuery string, params *RequestParameters) string {
if len(params.SelectFields) == 0 && len(params.NotSelectFields) == 0 {
return sqlQuery
}
// This is a simplified implementation
// A full implementation would parse the SQL and replace the SELECT clause
// For now, we log a warning that this feature needs manual implementation
if len(params.SelectFields) > 0 {
logger.Debug("Field selection requested: %v (manual SQL adjustment may be needed)", params.SelectFields)
}
if len(params.NotSelectFields) > 0 {
logger.Debug("Field exclusion requested: %v (manual SQL adjustment may be needed)", params.NotSelectFields)
}
return sqlQuery
}
// ApplyFilters applies all filters to the SQL query
func (h *Handler) ApplyFilters(sqlQuery string, params *RequestParameters) string {
// Apply field filters (exact match)
for colName, value := range params.FieldFilters {
condition := ""
if value == "" || value == "0" {
condition = fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
} else {
condition = fmt.Sprintf("%s = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
}
sqlQuery = sqlQryWhere(sqlQuery, condition)
logger.Debug("Applied field filter: %s", condition)
}
// Apply search filters (ILIKE)
for colName, value := range params.SearchFilters {
sval := strings.ReplaceAll(value, "'", "")
if sval != "" {
condition := fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colName, "colname"), ValidSQL(sval, "colvalue"))
sqlQuery = sqlQryWhere(sqlQuery, condition)
logger.Debug("Applied search filter: %s", condition)
}
}
// Apply search operators
for colName, filterOp := range params.SearchOps {
condition := h.buildFilterCondition(colName, filterOp)
if condition != "" {
if filterOp.Logic == "OR" {
sqlQuery = sqlQryWhereOr(sqlQuery, condition)
} else {
sqlQuery = sqlQryWhere(sqlQuery, condition)
}
logger.Debug("Applied search operator: %s", condition)
}
}
// Apply custom SQL WHERE
if params.CustomSQLWhere != "" {
colval := ValidSQL(params.CustomSQLWhere, "select")
if colval != "" {
sqlQuery = sqlQryWhere(sqlQuery, colval)
logger.Debug("Applied custom SQL WHERE: %s", colval)
}
}
// Apply custom SQL OR
if params.CustomSQLOr != "" {
colval := ValidSQL(params.CustomSQLOr, "select")
if colval != "" {
sqlQuery = sqlQryWhereOr(sqlQuery, colval)
logger.Debug("Applied custom SQL OR: %s", colval)
}
}
return sqlQuery
}
// buildFilterCondition builds a SQL condition from a FilterOperator
func (h *Handler) buildFilterCondition(colName string, op FilterOperator) string {
safCol := ValidSQL(colName, "colname")
operator := strings.ToLower(op.Operator)
value := op.Value
switch operator {
case "contains", "contain", "like":
return fmt.Sprintf("%s ILIKE '%%%s%%'", safCol, ValidSQL(value, "colvalue"))
case "beginswith", "startswith":
return fmt.Sprintf("%s ILIKE '%s%%'", safCol, ValidSQL(value, "colvalue"))
case "endswith":
return fmt.Sprintf("%s ILIKE '%%%s'", safCol, ValidSQL(value, "colvalue"))
case "equals", "eq", "=":
if IsNumeric(value) {
return fmt.Sprintf("%s = %s", safCol, ValidSQL(value, "colvalue"))
}
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
case "notequals", "neq", "ne", "!=", "<>":
if IsNumeric(value) {
return fmt.Sprintf("%s != %s", safCol, ValidSQL(value, "colvalue"))
}
return fmt.Sprintf("%s != '%s'", safCol, ValidSQL(value, "colvalue"))
case "greaterthan", "gt", ">":
return fmt.Sprintf("%s > %s", safCol, ValidSQL(value, "colvalue"))
case "lessthan", "lt", "<":
return fmt.Sprintf("%s < %s", safCol, ValidSQL(value, "colvalue"))
case "greaterthanorequal", "gte", "ge", ">=":
return fmt.Sprintf("%s >= %s", safCol, ValidSQL(value, "colvalue"))
case "lessthanorequal", "lte", "le", "<=":
return fmt.Sprintf("%s <= %s", safCol, ValidSQL(value, "colvalue"))
case "between":
parts := strings.Split(value, ",")
if len(parts) == 2 {
return fmt.Sprintf("%s > %s AND %s < %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
}
case "betweeninclusive":
parts := strings.Split(value, ",")
if len(parts) == 2 {
return fmt.Sprintf("%s >= %s AND %s <= %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
}
case "in":
values := strings.Split(value, ",")
safeValues := make([]string, len(values))
for i, v := range values {
safeValues[i] = fmt.Sprintf("'%s'", ValidSQL(v, "colvalue"))
}
return fmt.Sprintf("%s IN (%s)", safCol, strings.Join(safeValues, ", "))
case "empty", "isnull", "null":
return fmt.Sprintf("(%s IS NULL OR %s = '')", safCol, safCol)
case "notempty", "isnotnull", "notnull":
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", safCol, safCol)
default:
logger.Warn("Unknown filter operator: %s, defaulting to equals", operator)
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
}
return ""
}
// ApplyDistinct adds DISTINCT to SQL query if requested
func (h *Handler) ApplyDistinct(sqlQuery string, params *RequestParameters) string {
if !params.Distinct {
return sqlQuery
}
// Add DISTINCT after SELECT
selectPos := strings.Index(strings.ToUpper(sqlQuery), "SELECT")
if selectPos >= 0 {
beforeSelect := sqlQuery[:selectPos+6] // "SELECT"
afterSelect := sqlQuery[selectPos+6:]
sqlQuery = beforeSelect + " DISTINCT" + afterSelect
logger.Debug("Applied DISTINCT to query")
}
return sqlQuery
}
// sqlQryWhereOr adds a WHERE clause with OR logic
func sqlQryWhereOr(sqlquery, condition string) string {
lowerQuery := strings.ToLower(sqlquery)
wherePos := strings.Index(lowerQuery, " where ")
groupPos := strings.Index(lowerQuery, " group by")
orderPos := strings.Index(lowerQuery, " order by")
limitPos := strings.Index(lowerQuery, " limit ")
// Find the insertion point
insertPos := len(sqlquery)
if groupPos > 0 && groupPos < insertPos {
insertPos = groupPos
}
if orderPos > 0 && orderPos < insertPos {
insertPos = orderPos
}
if limitPos > 0 && limitPos < insertPos {
insertPos = limitPos
}
if wherePos > 0 {
// WHERE exists, add OR condition
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s OR (%s) %s", before, condition, after)
} else {
// No WHERE exists, add it
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s WHERE %s %s", before, condition, after)
}
}

View File

@@ -0,0 +1,549 @@
package funcspec
import (
"strings"
"testing"
)
// TestParseParameters tests the comprehensive parameter parsing
func TestParseParameters(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
queryParams map[string]string
headers map[string]string
validate func(t *testing.T, params *RequestParameters)
}{
{
name: "Parse field selection",
headers: map[string]string{
"X-Select-Fields": "id,name,email",
"X-Not-Select-Fields": "password,ssn",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SelectFields) != 3 {
t.Errorf("Expected 3 select fields, got %d", len(params.SelectFields))
}
if len(params.NotSelectFields) != 2 {
t.Errorf("Expected 2 not-select fields, got %d", len(params.NotSelectFields))
}
},
},
{
name: "Parse distinct flag",
headers: map[string]string{
"X-Distinct": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if !params.Distinct {
t.Error("Expected Distinct to be true")
}
},
},
{
name: "Parse field filters",
headers: map[string]string{
"X-FieldFilter-Status": "active",
"X-FieldFilter-Type": "admin",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.FieldFilters) != 2 {
t.Errorf("Expected 2 field filters, got %d", len(params.FieldFilters))
}
if params.FieldFilters["status"] != "active" {
t.Errorf("Expected status filter=active, got %s", params.FieldFilters["status"])
}
},
},
{
name: "Parse search filters",
headers: map[string]string{
"X-SearchFilter-Name": "john",
"X-SearchFilter-Email": "test",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SearchFilters) != 2 {
t.Errorf("Expected 2 search filters, got %d", len(params.SearchFilters))
}
},
},
{
name: "Parse sort columns",
queryParams: map[string]string{
"sort": "-created_at,name",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.SortColumns != "-created_at,name" {
t.Errorf("Expected sort columns=-created_at,name, got %s", params.SortColumns)
}
},
},
{
name: "Parse limit and offset",
queryParams: map[string]string{
"limit": "100",
"offset": "50",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.Limit != 100 {
t.Errorf("Expected limit=100, got %d", params.Limit)
}
if params.Offset != 50 {
t.Errorf("Expected offset=50, got %d", params.Offset)
}
},
},
{
name: "Parse skip count",
headers: map[string]string{
"X-SkipCount": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if !params.SkipCount {
t.Error("Expected SkipCount to be true")
}
},
},
{
name: "Parse response format - syncfusion",
headers: map[string]string{
"X-Syncfusion": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "syncfusion" {
t.Errorf("Expected ResponseFormat=syncfusion, got %s", params.ResponseFormat)
}
if !params.ComplexAPI {
t.Error("Expected ComplexAPI to be true for syncfusion format")
}
},
},
{
name: "Parse response format - detail",
headers: map[string]string{
"X-DetailAPI": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "detail" {
t.Errorf("Expected ResponseFormat=detail, got %s", params.ResponseFormat)
}
},
},
{
name: "Parse simple API",
headers: map[string]string{
"X-SimpleAPI": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "simple" {
t.Errorf("Expected ResponseFormat=simple, got %s", params.ResponseFormat)
}
if params.ComplexAPI {
t.Error("Expected ComplexAPI to be false for simple API")
}
},
},
{
name: "Parse custom SQL WHERE",
headers: map[string]string{
"X-Custom-SQL-W": "status = 'active' AND deleted = false",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.CustomSQLWhere == "" {
t.Error("Expected CustomSQLWhere to be set")
}
},
},
{
name: "Parse search operators - AND",
headers: map[string]string{
"X-SearchOp-Eq-Name": "john",
"X-SearchOp-Gt-Age": "18",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SearchOps) != 2 {
t.Errorf("Expected 2 search operators, got %d", len(params.SearchOps))
}
if op, exists := params.SearchOps["name"]; exists {
if op.Operator != "eq" {
t.Errorf("Expected operator=eq for name, got %s", op.Operator)
}
if op.Logic != "AND" {
t.Errorf("Expected logic=AND, got %s", op.Logic)
}
} else {
t.Error("Expected name search operator to exist")
}
},
},
{
name: "Parse search operators - OR",
headers: map[string]string{
"X-SearchOr-Like-Description": "test",
},
validate: func(t *testing.T, params *RequestParameters) {
if op, exists := params.SearchOps["description"]; exists {
if op.Logic != "OR" {
t.Errorf("Expected logic=OR, got %s", op.Logic)
}
} else {
t.Error("Expected description search operator to exist")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
params := handler.ParseParameters(req)
if tt.validate != nil {
tt.validate(t, params)
}
})
}
}
// TestBuildFilterCondition tests the filter condition builder
func TestBuildFilterCondition(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
colName string
operator FilterOperator
expected string
}{
{
name: "Equals operator - numeric",
colName: "age",
operator: FilterOperator{
Operator: "eq",
Value: "25",
Logic: "AND",
},
expected: "age = 25",
},
{
name: "Equals operator - string",
colName: "name",
operator: FilterOperator{
Operator: "eq",
Value: "john",
Logic: "AND",
},
expected: "name = 'john'",
},
{
name: "Not equals operator",
colName: "status",
operator: FilterOperator{
Operator: "neq",
Value: "inactive",
Logic: "AND",
},
expected: "status != 'inactive'",
},
{
name: "Greater than operator",
colName: "age",
operator: FilterOperator{
Operator: "gt",
Value: "18",
Logic: "AND",
},
expected: "age > 18",
},
{
name: "Less than operator",
colName: "price",
operator: FilterOperator{
Operator: "lt",
Value: "100",
Logic: "AND",
},
expected: "price < 100",
},
{
name: "Contains operator",
colName: "description",
operator: FilterOperator{
Operator: "contains",
Value: "test",
Logic: "AND",
},
expected: "description ILIKE '%test%'",
},
{
name: "Starts with operator",
colName: "name",
operator: FilterOperator{
Operator: "startswith",
Value: "john",
Logic: "AND",
},
expected: "name ILIKE 'john%'",
},
{
name: "Ends with operator",
colName: "email",
operator: FilterOperator{
Operator: "endswith",
Value: "@example.com",
Logic: "AND",
},
expected: "email ILIKE '%@example.com'",
},
{
name: "Between operator",
colName: "age",
operator: FilterOperator{
Operator: "between",
Value: "18,65",
Logic: "AND",
},
expected: "age > 18 AND age < 65",
},
{
name: "IN operator",
colName: "status",
operator: FilterOperator{
Operator: "in",
Value: "active,pending,approved",
Logic: "AND",
},
expected: "status IN ('active', 'pending', 'approved')",
},
{
name: "IS NULL operator",
colName: "deleted_at",
operator: FilterOperator{
Operator: "null",
Value: "",
Logic: "AND",
},
expected: "(deleted_at IS NULL OR deleted_at = '')",
},
{
name: "IS NOT NULL operator",
colName: "created_at",
operator: FilterOperator{
Operator: "notnull",
Value: "",
Logic: "AND",
},
expected: "(created_at IS NOT NULL AND created_at != '')",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.buildFilterCondition(tt.colName, tt.operator)
if result != tt.expected {
t.Errorf("Expected: %s\nGot: %s", tt.expected, result)
}
})
}
}
// TestApplyFilters tests the filter application to SQL queries
func TestApplyFilters(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
params *RequestParameters
expectedSQL string
shouldContain []string
}{
{
name: "Apply field filter",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
FieldFilters: map[string]string{
"status": "active",
},
},
shouldContain: []string{"WHERE", "status"},
},
{
name: "Apply search filter",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
SearchFilters: map[string]string{
"name": "john",
},
},
shouldContain: []string{"WHERE", "name", "ILIKE"},
},
{
name: "Apply search operators",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
SearchOps: map[string]FilterOperator{
"age": {
Operator: "gt",
Value: "18",
Logic: "AND",
},
},
},
shouldContain: []string{"WHERE", "age", ">", "18"},
},
{
name: "Apply custom SQL WHERE",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
CustomSQLWhere: "deleted = false",
},
shouldContain: []string{"WHERE", "deleted"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.ApplyFilters(tt.sqlQuery, tt.params)
for _, expected := range tt.shouldContain {
if !strings.Contains(result, expected) {
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
}
}
})
}
}
// TestApplyDistinct tests DISTINCT application
func TestApplyDistinct(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
distinct bool
shouldHave string
}{
{
name: "Apply DISTINCT",
sqlQuery: "SELECT id, name FROM users",
distinct: true,
shouldHave: "SELECT DISTINCT",
},
{
name: "Do not apply DISTINCT",
sqlQuery: "SELECT id, name FROM users",
distinct: false,
shouldHave: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
params := &RequestParameters{Distinct: tt.distinct}
result := handler.ApplyDistinct(tt.sqlQuery, params)
if tt.shouldHave != "" {
if !strings.Contains(result, tt.shouldHave) {
t.Errorf("Expected SQL to contain %q, got: %s", tt.shouldHave, result)
}
} else {
// Should not have DISTINCT when not requested
if strings.Contains(result, "DISTINCT") && !tt.distinct {
t.Errorf("SQL should not contain DISTINCT when not requested: %s", result)
}
}
})
}
}
// TestParseCommaSeparated tests comma-separated value parsing
func TestParseCommaSeparated(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
input string
expected []string
}{
{
name: "Simple comma-separated",
input: "id,name,email",
expected: []string{"id", "name", "email"},
},
{
name: "With spaces",
input: "id, name, email",
expected: []string{"id", "name", "email"},
},
{
name: "Empty string",
input: "",
expected: nil,
},
{
name: "Single value",
input: "id",
expected: []string{"id"},
},
{
name: "With extra commas",
input: "id,,name,,email",
expected: []string{"id", "name", "email"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.parseCommaSeparated(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("Expected %d values, got %d", len(tt.expected), len(result))
return
}
for i, expected := range tt.expected {
if result[i] != expected {
t.Errorf("Expected value %d to be %s, got %s", i, expected, result[i])
}
}
})
}
}
// TestSqlQryWhereOr tests OR WHERE clause manipulation
func TestSqlQryWhereOr(t *testing.T) {
tests := []struct {
name string
sqlQuery string
condition string
shouldContain []string
}{
{
name: "Add WHERE with OR to query without WHERE",
sqlQuery: "SELECT * FROM users",
condition: "status = 'inactive'",
shouldContain: []string{"WHERE", "status = 'inactive'"},
},
{
name: "Add OR to query with existing WHERE",
sqlQuery: "SELECT * FROM users WHERE id > 0",
condition: "status = 'inactive'",
shouldContain: []string{"WHERE", "OR", "(status = 'inactive')"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sqlQryWhereOr(tt.sqlQuery, tt.condition)
for _, expected := range tt.shouldContain {
if !strings.Contains(result, expected) {
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
}
}
})
}
}

View File

@@ -0,0 +1,83 @@
package funcspec
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks registers security hooks for funcspec handlers
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
// We provide audit logging for data access tracking
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 1: BeforeQueryList - Audit logging before query list execution
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
secCtx := newFuncSpecSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
// Hook 2: BeforeQuery - Audit logging before single query execution
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
secCtx := newFuncSpecSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
// Note: Row-level security and column masking are challenging in funcspec
// because the SQL query is fully user-defined. Security should be implemented
// at the SQL function level or through database policies (RLS).
}
// funcSpecSecurityContext adapts funcspec.HookContext to security.SecurityContext interface
type funcSpecSecurityContext struct {
ctx *HookContext
}
func newFuncSpecSecurityContext(ctx *HookContext) security.SecurityContext {
return &funcSpecSecurityContext{ctx: ctx}
}
func (f *funcSpecSecurityContext) GetContext() context.Context {
return f.ctx.Context
}
func (f *funcSpecSecurityContext) GetUserID() (int, bool) {
if f.ctx.UserContext == nil {
return 0, false
}
return int(f.ctx.UserContext.UserID), true
}
func (f *funcSpecSecurityContext) GetSchema() string {
// funcspec doesn't have a schema concept, extract from SQL query or use default
return "public"
}
func (f *funcSpecSecurityContext) GetEntity() string {
// funcspec doesn't have an entity concept, could parse from SQL or use a placeholder
return "sql_query"
}
func (f *funcSpecSecurityContext) GetModel() interface{} {
// funcspec doesn't use models in the same way as restheadspec
return nil
}
func (f *funcSpecSecurityContext) GetQuery() interface{} {
// In funcspec, the query is a string, not a query builder object
return f.ctx.SQLQuery
}
func (f *funcSpecSecurityContext) SetQuery(query interface{}) {
// In funcspec, we could modify the SQL string, but this should be done cautiously
if sqlQuery, ok := query.(string); ok {
f.ctx.SQLQuery = sqlQuery
}
}
func (f *funcSpecSecurityContext) GetResult() interface{} {
return f.ctx.Result
}
func (f *funcSpecSecurityContext) SetResult(result interface{}) {
f.ctx.Result = result
}

View File

@@ -1,15 +1,19 @@
package logger package logger
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"os" "os"
"runtime/debug" "runtime/debug"
"go.uber.org/zap" "go.uber.org/zap"
errortracking "github.com/bitechdev/ResolveSpec/pkg/errortracking"
) )
var Logger *zap.SugaredLogger var Logger *zap.SugaredLogger
var errorTracker errortracking.Provider
func Init(dev bool) { func Init(dev bool) {
@@ -23,6 +27,15 @@ func Init(dev bool) {
} }
func UpdateLoggerPath(path string, dev bool) {
defaultConfig := zap.NewProductionConfig()
if dev {
defaultConfig = zap.NewDevelopmentConfig()
}
defaultConfig.OutputPaths = []string{path}
UpdateLogger(&defaultConfig)
}
func UpdateLogger(config *zap.Config) { func UpdateLogger(config *zap.Config) {
defaultConfig := zap.NewProductionConfig() defaultConfig := zap.NewProductionConfig()
defaultConfig.OutputPaths = []string{"resolvespec.log"} defaultConfig.OutputPaths = []string{"resolvespec.log"}
@@ -40,6 +53,28 @@ func UpdateLogger(config *zap.Config) {
Info("ResolveSpec Logger initialized") Info("ResolveSpec Logger initialized")
} }
// InitErrorTracking initializes the error tracking provider
func InitErrorTracking(provider errortracking.Provider) {
errorTracker = provider
if errorTracker != nil {
Info("Error tracking initialized")
}
}
// GetErrorTracker returns the current error tracking provider
func GetErrorTracker() errortracking.Provider {
return errorTracker
}
// CloseErrorTracking flushes and closes the error tracking provider
func CloseErrorTracking() error {
if errorTracker != nil {
errorTracker.Flush(5)
return errorTracker.Close()
}
return nil
}
func Info(template string, args ...interface{}) { func Info(template string, args ...interface{}) {
if Logger == nil { if Logger == nil {
log.Printf(template, args...) log.Printf(template, args...)
@@ -49,19 +84,35 @@ func Info(template string, args ...interface{}) {
} }
func Warn(template string, args ...interface{}) { func Warn(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
if Logger == nil { if Logger == nil {
log.Printf(template, args...) log.Printf("%s", message)
return } else {
Logger.Warnw(message, "process_id", os.Getpid())
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{
"process_id": os.Getpid(),
})
} }
Logger.Warnw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
} }
func Error(template string, args ...interface{}) { func Error(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
if Logger == nil { if Logger == nil {
log.Printf(template, args...) log.Printf("%s", message)
return } else {
Logger.Errorw(message, "process_id", os.Getpid())
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{
"process_id": os.Getpid(),
})
} }
Logger.Errorw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
} }
func Debug(template string, args ...interface{}) { func Debug(template string, args ...interface{}) {
@@ -75,7 +126,7 @@ func Debug(template string, args ...interface{}) {
// CatchPanic - Handle panic // CatchPanic - Handle panic
func CatchPanicCallback(location string, cb func(err any)) { func CatchPanicCallback(location string, cb func(err any)) {
if err := recover(); err != nil { if err := recover(); err != nil {
// callstack := debug.Stack() callstack := debug.Stack()
if Logger != nil { if Logger != nil {
Error("Panic in %s : %v", location, err) Error("Panic in %s : %v", location, err)
@@ -84,14 +135,13 @@ func CatchPanicCallback(location string, cb func(err any)) {
debug.PrintStack() debug.PrintStack()
} }
// push to sentry // Send to error tracker
// hub := sentry.CurrentHub() if errorTracker != nil {
// if hub != nil { errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{
// evtID := hub.Recover(err) "location": location,
// if evtID != nil { "process_id": os.Getpid(),
// sentry.Flush(time.Second * 2) })
// } }
// }
if cb != nil { if cb != nil {
cb(err) cb(err)
@@ -104,13 +154,26 @@ func CatchPanic(location string) {
CatchPanicCallback(location, nil) CatchPanicCallback(location, nil)
} }
// RecoverPanic recovers from panics and returns an error // HandlePanic logs a panic and returns it as an error
// Use this in deferred functions to convert panics into errors // This should be called with the result of recover() from a deferred function
func RecoverPanic(methodName string) error { // Example usage:
if r := recover(); r != nil { //
stack := debug.Stack() // defer func() {
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack)) // if r := recover(); r != nil {
return fmt.Errorf("panic in %s: %v", methodName, r) // err = logger.HandlePanic("MethodName", r)
// }
// }()
func HandlePanic(methodName string, r any) error {
stack := debug.Stack()
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{
"method": methodName,
"process_id": os.Getpid(),
})
} }
return nil
return fmt.Errorf("panic in %s: %v", methodName, r)
} }

259
pkg/metrics/README.md Normal file
View File

@@ -0,0 +1,259 @@
# Metrics Package
A pluggable metrics collection system with Prometheus implementation.
## Quick Start
```go
import "github.com/bitechdev/ResolveSpec/pkg/metrics"
// Initialize Prometheus provider
provider := metrics.NewPrometheusProvider()
metrics.SetProvider(provider)
// Apply middleware to your router
router.Use(provider.Middleware)
// Expose metrics endpoint
http.Handle("/metrics", provider.Handler())
```
## Provider Interface
The package uses a provider interface, allowing you to plug in different metric systems:
```go
type Provider interface {
RecordHTTPRequest(method, path, status string, duration time.Duration)
IncRequestsInFlight()
DecRequestsInFlight()
RecordDBQuery(operation, table string, duration time.Duration, err error)
RecordCacheHit(provider string)
RecordCacheMiss(provider string)
UpdateCacheSize(provider string, size int64)
Handler() http.Handler
}
```
## Recording Metrics
### HTTP Metrics (Automatic)
When using the middleware, HTTP metrics are recorded automatically:
```go
router.Use(provider.Middleware)
```
**Collected:**
- Request duration (histogram)
- Request count by method, path, and status
- Requests in flight (gauge)
### Database Metrics
```go
start := time.Now()
rows, err := db.Query("SELECT * FROM users WHERE id = ?", userID)
duration := time.Since(start)
metrics.GetProvider().RecordDBQuery("SELECT", "users", duration, err)
```
### Cache Metrics
```go
// Record cache hit
metrics.GetProvider().RecordCacheHit("memory")
// Record cache miss
metrics.GetProvider().RecordCacheMiss("memory")
// Update cache size
metrics.GetProvider().UpdateCacheSize("memory", 1024)
```
## Prometheus Metrics
When using `PrometheusProvider`, the following metrics are available:
| Metric Name | Type | Labels | Description |
|-------------|------|--------|-------------|
| `http_request_duration_seconds` | Histogram | method, path, status | HTTP request duration |
| `http_requests_total` | Counter | method, path, status | Total HTTP requests |
| `http_requests_in_flight` | Gauge | - | Current in-flight requests |
| `db_query_duration_seconds` | Histogram | operation, table | Database query duration |
| `db_queries_total` | Counter | operation, table, status | Total database queries |
| `cache_hits_total` | Counter | provider | Total cache hits |
| `cache_misses_total` | Counter | provider | Total cache misses |
| `cache_size_items` | Gauge | provider | Current cache size |
## Prometheus Queries
### HTTP Request Rate
```promql
rate(http_requests_total[5m])
```
### HTTP Request Duration (95th percentile)
```promql
histogram_quantile(0.95, rate(http_request_duration_seconds_bucket[5m]))
```
### Database Query Error Rate
```promql
rate(db_queries_total{status="error"}[5m])
```
### Cache Hit Rate
```promql
rate(cache_hits_total[5m]) / (rate(cache_hits_total[5m]) + rate(cache_misses_total[5m]))
```
## No-Op Provider
If metrics are disabled:
```go
// No provider set - uses no-op provider automatically
metrics.GetProvider().RecordHTTPRequest(...) // Does nothing
```
## Custom Provider
Implement your own metrics provider:
```go
type CustomProvider struct{}
func (c *CustomProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
// Send to your metrics system
}
// Implement other Provider interface methods...
func (c *CustomProvider) Handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return your metrics format
})
}
// Use it
metrics.SetProvider(&CustomProvider{})
```
## Complete Example
```go
package main
import (
"database/sql"
"log"
"net/http"
"time"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
"github.com/gorilla/mux"
)
func main() {
// Initialize metrics
provider := metrics.NewPrometheusProvider()
metrics.SetProvider(provider)
// Create router
router := mux.NewRouter()
// Apply metrics middleware
router.Use(provider.Middleware)
// Expose metrics endpoint
router.Handle("/metrics", provider.Handler())
// Your API routes
router.HandleFunc("/api/users", getUsersHandler)
log.Fatal(http.ListenAndServe(":8080", router))
}
func getUsersHandler(w http.ResponseWriter, r *http.Request) {
// Record database query
start := time.Now()
users, err := fetchUsers()
duration := time.Since(start)
metrics.GetProvider().RecordDBQuery("SELECT", "users", duration, err)
if err != nil {
http.Error(w, "Internal Server Error", 500)
return
}
// Return users...
}
```
## Docker Compose Example
```yaml
version: '3'
services:
app:
build: .
ports:
- "8080:8080"
prometheus:
image: prom/prometheus
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
command:
- '--config.file=/etc/prometheus/prometheus.yml'
grafana:
image: grafana/grafana
ports:
- "3000:3000"
depends_on:
- prometheus
```
**prometheus.yml:**
```yaml
global:
scrape_interval: 15s
scrape_configs:
- job_name: 'resolvespec'
static_configs:
- targets: ['app:8080']
```
## Best Practices
1. **Label Cardinality**: Keep labels low-cardinality
- ✅ Good: `method`, `status_code`
- ❌ Bad: `user_id`, `timestamp`
2. **Path Normalization**: Normalize dynamic paths
```go
// Instead of /api/users/123
// Use /api/users/:id
```
3. **Metric Naming**: Follow Prometheus conventions
- Use `_total` suffix for counters
- Use `_seconds` suffix for durations
- Use base units (seconds, not milliseconds)
4. **Performance**: Metrics collection is lock-free and highly performant
- Safe for high-throughput applications
- Minimal overhead (<1% in most cases)

73
pkg/metrics/interfaces.go Normal file
View File

@@ -0,0 +1,73 @@
package metrics
import (
"net/http"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Provider defines the interface for metric collection
type Provider interface {
// RecordHTTPRequest records metrics for an HTTP request
RecordHTTPRequest(method, path, status string, duration time.Duration)
// IncRequestsInFlight increments the in-flight requests counter
IncRequestsInFlight()
// DecRequestsInFlight decrements the in-flight requests counter
DecRequestsInFlight()
// RecordDBQuery records metrics for a database query
RecordDBQuery(operation, table string, duration time.Duration, err error)
// RecordCacheHit records a cache hit
RecordCacheHit(provider string)
// RecordCacheMiss records a cache miss
RecordCacheMiss(provider string)
// UpdateCacheSize updates the cache size metric
UpdateCacheSize(provider string, size int64)
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
Handler() http.Handler
}
// globalProvider is the global metrics provider
var globalProvider Provider
// SetProvider sets the global metrics provider
func SetProvider(p Provider) {
globalProvider = p
}
// GetProvider returns the current metrics provider
func GetProvider() Provider {
if globalProvider == nil {
// Return no-op provider if none is set
return &NoOpProvider{}
}
return globalProvider
}
// NoOpProvider is a no-op implementation of Provider
type NoOpProvider struct{}
func (n *NoOpProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {}
func (n *NoOpProvider) IncRequestsInFlight() {}
func (n *NoOpProvider) DecRequestsInFlight() {}
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
}
func (n *NoOpProvider) RecordCacheHit(provider string) {}
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
func (n *NoOpProvider) Handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
_, err := w.Write([]byte("Metrics provider not configured"))
if err != nil {
logger.Warn("Failed to write. %v", err)
}
})
}

174
pkg/metrics/prometheus.go Normal file
View File

@@ -0,0 +1,174 @@
package metrics
import (
"net/http"
"strconv"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// PrometheusProvider implements the Provider interface using Prometheus
type PrometheusProvider struct {
requestDuration *prometheus.HistogramVec
requestTotal *prometheus.CounterVec
requestsInFlight prometheus.Gauge
dbQueryDuration *prometheus.HistogramVec
dbQueryTotal *prometheus.CounterVec
cacheHits *prometheus.CounterVec
cacheMisses *prometheus.CounterVec
cacheSize *prometheus.GaugeVec
}
// NewPrometheusProvider creates a new Prometheus metrics provider
func NewPrometheusProvider() *PrometheusProvider {
return &PrometheusProvider{
requestDuration: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "HTTP request duration in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"method", "path", "status"},
),
requestTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"method", "path", "status"},
),
requestsInFlight: promauto.NewGauge(
prometheus.GaugeOpts{
Name: "http_requests_in_flight",
Help: "Current number of HTTP requests being processed",
},
),
dbQueryDuration: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "db_query_duration_seconds",
Help: "Database query duration in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"operation", "table"},
),
dbQueryTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "db_queries_total",
Help: "Total number of database queries",
},
[]string{"operation", "table", "status"},
),
cacheHits: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_hits_total",
Help: "Total number of cache hits",
},
[]string{"provider"},
),
cacheMisses: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_misses_total",
Help: "Total number of cache misses",
},
[]string{"provider"},
),
cacheSize: promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "cache_size_items",
Help: "Number of items in cache",
},
[]string{"provider"},
),
}
}
// ResponseWriter wraps http.ResponseWriter to capture status code
type ResponseWriter struct {
http.ResponseWriter
statusCode int
}
func NewResponseWriter(w http.ResponseWriter) *ResponseWriter {
return &ResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
}
func (rw *ResponseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
// RecordHTTPRequest implements Provider interface
func (p *PrometheusProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
p.requestDuration.WithLabelValues(method, path, status).Observe(duration.Seconds())
p.requestTotal.WithLabelValues(method, path, status).Inc()
}
// IncRequestsInFlight implements Provider interface
func (p *PrometheusProvider) IncRequestsInFlight() {
p.requestsInFlight.Inc()
}
// DecRequestsInFlight implements Provider interface
func (p *PrometheusProvider) DecRequestsInFlight() {
p.requestsInFlight.Dec()
}
// RecordDBQuery implements Provider interface
func (p *PrometheusProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
status := "success"
if err != nil {
status = "error"
}
p.dbQueryDuration.WithLabelValues(operation, table).Observe(duration.Seconds())
p.dbQueryTotal.WithLabelValues(operation, table, status).Inc()
}
// RecordCacheHit implements Provider interface
func (p *PrometheusProvider) RecordCacheHit(provider string) {
p.cacheHits.WithLabelValues(provider).Inc()
}
// RecordCacheMiss implements Provider interface
func (p *PrometheusProvider) RecordCacheMiss(provider string) {
p.cacheMisses.WithLabelValues(provider).Inc()
}
// UpdateCacheSize implements Provider interface
func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) {
p.cacheSize.WithLabelValues(provider).Set(float64(size))
}
// Handler implements Provider interface
func (p *PrometheusProvider) Handler() http.Handler {
return promhttp.Handler()
}
// Middleware returns an HTTP middleware that collects metrics
func (p *PrometheusProvider) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Increment in-flight requests
p.IncRequestsInFlight()
defer p.DecRequestsInFlight()
// Wrap response writer to capture status code
rw := NewResponseWriter(w)
// Call next handler
next.ServeHTTP(rw, r)
// Record metrics
duration := time.Since(start)
status := strconv.Itoa(rw.statusCode)
p.RecordHTTPRequest(r.Method, r.URL.Path, status, duration)
})
}

806
pkg/middleware/README.md Normal file
View File

@@ -0,0 +1,806 @@
# Middleware Package
HTTP middleware utilities for security and performance.
## Table of Contents
1. [Rate Limiting](#rate-limiting)
2. [Request Size Limits](#request-size-limits)
3. [Input Sanitization](#input-sanitization)
---
## Rate Limiting
Production-grade rate limiting using token bucket algorithm.
### Quick Start
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
// Create rate limiter: 100 requests per second, burst of 20
rateLimiter := middleware.NewRateLimiter(100, 20)
// Apply to all routes
router.Use(rateLimiter.Middleware)
```
### Basic Usage
```go
package main
import (
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/gorilla/mux"
)
func main() {
router := mux.NewRouter()
// Rate limit: 10 requests per second, burst of 5
rateLimiter := middleware.NewRateLimiter(10, 5)
router.Use(rateLimiter.Middleware)
router.HandleFunc("/api/data", dataHandler)
log.Fatal(http.ListenAndServe(":8080", router))
}
```
### Custom Key Extraction
By default, rate limiting is per IP address. Customize the key:
```go
// Rate limit by User ID from header
keyFunc := func(r *http.Request) string {
userID := r.Header.Get("X-User-ID")
if userID == "" {
return r.RemoteAddr // Fallback to IP
}
return "user:" + userID
}
router.Use(rateLimiter.MiddlewareWithKeyFunc(keyFunc))
```
### Advanced Key Functions
**By API Key:**
```go
keyFunc := func(r *http.Request) string {
apiKey := r.Header.Get("X-API-Key")
if apiKey == "" {
return r.RemoteAddr
}
return "api:" + apiKey
}
```
**By Authenticated User:**
```go
keyFunc := func(r *http.Request) string {
// Extract from JWT or session
user := getUserFromContext(r.Context())
if user != nil {
return "user:" + user.ID
}
return r.RemoteAddr
}
```
**By Path + User:**
```go
keyFunc := func(r *http.Request) string {
user := getUserFromContext(r.Context())
if user != nil {
return fmt.Sprintf("user:%s:path:%s", user.ID, r.URL.Path)
}
return r.URL.Path + ":" + r.RemoteAddr
}
```
### Different Limits Per Route
```go
func main() {
router := mux.NewRouter()
// Public endpoints: 10 rps
publicLimiter := middleware.NewRateLimiter(10, 5)
// API endpoints: 100 rps
apiLimiter := middleware.NewRateLimiter(100, 20)
// Admin endpoints: 1000 rps
adminLimiter := middleware.NewRateLimiter(1000, 50)
// Apply different limiters to subrouters
publicRouter := router.PathPrefix("/public").Subrouter()
publicRouter.Use(publicLimiter.Middleware)
apiRouter := router.PathPrefix("/api").Subrouter()
apiRouter.Use(apiLimiter.Middleware)
adminRouter := router.PathPrefix("/admin").Subrouter()
adminRouter.Use(adminLimiter.Middleware)
}
```
### Rate Limit Response
When rate limited, clients receive:
```http
HTTP/1.1 429 Too Many Requests
Content-Type: text/plain
```
### Configuration Examples
**Tight Rate Limit (Anti-abuse):**
```go
// 1 request per second, burst of 3
rateLimiter := middleware.NewRateLimiter(1, 3)
```
**Moderate Rate Limit (Standard API):**
```go
// 100 requests per second, burst of 20
rateLimiter := middleware.NewRateLimiter(100, 20)
```
**Generous Rate Limit (Internal Services):**
```go
// 1000 requests per second, burst of 100
rateLimiter := middleware.NewRateLimiter(1000, 100)
```
**Time-based Limits:**
```go
// 60 requests per minute = 1 request per second
rateLimiter := middleware.NewRateLimiter(1, 10)
// 1000 requests per hour ≈ 0.28 requests per second
rateLimiter := middleware.NewRateLimiter(0.28, 50)
```
### Understanding Burst
The burst parameter allows short bursts above the rate:
```go
// Rate: 10 rps, Burst: 5
// Allows up to 5 requests immediately, then 10/second
rateLimiter := middleware.NewRateLimiter(10, 5)
```
**Bucket fills at rate:** 10 tokens/second
**Bucket capacity:** 5 tokens
**Request consumes:** 1 token
**Example traffic pattern:**
- T=0s: 5 requests → ✅ All allowed (burst)
- T=0.1s: 1 request → ❌ Denied (bucket empty)
- T=0.5s: 1 request → ✅ Allowed (bucket refilled 0.5 tokens)
- T=1s: 1 request → ✅ Allowed (bucket has ~1 token)
### Cleanup Behavior
The rate limiter automatically cleans up inactive limiters every 5 minutes to prevent memory leaks.
### Performance Characteristics
- **Memory**: ~100 bytes per active limiter
- **Throughput**: >1M requests/second
- **Latency**: <1μs per request
- **Concurrency**: Lock-free for rate checks
### Production Deployment
**With Reverse Proxy:**
```go
// Use X-Forwarded-For or X-Real-IP
keyFunc := func(r *http.Request) string {
// Check proxy headers first
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return strings.Split(ip, ",")[0]
}
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
return r.RemoteAddr
}
router.Use(rateLimiter.MiddlewareWithKeyFunc(keyFunc))
```
**Environment-based Configuration:**
```go
import "os"
func getRateLimiter() *middleware.RateLimiter {
rps := getEnvFloat("RATE_LIMIT_RPS", 100)
burst := getEnvInt("RATE_LIMIT_BURST", 20)
return middleware.NewRateLimiter(rps, burst)
}
```
### Testing Rate Limits
```bash
# Send 10 requests rapidly
for i in {1..10}; do
curl -w "Status: %{http_code}\n" http://localhost:8080/api/data
done
```
**Expected output:**
```
Status: 200 # Request 1-5 (within burst)
Status: 200
Status: 200
Status: 200
Status: 200
Status: 429 # Request 6-10 (rate limited)
Status: 429
Status: 429
Status: 429
Status: 429
```
### Complete Example
```go
package main
import (
"encoding/json"
"log"
"net/http"
"os"
"strconv"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/gorilla/mux"
)
func main() {
// Configuration from environment
rps, _ := strconv.ParseFloat(os.Getenv("RATE_LIMIT_RPS"), 64)
if rps == 0 {
rps = 100 // Default
}
burst, _ := strconv.Atoi(os.Getenv("RATE_LIMIT_BURST"))
if burst == 0 {
burst = 20 // Default
}
// Create rate limiter
rateLimiter := middleware.NewRateLimiter(rps, burst)
// Custom key extraction
keyFunc := func(r *http.Request) string {
// Try API key first
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
return "api:" + apiKey
}
// Try authenticated user
if userID := r.Header.Get("X-User-ID"); userID != "" {
return "user:" + userID
}
// Fall back to IP
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return ip
}
return r.RemoteAddr
}
// Create router
router := mux.NewRouter()
// Apply rate limiting
router.Use(rateLimiter.MiddlewareWithKeyFunc(keyFunc))
// Routes
router.HandleFunc("/api/data", dataHandler)
router.HandleFunc("/health", healthHandler)
log.Printf("Starting server with rate limit: %.1f rps, burst: %d", rps, burst)
log.Fatal(http.ListenAndServe(":8080", router))
}
func dataHandler(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"message": "Data endpoint",
})
}
func healthHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}
```
## Best Practices
1. **Set Appropriate Limits**: Consider your backend capacity
- Database: Can it handle X queries/second?
- External APIs: What are their rate limits?
- Server resources: CPU, memory, connections
2. **Use Burst Wisely**: Allow legitimate traffic spikes
- Too low: Reject valid bursts
- Too high: Allow abuse
3. **Monitor Rate Limits**: Track how often limits are hit
```go
// Log rate limit events
if rateLimited {
log.Printf("Rate limited: %s", clientKey)
}
```
4. **Provide Feedback**: Include rate limit headers (future enhancement)
```http
X-RateLimit-Limit: 100
X-RateLimit-Remaining: 95
X-RateLimit-Reset: 1640000000
```
5. **Tiered Limits**: Different limits for different user tiers
```go
func getRateLimiter(userTier string) *middleware.RateLimiter {
switch userTier {
case "premium":
return middleware.NewRateLimiter(1000, 100)
case "standard":
return middleware.NewRateLimiter(100, 20)
default:
return middleware.NewRateLimiter(10, 5)
}
}
```
---
## Request Size Limits
Protect against oversized request bodies with configurable size limits.
### Quick Start
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
// Default: 10MB limit
sizeLimiter := middleware.NewRequestSizeLimiter(0)
router.Use(sizeLimiter.Middleware)
```
### Custom Size Limit
```go
// 5MB limit
sizeLimiter := middleware.NewRequestSizeLimiter(5 * 1024 * 1024)
router.Use(sizeLimiter.Middleware)
// Or use constants
sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size5MB)
```
### Available Size Constants
```go
middleware.Size1MB // 1 MB
middleware.Size5MB // 5 MB
middleware.Size10MB // 10 MB (default)
middleware.Size50MB // 50 MB
middleware.Size100MB // 100 MB
```
### Different Limits Per Route
```go
func main() {
router := mux.NewRouter()
// File upload endpoint: 50MB
uploadLimiter := middleware.NewRequestSizeLimiter(middleware.Size50MB)
uploadRouter := router.PathPrefix("/upload").Subrouter()
uploadRouter.Use(uploadLimiter.Middleware)
// API endpoints: 1MB
apiLimiter := middleware.NewRequestSizeLimiter(middleware.Size1MB)
apiRouter := router.PathPrefix("/api").Subrouter()
apiRouter.Use(apiLimiter.Middleware)
}
```
### Dynamic Size Limits
```go
// Custom size based on request
sizeFunc := func(r *http.Request) int64 {
// Premium users get 50MB
if isPremiumUser(r) {
return middleware.Size50MB
}
// Free users get 5MB
return middleware.Size5MB
}
router.Use(sizeLimiter.MiddlewareWithCustomSize(sizeFunc))
```
**By Content-Type:**
```go
sizeFunc := func(r *http.Request) int64 {
contentType := r.Header.Get("Content-Type")
switch {
case strings.Contains(contentType, "multipart/form-data"):
return middleware.Size50MB // File uploads
case strings.Contains(contentType, "application/json"):
return middleware.Size1MB // JSON APIs
default:
return middleware.Size10MB // Default
}
}
```
### Error Response
When size limit exceeded:
```http
HTTP/1.1 413 Request Entity Too Large
X-Max-Request-Size: 10485760
http: request body too large
```
### Complete Example
```go
package main
import (
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/gorilla/mux"
)
func main() {
router := mux.NewRouter()
// API routes: 1MB limit
api := router.PathPrefix("/api").Subrouter()
apiLimiter := middleware.NewRequestSizeLimiter(middleware.Size1MB)
api.Use(apiLimiter.Middleware)
api.HandleFunc("/users", createUserHandler).Methods("POST")
// Upload routes: 50MB limit
upload := router.PathPrefix("/upload").Subrouter()
uploadLimiter := middleware.NewRequestSizeLimiter(middleware.Size50MB)
upload.Use(uploadLimiter.Middleware)
upload.HandleFunc("/file", uploadFileHandler).Methods("POST")
log.Fatal(http.ListenAndServe(":8080", router))
}
```
---
## Input Sanitization
Protect against XSS, injection attacks, and malicious input.
### Quick Start
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
// Default sanitizer (safe defaults)
sanitizer := middleware.DefaultSanitizer()
router.Use(sanitizer.Middleware)
```
### Sanitizer Types
**Default Sanitizer (Recommended):**
```go
sanitizer := middleware.DefaultSanitizer()
// ✓ Escapes HTML entities
// ✓ Removes null bytes
// ✓ Removes control characters
// ✓ Blocks XSS patterns (script tags, event handlers)
// ✗ Does not strip HTML (allows legitimate content)
```
**Strict Sanitizer:**
```go
sanitizer := middleware.StrictSanitizer()
// ✓ All default features
// ✓ Strips ALL HTML tags
// ✓ Max string length: 10,000 chars
```
### Custom Configuration
```go
sanitizer := &middleware.Sanitizer{
StripHTML: true, // Remove HTML tags
EscapeHTML: false, // Don't escape (already stripped)
RemoveNullBytes: true, // Remove \x00
RemoveControlChars: true, // Remove dangerous control chars
MaxStringLength: 5000, // Limit to 5000 chars
// Block patterns (regex)
BlockPatterns: []*regexp.Regexp{
regexp.MustCompile(`(?i)<script`),
regexp.MustCompile(`(?i)javascript:`),
},
// Custom sanitization function
CustomSanitizer: func(s string) string {
// Your custom logic
return strings.ToLower(s)
},
}
router.Use(sanitizer.Middleware)
```
### What Gets Sanitized
**Automatic (via middleware):**
- Query parameters
- Headers (User-Agent, Referer, X-Forwarded-For, X-Real-IP)
**Manual (in your handler):**
- Request body (JSON, form data)
- Database queries
- File names
### Manual Sanitization
**String Values:**
```go
sanitizer := middleware.DefaultSanitizer()
// Sanitize user input
username := sanitizer.Sanitize(r.FormValue("username"))
email := sanitizer.Sanitize(r.FormValue("email"))
```
**Map/JSON Data:**
```go
var data map[string]interface{}
json.Unmarshal(body, &data)
// Sanitize all string values recursively
sanitizedData := sanitizer.SanitizeMap(data)
```
**Nested Structures:**
```go
type User struct {
Name string
Email string
Bio string
Profile map[string]interface{}
}
// After unmarshaling
user.Name = sanitizer.Sanitize(user.Name)
user.Email = sanitizer.Sanitize(user.Email)
user.Bio = sanitizer.Sanitize(user.Bio)
user.Profile = sanitizer.SanitizeMap(user.Profile)
```
### Specialized Sanitizers
**Filenames:**
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
filename := middleware.SanitizeFilename(uploadedFilename)
// Removes: .., /, \, null bytes
// Limits: 255 characters
```
**Emails:**
```go
email := middleware.SanitizeEmail(" USER@EXAMPLE.COM ")
// Result: "user@example.com"
// Trims, lowercases, removes null bytes
```
**URLs:**
```go
url := middleware.SanitizeURL(userInput)
// Blocks: javascript:, data: protocols
// Removes: null bytes
```
### Blocked Patterns (Default)
The default sanitizer blocks:
1. **Script tags**: `<script>...</script>`
2. **JavaScript protocol**: `javascript:alert(1)`
3. **Event handlers**: `onclick="..."`, `onerror="..."`
4. **Iframes**: `<iframe src="...">`
5. **Objects**: `<object data="...">`
6. **Embeds**: `<embed src="...">`
### Security Best Practices
**1. Layer Defense:**
```go
// Layer 1: Middleware (query params, headers)
router.Use(sanitizer.Middleware)
// Layer 2: Input validation (in handler)
func createUserHandler(w http.ResponseWriter, r *http.Request) {
var user User
json.NewDecoder(r.Body).Decode(&user)
// Sanitize
user.Name = sanitizer.Sanitize(user.Name)
user.Email = middleware.SanitizeEmail(user.Email)
// Validate
if !isValidEmail(user.Email) {
http.Error(w, "Invalid email", 400)
return
}
// Use parameterized queries (prevents SQL injection)
db.Exec("INSERT INTO users (name, email) VALUES (?, ?)",
user.Name, user.Email)
}
```
**2. Context-Aware Sanitization:**
```go
// HTML content (user posts, comments)
sanitizer := middleware.StrictSanitizer()
post.Content = sanitizer.Sanitize(post.Content)
// Structured data (JSON API)
sanitizer := middleware.DefaultSanitizer()
data = sanitizer.SanitizeMap(jsonData)
// Search queries (preserve special chars)
query = middleware.SanitizeFilename(searchTerm) // Light sanitization
```
**3. Output Encoding:**
```go
// When rendering HTML
import "html/template"
tmpl := template.Must(template.New("page").Parse(`
<h1>{{.Title}}</h1>
<p>{{.Content}}</p>
`))
// template.HTML automatically escapes
tmpl.Execute(w, data)
```
### Complete Example
```go
package main
import (
"encoding/json"
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/gorilla/mux"
)
func main() {
router := mux.NewRouter()
// Apply sanitization middleware
sanitizer := middleware.DefaultSanitizer()
router.Use(sanitizer.Middleware)
router.HandleFunc("/api/users", createUserHandler).Methods("POST")
log.Fatal(http.ListenAndServe(":8080", router))
}
func createUserHandler(w http.ResponseWriter, r *http.Request) {
sanitizer := middleware.DefaultSanitizer()
var user struct {
Name string `json:"name"`
Email string `json:"email"`
Bio string `json:"bio"`
}
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
http.Error(w, "Invalid JSON", 400)
return
}
// Sanitize inputs
user.Name = sanitizer.Sanitize(user.Name)
user.Email = middleware.SanitizeEmail(user.Email)
user.Bio = sanitizer.Sanitize(user.Bio)
// Validate
if len(user.Name) == 0 || len(user.Email) == 0 {
http.Error(w, "Name and email required", 400)
return
}
// Save to database (use parameterized queries!)
// db.Exec("INSERT INTO users (name, email, bio) VALUES (?, ?, ?)",
// user.Name, user.Email, user.Bio)
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{
"status": "created",
})
}
```
### Testing Sanitization
```bash
# Test XSS prevention
curl -X POST http://localhost:8080/api/users \
-H "Content-Type: application/json" \
-d '{
"name": "<script>alert(1)</script>John",
"email": "test@example.com",
"bio": "My bio with <iframe src=\"evil.com\"></iframe>"
}'
# Script tags and iframes should be removed
```
### Performance
- **Overhead**: <1ms per request for typical payloads
- **Regex compilation**: Done once at initialization
- **Safe for production**: Minimal performance impact
- **Safe for production**: Minimal performance impact

212
pkg/middleware/blacklist.go Normal file
View File

@@ -0,0 +1,212 @@
package middleware
import (
"encoding/json"
"net"
"net/http"
"strings"
"sync"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// IPBlacklist provides IP blocking functionality
type IPBlacklist struct {
mu sync.RWMutex
ips map[string]bool // Individual IPs
cidrs []*net.IPNet // CIDR ranges
reason map[string]string
useProxy bool // Whether to check X-Forwarded-For headers
}
// BlacklistConfig configures the IP blacklist
type BlacklistConfig struct {
// UseProxy indicates whether to extract IP from X-Forwarded-For/X-Real-IP headers
UseProxy bool
}
// NewIPBlacklist creates a new IP blacklist
func NewIPBlacklist(config BlacklistConfig) *IPBlacklist {
return &IPBlacklist{
ips: make(map[string]bool),
cidrs: make([]*net.IPNet, 0),
reason: make(map[string]string),
useProxy: config.UseProxy,
}
}
// BlockIP blocks a single IP address
func (bl *IPBlacklist) BlockIP(ip string, reason string) error {
// Validate IP
if net.ParseIP(ip) == nil {
return &net.ParseError{Type: "IP address", Text: ip}
}
bl.mu.Lock()
defer bl.mu.Unlock()
bl.ips[ip] = true
if reason != "" {
bl.reason[ip] = reason
}
return nil
}
// BlockCIDR blocks an IP range using CIDR notation
func (bl *IPBlacklist) BlockCIDR(cidr string, reason string) error {
_, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
return err
}
bl.mu.Lock()
defer bl.mu.Unlock()
bl.cidrs = append(bl.cidrs, ipNet)
if reason != "" {
bl.reason[cidr] = reason
}
return nil
}
// UnblockIP removes an IP from the blacklist
func (bl *IPBlacklist) UnblockIP(ip string) {
bl.mu.Lock()
defer bl.mu.Unlock()
delete(bl.ips, ip)
delete(bl.reason, ip)
}
// UnblockCIDR removes a CIDR range from the blacklist
func (bl *IPBlacklist) UnblockCIDR(cidr string) {
bl.mu.Lock()
defer bl.mu.Unlock()
// Find and remove the CIDR
for i, ipNet := range bl.cidrs {
if ipNet.String() == cidr {
bl.cidrs = append(bl.cidrs[:i], bl.cidrs[i+1:]...)
break
}
}
delete(bl.reason, cidr)
}
// IsBlocked checks if an IP is blacklisted
func (bl *IPBlacklist) IsBlocked(ip string) (blacklist bool, reason string) {
bl.mu.RLock()
defer bl.mu.RUnlock()
// Check individual IPs
if bl.ips[ip] {
return true, bl.reason[ip]
}
// Check CIDR ranges
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false, ""
}
for i, ipNet := range bl.cidrs {
if ipNet.Contains(parsedIP) {
cidr := ipNet.String()
// Try to find reason by CIDR or by index
if reason, ok := bl.reason[cidr]; ok {
return true, reason
}
// Check if reason was stored by original CIDR string
for key, reason := range bl.reason {
if strings.Contains(key, "/") && key == cidr {
return true, reason
}
}
// Return true even if no reason found
if i < len(bl.cidrs) {
return true, ""
}
}
}
return false, ""
}
// GetBlacklist returns all blacklisted IPs and CIDRs
func (bl *IPBlacklist) GetBlacklist() (ips []string, cidrs []string) {
bl.mu.RLock()
defer bl.mu.RUnlock()
ips = make([]string, 0, len(bl.ips))
for ip := range bl.ips {
ips = append(ips, ip)
}
cidrs = make([]string, 0, len(bl.cidrs))
for _, ipNet := range bl.cidrs {
cidrs = append(cidrs, ipNet.String())
}
return ips, cidrs
}
// Middleware returns an HTTP middleware that blocks blacklisted IPs
func (bl *IPBlacklist) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var clientIP string
if bl.useProxy {
clientIP = getClientIP(r)
// Clean up IPv6 brackets if present
clientIP = strings.Trim(clientIP, "[]")
} else {
// Extract IP from RemoteAddr
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
clientIP = r.RemoteAddr[:idx]
} else {
clientIP = r.RemoteAddr
}
clientIP = strings.Trim(clientIP, "[]")
}
blocked, reason := bl.IsBlocked(clientIP)
if blocked {
response := map[string]interface{}{
"error": "forbidden",
"message": "Access denied",
}
if reason != "" {
response["reason"] = reason
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
err := json.NewEncoder(w).Encode(response)
if err != nil {
logger.Debug("Failed to write blacklist response: %v", err)
}
return
}
next.ServeHTTP(w, r)
})
}
// StatsHandler returns an HTTP handler that shows blacklist statistics
func (bl *IPBlacklist) StatsHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ips, cidrs := bl.GetBlacklist()
stats := map[string]interface{}{
"blocked_ips": ips,
"blocked_cidrs": cidrs,
"total_ips": len(ips),
"total_cidrs": len(cidrs),
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(stats)
if err != nil {
logger.Debug("Failed to encode stats: %v", err)
}
})
}

View File

@@ -0,0 +1,254 @@
package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestIPBlacklist_BlockIP(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
// Block an IP
err := bl.BlockIP("192.168.1.100", "Suspicious activity")
if err != nil {
t.Fatalf("BlockIP() error = %v", err)
}
// Check if IP is blocked
blocked, reason := bl.IsBlocked("192.168.1.100")
if !blocked {
t.Error("IP should be blocked")
}
if reason != "Suspicious activity" {
t.Errorf("Reason = %q, want %q", reason, "Suspicious activity")
}
// Check non-blocked IP
blocked, _ = bl.IsBlocked("192.168.1.1")
if blocked {
t.Error("IP should not be blocked")
}
}
func TestIPBlacklist_BlockCIDR(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
// Block a CIDR range
err := bl.BlockCIDR("10.0.0.0/24", "Internal network blocked")
if err != nil {
t.Fatalf("BlockCIDR() error = %v", err)
}
// Check IPs in range
testIPs := []string{
"10.0.0.1",
"10.0.0.100",
"10.0.0.254",
}
for _, ip := range testIPs {
blocked, _ := bl.IsBlocked(ip)
if !blocked {
t.Errorf("IP %s should be blocked by CIDR", ip)
}
}
// Check IP outside range
blocked, _ := bl.IsBlocked("10.0.1.1")
if blocked {
t.Error("IP outside CIDR range should not be blocked")
}
}
func TestIPBlacklist_UnblockIP(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
// Block and then unblock
bl.BlockIP("192.168.1.100", "Test")
blocked, _ := bl.IsBlocked("192.168.1.100")
if !blocked {
t.Error("IP should be blocked")
}
bl.UnblockIP("192.168.1.100")
blocked, _ = bl.IsBlocked("192.168.1.100")
if blocked {
t.Error("IP should be unblocked")
}
}
func TestIPBlacklist_UnblockCIDR(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
// Block and then unblock CIDR
bl.BlockCIDR("10.0.0.0/24", "Test")
blocked, _ := bl.IsBlocked("10.0.0.1")
if !blocked {
t.Error("IP should be blocked by CIDR")
}
bl.UnblockCIDR("10.0.0.0/24")
blocked, _ = bl.IsBlocked("10.0.0.1")
if blocked {
t.Error("IP should be unblocked after CIDR removal")
}
}
func TestIPBlacklist_Middleware(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
bl.BlockIP("192.168.1.100", "Banned")
handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
// Blocked IP should get 403
t.Run("BlockedIP", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
}
var response map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if response["error"] != "forbidden" {
t.Errorf("Error = %v, want %q", response["error"], "forbidden")
}
})
// Allowed IP should succeed
t.Run("AllowedIP", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
}
})
}
func TestIPBlacklist_MiddlewareWithProxy(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: true})
bl.BlockIP("203.0.113.1", "Blocked via proxy")
handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Test X-Forwarded-For
t.Run("X-Forwarded-For", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "10.0.0.1:12345"
req.Header.Set("X-Forwarded-For", "203.0.113.1")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
}
})
// Test X-Real-IP
t.Run("X-Real-IP", func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "10.0.0.1:12345"
req.Header.Set("X-Real-IP", "203.0.113.1")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
}
})
}
func TestIPBlacklist_StatsHandler(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
bl.BlockIP("192.168.1.100", "Test1")
bl.BlockIP("192.168.1.101", "Test2")
bl.BlockCIDR("10.0.0.0/24", "Test CIDR")
handler := bl.StatsHandler()
req := httptest.NewRequest("GET", "/blacklist-stats", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
}
var stats map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if int(stats["total_ips"].(float64)) != 2 {
t.Errorf("total_ips = %v, want 2", stats["total_ips"])
}
if int(stats["total_cidrs"].(float64)) != 1 {
t.Errorf("total_cidrs = %v, want 1", stats["total_cidrs"])
}
}
func TestIPBlacklist_GetBlacklist(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
bl.BlockIP("192.168.1.100", "")
bl.BlockIP("192.168.1.101", "")
bl.BlockCIDR("10.0.0.0/24", "")
ips, cidrs := bl.GetBlacklist()
if len(ips) != 2 {
t.Errorf("len(ips) = %d, want 2", len(ips))
}
if len(cidrs) != 1 {
t.Errorf("len(cidrs) = %d, want 1", len(cidrs))
}
// Verify CIDR format
if cidrs[0] != "10.0.0.0/24" {
t.Errorf("CIDR = %q, want %q", cidrs[0], "10.0.0.0/24")
}
}
func TestIPBlacklist_InvalidIP(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
err := bl.BlockIP("invalid-ip", "Test")
if err == nil {
t.Error("BlockIP() should return error for invalid IP")
}
}
func TestIPBlacklist_InvalidCIDR(t *testing.T) {
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
err := bl.BlockCIDR("invalid-cidr", "Test")
if err == nil {
t.Error("BlockCIDR() should return error for invalid CIDR")
}
}

233
pkg/middleware/ratelimit.go Normal file
View File

@@ -0,0 +1,233 @@
// Package middleware provides HTTP middleware functionalities such as rate limiting and IP blacklisting.
package middleware
//nolint:all
import (
"encoding/json"
"net/http"
"strings"
"sync"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"golang.org/x/time/rate"
)
// RateLimiter provides rate limiting functionality
type RateLimiter struct {
mu sync.RWMutex
limiters map[string]*rate.Limiter
rate rate.Limit
burst int
cleanup time.Duration
}
// NewRateLimiter creates a new rate limiter
// rps is requests per second, burst is the maximum burst size
func NewRateLimiter(rps float64, burst int) *RateLimiter {
rl := &RateLimiter{
limiters: make(map[string]*rate.Limiter),
rate: rate.Limit(rps),
burst: burst,
cleanup: 5 * time.Minute, // Clean up stale limiters every 5 minutes
}
// Start cleanup goroutine
go rl.cleanupRoutine()
return rl
}
// getLimiter returns the rate limiter for a given key (e.g., IP address)
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
rl.mu.RLock()
limiter, exists := rl.limiters[key]
rl.mu.RUnlock()
if exists {
return limiter
}
rl.mu.Lock()
defer rl.mu.Unlock()
// Double-check after acquiring write lock
if limiter, exists := rl.limiters[key]; exists {
return limiter
}
limiter = rate.NewLimiter(rl.rate, rl.burst)
rl.limiters[key] = limiter
return limiter
}
// cleanupRoutine periodically removes inactive limiters
func (rl *RateLimiter) cleanupRoutine() {
ticker := time.NewTicker(rl.cleanup)
defer ticker.Stop()
for range ticker.C {
rl.mu.Lock()
// Simple cleanup: remove all limiters
// In production, you might want to track last access time
rl.limiters = make(map[string]*rate.Limiter)
rl.mu.Unlock()
}
}
// Middleware returns an HTTP middleware that applies rate limiting
// Automatically handles X-Forwarded-For headers when behind a proxy
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Extract client IP, handling proxy headers
key := getClientIP(r)
limiter := rl.getLimiter(key)
if !limiter.Allow() {
http.Error(w, `{"error":"rate_limit_exceeded","message":"Too many requests"}`, http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
// MiddlewareWithKeyFunc returns an HTTP middleware with a custom key extraction function
func (rl *RateLimiter) MiddlewareWithKeyFunc(keyFunc func(*http.Request) string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := keyFunc(r)
if key == "" {
key = r.RemoteAddr
}
limiter := rl.getLimiter(key)
if !limiter.Allow() {
http.Error(w, `{"error":"rate_limit_exceeded","message":"Too many requests"}`, http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}
// RateLimitInfo contains information about a specific IP's rate limit status
type RateLimitInfo struct {
IP string `json:"ip"`
TokensRemaining float64 `json:"tokens_remaining"`
Limit float64 `json:"limit"`
Burst int `json:"burst"`
}
// GetTrackedIPs returns all IPs currently being tracked by the rate limiter
func (rl *RateLimiter) GetTrackedIPs() []string {
rl.mu.RLock()
defer rl.mu.RUnlock()
ips := make([]string, 0, len(rl.limiters))
for ip := range rl.limiters {
ips = append(ips, ip)
}
return ips
}
// GetRateLimitInfo returns rate limit information for a specific IP
func (rl *RateLimiter) GetRateLimitInfo(ip string) *RateLimitInfo {
rl.mu.RLock()
limiter, exists := rl.limiters[ip]
rl.mu.RUnlock()
if !exists {
// Return default info for untracked IP
return &RateLimitInfo{
IP: ip,
TokensRemaining: float64(rl.burst),
Limit: float64(rl.rate),
Burst: rl.burst,
}
}
return &RateLimitInfo{
IP: ip,
TokensRemaining: limiter.Tokens(),
Limit: float64(rl.rate),
Burst: rl.burst,
}
}
// GetAllRateLimitInfo returns rate limit information for all tracked IPs
func (rl *RateLimiter) GetAllRateLimitInfo() []*RateLimitInfo {
ips := rl.GetTrackedIPs()
info := make([]*RateLimitInfo, 0, len(ips))
for _, ip := range ips {
info = append(info, rl.GetRateLimitInfo(ip))
}
return info
}
// StatsHandler returns an HTTP handler that exposes rate limit statistics
// Example: GET /rate-limit-stats
func (rl *RateLimiter) StatsHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Support querying specific IP via ?ip=x.x.x.x
if ip := r.URL.Query().Get("ip"); ip != "" {
info := rl.GetRateLimitInfo(ip)
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(info)
if err != nil {
logger.Debug("Failed to encode json: %v", err)
}
return
}
// Return all tracked IPs
allInfo := rl.GetAllRateLimitInfo()
stats := map[string]interface{}{
"total_tracked_ips": len(allInfo),
"rate_limit_config": map[string]interface{}{
"requests_per_second": float64(rl.rate),
"burst": rl.burst,
},
"tracked_ips": allInfo,
}
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(stats)
if err != nil {
logger.Debug("Failed to encode json: %v", err)
}
})
}
// getClientIP extracts the real client IP from the request
// Handles X-Forwarded-For, X-Real-IP, and falls back to RemoteAddr
func getClientIP(r *http.Request) string {
// Check X-Forwarded-For header (most common in production)
// Format: X-Forwarded-For: client, proxy1, proxy2
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first IP (the original client)
if idx := strings.Index(xff, ","); idx != -1 {
return strings.TrimSpace(xff[:idx])
}
return strings.TrimSpace(xff)
}
// Check X-Real-IP header (used by some proxies like nginx)
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return strings.TrimSpace(xri)
}
// Fall back to RemoteAddr
// Remove port if present (format: "ip:port")
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}

View File

@@ -0,0 +1,388 @@
package middleware
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestRateLimiter(t *testing.T) {
// Create rate limiter: 2 requests per second, burst of 2
rl := NewRateLimiter(2, 2)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}))
// First request should succeed
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("First request failed: got %d, want %d", w.Code, http.StatusOK)
}
// Second request should succeed (within burst)
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Second request failed: got %d, want %d", w.Code, http.StatusOK)
}
// Third request should be rate limited
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("Third request should be rate limited: got %d, want %d", w.Code, http.StatusTooManyRequests)
}
// Wait for rate limiter to refill
time.Sleep(600 * time.Millisecond)
// Request should succeed again
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Request after wait failed: got %d, want %d", w.Code, http.StatusOK)
}
}
func TestRateLimiterDifferentIPs(t *testing.T) {
rl := NewRateLimiter(1, 1)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First IP
req1 := httptest.NewRequest("GET", "/test", nil)
req1.RemoteAddr = "192.168.1.1:12345"
// Second IP
req2 := httptest.NewRequest("GET", "/test", nil)
req2.RemoteAddr = "192.168.1.2:12345"
// Both should succeed (different IPs)
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
if w1.Code != http.StatusOK {
t.Errorf("First IP request failed: got %d, want %d", w1.Code, http.StatusOK)
}
if w2.Code != http.StatusOK {
t.Errorf("Second IP request failed: got %d, want %d", w2.Code, http.StatusOK)
}
}
func TestGetClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
xForwardedFor string
xRealIP string
expectedIP string
}{
{
name: "RemoteAddr only",
remoteAddr: "192.168.1.1:12345",
expectedIP: "192.168.1.1",
},
{
name: "X-Forwarded-For single IP",
remoteAddr: "10.0.0.1:12345",
xForwardedFor: "203.0.113.1",
expectedIP: "203.0.113.1",
},
{
name: "X-Forwarded-For multiple IPs",
remoteAddr: "10.0.0.1:12345",
xForwardedFor: "203.0.113.1, 10.0.0.2, 10.0.0.3",
expectedIP: "203.0.113.1",
},
{
name: "X-Real-IP",
remoteAddr: "10.0.0.1:12345",
xRealIP: "203.0.113.1",
expectedIP: "203.0.113.1",
},
{
name: "X-Forwarded-For takes precedence over X-Real-IP",
remoteAddr: "10.0.0.1:12345",
xForwardedFor: "203.0.113.1",
xRealIP: "203.0.113.2",
expectedIP: "203.0.113.1",
},
{
name: "IPv6 address",
remoteAddr: "[2001:db8::1]:12345",
expectedIP: "[2001:db8::1]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = tt.remoteAddr
if tt.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
}
if tt.xRealIP != "" {
req.Header.Set("X-Real-IP", tt.xRealIP)
}
ip := getClientIP(req)
if ip != tt.expectedIP {
t.Errorf("getClientIP() = %q, want %q", ip, tt.expectedIP)
}
})
}
}
func TestRateLimiterWithCustomKeyFunc(t *testing.T) {
rl := NewRateLimiter(1, 1)
// Use user ID as key
keyFunc := func(r *http.Request) string {
userID := r.Header.Get("X-User-ID")
if userID == "" {
return r.RemoteAddr
}
return "user:" + userID
}
handler := rl.MiddlewareWithKeyFunc(keyFunc)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// User 1
req1 := httptest.NewRequest("GET", "/test", nil)
req1.Header.Set("X-User-ID", "user1")
// User 2
req2 := httptest.NewRequest("GET", "/test", nil)
req2.Header.Set("X-User-ID", "user2")
// Both users should succeed (different keys)
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
if w1.Code != http.StatusOK {
t.Errorf("User 1 request failed: got %d, want %d", w1.Code, http.StatusOK)
}
if w2.Code != http.StatusOK {
t.Errorf("User 2 request failed: got %d, want %d", w2.Code, http.StatusOK)
}
// User 1 second request should be rate limited
w1 = httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
if w1.Code != http.StatusTooManyRequests {
t.Errorf("User 1 second request should be rate limited: got %d, want %d", w1.Code, http.StatusTooManyRequests)
}
}
func TestRateLimiter_GetTrackedIPs(t *testing.T) {
rl := NewRateLimiter(10, 10)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make requests from different IPs
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
for _, ip := range ips {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = ip + ":12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
}
// Check tracked IPs
trackedIPs := rl.GetTrackedIPs()
if len(trackedIPs) != len(ips) {
t.Errorf("len(trackedIPs) = %d, want %d", len(trackedIPs), len(ips))
}
// Verify all IPs are tracked
ipMap := make(map[string]bool)
for _, ip := range trackedIPs {
ipMap[ip] = true
}
for _, ip := range ips {
if !ipMap[ip] {
t.Errorf("IP %s should be tracked", ip)
}
}
}
func TestRateLimiter_GetRateLimitInfo(t *testing.T) {
rl := NewRateLimiter(10, 5)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make a request
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// Get rate limit info
info := rl.GetRateLimitInfo("192.168.1.1")
if info.IP != "192.168.1.1" {
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
}
if info.Limit != 10.0 {
t.Errorf("Limit = %f, want 10.0", info.Limit)
}
if info.Burst != 5 {
t.Errorf("Burst = %d, want 5", info.Burst)
}
// Tokens should be less than burst after one request
if info.TokensRemaining >= float64(info.Burst) {
t.Errorf("TokensRemaining = %f, should be less than %d", info.TokensRemaining, info.Burst)
}
}
func TestRateLimiter_GetRateLimitInfo_UntrackedIP(t *testing.T) {
rl := NewRateLimiter(10, 5)
// Get info for untracked IP (should return default)
info := rl.GetRateLimitInfo("192.168.1.1")
if info.IP != "192.168.1.1" {
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
}
if info.TokensRemaining != float64(rl.burst) {
t.Errorf("TokensRemaining = %f, want %d (full burst)", info.TokensRemaining, rl.burst)
}
}
func TestRateLimiter_GetAllRateLimitInfo(t *testing.T) {
rl := NewRateLimiter(10, 10)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make requests from different IPs
ips := []string{"192.168.1.1", "192.168.1.2"}
for _, ip := range ips {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = ip + ":12345"
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
}
// Get all rate limit info
allInfo := rl.GetAllRateLimitInfo()
if len(allInfo) != len(ips) {
t.Errorf("len(allInfo) = %d, want %d", len(allInfo), len(ips))
}
// Verify each IP has info
for _, info := range allInfo {
found := false
for _, ip := range ips {
if info.IP == ip {
found = true
break
}
}
if !found {
t.Errorf("Unexpected IP in info: %s", info.IP)
}
}
}
func TestRateLimiter_StatsHandler(t *testing.T) {
rl := NewRateLimiter(10, 5)
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make requests from different IPs
req1 := httptest.NewRequest("GET", "/test", nil)
req1.RemoteAddr = "192.168.1.1:12345"
w1 := httptest.NewRecorder()
handler.ServeHTTP(w1, req1)
req2 := httptest.NewRequest("GET", "/test", nil)
req2.RemoteAddr = "192.168.1.2:12345"
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
// Test stats handler (all IPs)
t.Run("AllIPs", func(t *testing.T) {
statsHandler := rl.StatsHandler()
req := httptest.NewRequest("GET", "/rate-limit-stats", nil)
w := httptest.NewRecorder()
statsHandler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
}
var stats map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if int(stats["total_tracked_ips"].(float64)) != 2 {
t.Errorf("total_tracked_ips = %v, want 2", stats["total_tracked_ips"])
}
config := stats["rate_limit_config"].(map[string]interface{})
if config["requests_per_second"].(float64) != 10.0 {
t.Errorf("requests_per_second = %v, want 10.0", config["requests_per_second"])
}
})
// Test stats handler (specific IP)
t.Run("SpecificIP", func(t *testing.T) {
statsHandler := rl.StatsHandler()
req := httptest.NewRequest("GET", "/rate-limit-stats?ip=192.168.1.1", nil)
w := httptest.NewRecorder()
statsHandler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
}
var info RateLimitInfo
if err := json.Unmarshal(w.Body.Bytes(), &info); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if info.IP != "192.168.1.1" {
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
}
})
}

251
pkg/middleware/sanitize.go Normal file
View File

@@ -0,0 +1,251 @@
package middleware
import (
"html"
"net/http"
"regexp"
"strings"
)
// Sanitizer provides input sanitization beyond SQL injection protection
type Sanitizer struct {
// StripHTML removes HTML tags from input
StripHTML bool
// EscapeHTML escapes HTML entities
EscapeHTML bool
// RemoveNullBytes removes null bytes from input
RemoveNullBytes bool
// RemoveControlChars removes control characters (except newline, carriage return, tab)
RemoveControlChars bool
// MaxStringLength limits individual string field length (0 = no limit)
MaxStringLength int
// BlockPatterns are regex patterns to block (e.g., script tags, SQL keywords)
BlockPatterns []*regexp.Regexp
// Custom sanitization function
CustomSanitizer func(string) string
}
// DefaultSanitizer returns a sanitizer with secure defaults
func DefaultSanitizer() *Sanitizer {
return &Sanitizer{
StripHTML: false, // Don't strip by default (breaks legitimate HTML content)
EscapeHTML: true, // Escape HTML entities to prevent XSS
RemoveNullBytes: true, // Remove null bytes (security best practice)
RemoveControlChars: true, // Remove dangerous control characters
MaxStringLength: 0, // No limit by default
// Block common XSS and injection patterns
BlockPatterns: []*regexp.Regexp{
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`), // Script tags
regexp.MustCompile(`(?i)javascript:`), // JavaScript protocol
regexp.MustCompile(`(?i)on\w+\s*=`), // Event handlers (onclick, onerror, etc.)
regexp.MustCompile(`(?i)<iframe[^>]*>`), // Iframes
regexp.MustCompile(`(?i)<object[^>]*>`), // Objects
regexp.MustCompile(`(?i)<embed[^>]*>`), // Embeds
},
}
}
// StrictSanitizer returns a sanitizer with very strict rules
func StrictSanitizer() *Sanitizer {
s := DefaultSanitizer()
s.StripHTML = true
s.MaxStringLength = 10000
return s
}
// Sanitize sanitizes a string value
func (s *Sanitizer) Sanitize(value string) string {
if value == "" {
return value
}
// Remove null bytes
if s.RemoveNullBytes {
value = strings.ReplaceAll(value, "\x00", "")
}
// Remove control characters
if s.RemoveControlChars {
value = removeControlCharacters(value)
}
// Check block patterns
for _, pattern := range s.BlockPatterns {
if pattern.MatchString(value) {
// Replace matched pattern with empty string
value = pattern.ReplaceAllString(value, "")
}
}
// Strip HTML tags
if s.StripHTML {
value = stripHTMLTags(value)
}
// Escape HTML entities
if s.EscapeHTML && !s.StripHTML {
value = html.EscapeString(value)
}
// Apply max length
if s.MaxStringLength > 0 && len(value) > s.MaxStringLength {
value = value[:s.MaxStringLength]
}
// Apply custom sanitizer
if s.CustomSanitizer != nil {
value = s.CustomSanitizer(value)
}
return value
}
// SanitizeMap sanitizes all string values in a map
func (s *Sanitizer) SanitizeMap(data map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for key, value := range data {
result[key] = s.sanitizeValue(value)
}
return result
}
// sanitizeValue recursively sanitizes values
func (s *Sanitizer) sanitizeValue(value interface{}) interface{} {
switch v := value.(type) {
case string:
return s.Sanitize(v)
case map[string]interface{}:
return s.SanitizeMap(v)
case []interface{}:
result := make([]interface{}, len(v))
for i, item := range v {
result[i] = s.sanitizeValue(item)
}
return result
default:
return value
}
}
// Middleware returns an HTTP middleware that sanitizes request headers and query params
// Note: Body sanitization should be done at the application level after parsing
func (s *Sanitizer) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Sanitize query parameters
if r.URL.RawQuery != "" {
q := r.URL.Query()
sanitized := false
for key, values := range q {
for i, value := range values {
sanitizedValue := s.Sanitize(value)
if sanitizedValue != value {
values[i] = sanitizedValue
sanitized = true
}
}
if sanitized {
q[key] = values
}
}
if sanitized {
r.URL.RawQuery = q.Encode()
}
}
// Sanitize specific headers (User-Agent, Referer, etc.)
dangerousHeaders := []string{
"User-Agent",
"Referer",
"X-Forwarded-For",
"X-Real-IP",
}
for _, header := range dangerousHeaders {
if value := r.Header.Get(header); value != "" {
sanitized := s.Sanitize(value)
if sanitized != value {
r.Header.Set(header, sanitized)
}
}
}
next.ServeHTTP(w, r)
})
}
// Helper functions
// removeControlCharacters removes control characters except \n, \r, \t
func removeControlCharacters(s string) string {
var result strings.Builder
for _, r := range s {
// Keep newline, carriage return, tab, and non-control characters
if r == '\n' || r == '\r' || r == '\t' || r >= 32 {
result.WriteRune(r)
}
}
return result.String()
}
// stripHTMLTags removes HTML tags from a string
func stripHTMLTags(s string) string {
// Simple regex to remove HTML tags
re := regexp.MustCompile(`<[^>]*>`)
return re.ReplaceAllString(s, "")
}
// Common sanitization patterns
// SanitizeFilename sanitizes a filename
func SanitizeFilename(filename string) string {
// Remove path traversal attempts
filename = strings.ReplaceAll(filename, "..", "")
filename = strings.ReplaceAll(filename, "/", "")
filename = strings.ReplaceAll(filename, "\\", "")
// Remove null bytes
filename = strings.ReplaceAll(filename, "\x00", "")
// Limit length
if len(filename) > 255 {
filename = filename[:255]
}
return filename
}
// SanitizeEmail performs basic email sanitization
func SanitizeEmail(email string) string {
email = strings.TrimSpace(strings.ToLower(email))
// Remove dangerous characters
email = strings.ReplaceAll(email, "\x00", "")
email = removeControlCharacters(email)
return email
}
// SanitizeURL performs basic URL sanitization
func SanitizeURL(url string) string {
url = strings.TrimSpace(url)
// Remove null bytes
url = strings.ReplaceAll(url, "\x00", "")
// Block javascript: and data: protocols
if strings.HasPrefix(strings.ToLower(url), "javascript:") {
return ""
}
if strings.HasPrefix(strings.ToLower(url), "data:") {
return ""
}
return url
}

View File

@@ -0,0 +1,273 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestSanitizeXSS(t *testing.T) {
sanitizer := DefaultSanitizer()
tests := []struct {
name string
input string
contains string // String that should NOT be in output
}{
{
name: "Script tag",
input: "<script>alert(1)</script>",
contains: "<script>",
},
{
name: "JavaScript protocol",
input: "javascript:alert(1)",
contains: "javascript:",
},
{
name: "Event handler",
input: "<img onerror='alert(1)'>",
contains: "onerror=",
},
{
name: "Iframe",
input: "<iframe src='evil.com'></iframe>",
contains: "<iframe",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sanitizer.Sanitize(tt.input)
if result == tt.input {
t.Errorf("Sanitize() did not modify input: %q", tt.input)
}
})
}
}
func TestSanitizeNullBytes(t *testing.T) {
sanitizer := DefaultSanitizer()
input := "hello\x00world"
result := sanitizer.Sanitize(input)
if result == input {
t.Error("Null bytes should be removed")
}
if len(result) >= len(input) {
t.Errorf("Result length should be less than input: got %d, input %d", len(result), len(input))
}
}
func TestSanitizeControlCharacters(t *testing.T) {
sanitizer := DefaultSanitizer()
// Include various control characters
input := "hello\x01\x02world\x1F"
result := sanitizer.Sanitize(input)
if result == input {
t.Error("Control characters should be removed")
}
// Newlines, tabs, carriage returns should be preserved
input2 := "hello\nworld\t\r"
result2 := sanitizer.Sanitize(input2)
if result2 != input2 {
t.Errorf("Safe control characters should be preserved: got %q, want %q", result2, input2)
}
}
func TestSanitizeMap(t *testing.T) {
sanitizer := DefaultSanitizer()
input := map[string]interface{}{
"name": "<script>alert(1)</script>John",
"email": "test@example.com",
"nested": map[string]interface{}{
"bio": "<iframe src='evil.com'>Bio</iframe>",
},
}
result := sanitizer.SanitizeMap(input)
// Check that script tag was removed/escaped
name, ok := result["name"].(string)
if !ok || name == input["name"] {
t.Error("Name should be sanitized")
}
// Check nested map
nested, ok := result["nested"].(map[string]interface{})
if !ok {
t.Fatal("Nested should still be a map")
}
bio, ok := nested["bio"].(string)
if !ok || bio == input["nested"].(map[string]interface{})["bio"] {
t.Error("Nested bio should be sanitized")
}
}
func TestSanitizeMiddleware(t *testing.T) {
sanitizer := DefaultSanitizer()
handler := sanitizer.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check that query param was sanitized
param := r.URL.Query().Get("q")
if param == "<script>alert(1)</script>" {
t.Error("Query param should be sanitized")
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("GET", "/test?q=<script>alert(1)</script>", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Handler failed: got %d, want %d", w.Code, http.StatusOK)
}
}
func TestSanitizeFilename(t *testing.T) {
tests := []struct {
name string
input string
contains string // String that should NOT be in output
}{
{
name: "Path traversal",
input: "../../../etc/passwd",
contains: "..",
},
{
name: "Absolute path",
input: "/etc/passwd",
contains: "/",
},
{
name: "Windows path",
input: "..\\..\\windows\\system32",
contains: "\\",
},
{
name: "Null byte",
input: "file\x00.txt",
contains: "\x00",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeFilename(tt.input)
if result == tt.input {
t.Errorf("SanitizeFilename() did not modify input: %q", tt.input)
}
})
}
}
func TestSanitizeEmail(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "Uppercase",
input: "TEST@EXAMPLE.COM",
expected: "test@example.com",
},
{
name: "Whitespace",
input: " test@example.com ",
expected: "test@example.com",
},
{
name: "Null bytes",
input: "test\x00@example.com",
expected: "test@example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeEmail(tt.input)
if result != tt.expected {
t.Errorf("SanitizeEmail() = %q, want %q", result, tt.expected)
}
})
}
}
func TestSanitizeURL(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "JavaScript protocol",
input: "javascript:alert(1)",
expected: "",
},
{
name: "Data protocol",
input: "data:text/html,<script>alert(1)</script>",
expected: "",
},
{
name: "Valid HTTP URL",
input: "https://example.com",
expected: "https://example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeURL(tt.input)
if result != tt.expected {
t.Errorf("SanitizeURL() = %q, want %q", result, tt.expected)
}
})
}
}
func TestStrictSanitizer(t *testing.T) {
sanitizer := StrictSanitizer()
input := "<b>Bold text</b> with <script>alert(1)</script>"
result := sanitizer.Sanitize(input)
// Should strip ALL HTML tags
if result == input {
t.Error("Strict sanitizer should modify input")
}
// Should not contain any HTML tags
if len(result) > 0 && (result[0] == '<' || result[len(result)-1] == '>') {
t.Error("Result should not contain HTML tags")
}
}
func TestMaxStringLength(t *testing.T) {
sanitizer := &Sanitizer{
MaxStringLength: 10,
}
input := "This is a very long string that exceeds the maximum length"
result := sanitizer.Sanitize(input)
if len(result) != 10 {
t.Errorf("Result length = %d, want 10", len(result))
}
if result != input[:10] {
t.Errorf("Result = %q, want %q", result, input[:10])
}
}

View File

@@ -0,0 +1,70 @@
package middleware
import (
"fmt"
"net/http"
)
const (
// DefaultMaxRequestSize is the default maximum request body size (10MB)
DefaultMaxRequestSize = 10 * 1024 * 1024 // 10MB
// MaxRequestSizeHeader is the header name for max request size
MaxRequestSizeHeader = "X-Max-Request-Size"
)
// RequestSizeLimiter limits the size of request bodies
type RequestSizeLimiter struct {
maxSize int64
}
// NewRequestSizeLimiter creates a new request size limiter
// maxSize is in bytes. If 0, uses DefaultMaxRequestSize (10MB)
func NewRequestSizeLimiter(maxSize int64) *RequestSizeLimiter {
if maxSize <= 0 {
maxSize = DefaultMaxRequestSize
}
return &RequestSizeLimiter{
maxSize: maxSize,
}
}
// Middleware returns an HTTP middleware that enforces request size limits
func (rsl *RequestSizeLimiter) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Set max bytes reader on the request body
r.Body = http.MaxBytesReader(w, r.Body, rsl.maxSize)
// Add informational header
w.Header().Set(MaxRequestSizeHeader, fmt.Sprintf("%d", rsl.maxSize))
next.ServeHTTP(w, r)
})
}
// MiddlewareWithCustomSize returns middleware with a custom size limit function
// This allows different size limits based on the request
func (rsl *RequestSizeLimiter) MiddlewareWithCustomSize(sizeFunc func(*http.Request) int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
maxSize := sizeFunc(r)
if maxSize <= 0 {
maxSize = rsl.maxSize
}
r.Body = http.MaxBytesReader(w, r.Body, maxSize)
w.Header().Set(MaxRequestSizeHeader, fmt.Sprintf("%d", maxSize))
next.ServeHTTP(w, r)
})
}
}
// Common size limits
const (
Size1MB = 1 * 1024 * 1024
Size5MB = 5 * 1024 * 1024
Size10MB = 10 * 1024 * 1024
Size50MB = 50 * 1024 * 1024
Size100MB = 100 * 1024 * 1024
)

View File

@@ -0,0 +1,126 @@
package middleware
import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"
)
func TestRequestSizeLimiter(t *testing.T) {
// 1KB limit
limiter := NewRequestSizeLimiter(1024)
handler := limiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Try to read body
_, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusRequestEntityTooLarge)
return
}
w.WriteHeader(http.StatusOK)
}))
// Small request (should succeed)
t.Run("SmallRequest", func(t *testing.T) {
body := bytes.NewReader(make([]byte, 512)) // 512 bytes
req := httptest.NewRequest("POST", "/test", body)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Small request failed: got %d, want %d", w.Code, http.StatusOK)
}
// Check header
if maxSize := w.Header().Get(MaxRequestSizeHeader); maxSize != "1024" {
t.Errorf("MaxRequestSizeHeader = %q, want %q", maxSize, "1024")
}
})
// Large request (should fail)
t.Run("LargeRequest", func(t *testing.T) {
body := bytes.NewReader(make([]byte, 2048)) // 2KB
req := httptest.NewRequest("POST", "/test", body)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Errorf("Large request should fail: got %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
}
})
}
func TestRequestSizeLimiterDefault(t *testing.T) {
// Default limiter (10MB)
limiter := NewRequestSizeLimiter(0)
handler := limiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest("POST", "/test", bytes.NewReader(make([]byte, 1024)))
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Request failed: got %d, want %d", w.Code, http.StatusOK)
}
// Check default size
if maxSize := w.Header().Get(MaxRequestSizeHeader); maxSize != "10485760" {
t.Errorf("Default MaxRequestSizeHeader = %q, want %q", maxSize, "10485760")
}
}
func TestRequestSizeLimiterWithCustomSize(t *testing.T) {
limiter := NewRequestSizeLimiter(1024)
// Premium users get 10MB, regular users get 1KB
sizeFunc := func(r *http.Request) int64 {
if r.Header.Get("X-User-Tier") == "premium" {
return Size10MB
}
return 1024
}
handler := limiter.MiddlewareWithCustomSize(sizeFunc)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusRequestEntityTooLarge)
return
}
w.WriteHeader(http.StatusOK)
}))
// Regular user with large request (should fail)
t.Run("RegularUserLargeRequest", func(t *testing.T) {
body := bytes.NewReader(make([]byte, 2048))
req := httptest.NewRequest("POST", "/test", body)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusRequestEntityTooLarge {
t.Errorf("Regular user large request should fail: got %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
}
})
// Premium user with large request (should succeed)
t.Run("PremiumUserLargeRequest", func(t *testing.T) {
body := bytes.NewReader(make([]byte, 2048))
req := httptest.NewRequest("POST", "/test", body)
req.Header.Set("X-User-Tier", "premium")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Premium user large request failed: got %d, want %d", w.Code, http.StatusOK)
}
})
}

View File

@@ -28,8 +28,14 @@ func NewModelRegistry() *DefaultModelRegistry {
} }
} }
func GetDefaultRegistry() *DefaultModelRegistry {
return defaultRegistry
}
func SetDefaultRegistry(registry *DefaultModelRegistry) { func SetDefaultRegistry(registry *DefaultModelRegistry) {
registriesMutex.Lock() registriesMutex.Lock()
defer registriesMutex.Unlock()
foundAt := -1 foundAt := -1
for idx, r := range registries { for idx, r := range registries {
if r == defaultRegistry { if r == defaultRegistry {
@@ -43,9 +49,6 @@ func SetDefaultRegistry(registry *DefaultModelRegistry) {
} else { } else {
registries = append([]*DefaultModelRegistry{registry}, registries...) registries = append([]*DefaultModelRegistry{registry}, registries...)
} }
defer registriesMutex.Unlock()
} }
// AddRegistry adds a registry to the global list of registries // AddRegistry adds a registry to the global list of registries

321
pkg/openapi/README.md Normal file
View File

@@ -0,0 +1,321 @@
# OpenAPI Generator for ResolveSpec
This package provides automatic OpenAPI 3.0 specification generation for ResolveSpec, RestheadSpec, and FuncSpec API frameworks.
## Features
- **Automatic Schema Generation**: Generates OpenAPI schemas from Go struct models
- **Multiple Framework Support**: Works with RestheadSpec, ResolveSpec, and FuncSpec
- **Dynamic Endpoint Discovery**: Automatically discovers all registered models and generates paths
- **Query Parameter Access**: Access spec via `?openapi` on any endpoint or via `/openapi`
- **Comprehensive Documentation**: Includes all request/response schemas, parameters, and security schemes
## Quick Start
### RestheadSpec Example
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/openapi"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
"github.com/gorilla/mux"
)
func main() {
// 1. Create handler
handler := restheadspec.NewHandlerWithGORM(db)
// 2. Register models
handler.registry.RegisterModel("public.users", User{})
handler.registry.RegisterModel("public.products", Product{})
// 3. Configure OpenAPI generator
handler.SetOpenAPIGenerator(func() (string, error) {
generator := openapi.NewGenerator(openapi.GeneratorConfig{
Title: "My API",
Description: "API documentation",
Version: "1.0.0",
BaseURL: "http://localhost:8080",
Registry: handler.registry.(*modelregistry.DefaultModelRegistry),
IncludeRestheadSpec: true,
IncludeResolveSpec: false,
IncludeFuncSpec: false,
})
return generator.GenerateJSON()
})
// 4. Setup routes (automatically includes /openapi endpoint)
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler, nil)
// Start server
http.ListenAndServe(":8080", router)
}
```
### ResolveSpec Example
```go
func main() {
// 1. Create handler
handler := resolvespec.NewHandlerWithGORM(db)
// 2. Register models
handler.RegisterModel("public", "users", User{})
handler.RegisterModel("public", "products", Product{})
// 3. Configure OpenAPI generator
handler.SetOpenAPIGenerator(func() (string, error) {
generator := openapi.NewGenerator(openapi.GeneratorConfig{
Title: "My API",
Version: "1.0.0",
Registry: handler.registry.(*modelregistry.DefaultModelRegistry),
IncludeResolveSpec: true,
})
return generator.GenerateJSON()
})
// 4. Setup routes
router := mux.NewRouter()
resolvespec.SetupMuxRoutes(router, handler, nil)
http.ListenAndServe(":8080", router)
}
```
## Accessing the OpenAPI Specification
Once configured, the OpenAPI spec is available in two ways:
### 1. Global `/openapi` Endpoint
```bash
curl http://localhost:8080/openapi
```
Returns the complete OpenAPI specification for all registered models.
### 2. Query Parameter on Any Endpoint
```bash
# RestheadSpec
curl http://localhost:8080/public/users?openapi
# ResolveSpec
curl http://localhost:8080/resolve/public/users?openapi
```
Returns the same OpenAPI specification as `/openapi`.
## Generated Endpoints
### RestheadSpec
For each registered model (e.g., `public.users`), the following paths are generated:
- `GET /public/users` - List records with header-based filtering
- `POST /public/users` - Create a new record
- `GET /public/users/{id}` - Get a single record
- `PUT /public/users/{id}` - Update a record
- `PATCH /public/users/{id}` - Partially update a record
- `DELETE /public/users/{id}` - Delete a record
- `GET /public/users/metadata` - Get table metadata
- `OPTIONS /public/users` - CORS preflight
### ResolveSpec
For each registered model (e.g., `public.users`), the following paths are generated:
- `POST /resolve/public/users` - Execute operations (read, create, meta)
- `POST /resolve/public/users/{id}` - Execute operations (update, delete)
- `GET /resolve/public/users` - Get metadata
- `OPTIONS /resolve/public/users` - CORS preflight
## Schema Generation
The generator automatically extracts information from your Go struct tags:
```go
type User struct {
ID int `json:"id" gorm:"primaryKey" description:"User ID"`
Name string `json:"name" gorm:"not null" description:"User's full name"`
Email string `json:"email" gorm:"unique" description:"Email address"`
CreatedAt time.Time `json:"created_at" description:"Creation timestamp"`
Roles []string `json:"roles" description:"User roles"`
}
```
This generates an OpenAPI schema with:
- Property names from `json` tags
- Required fields from `gorm:"not null"` and non-pointer types
- Descriptions from `description` tags
- Proper type mappings (int → integer, time.Time → string with format: date-time, etc.)
## RestheadSpec Headers
The generator documents all RestheadSpec HTTP headers:
- `X-Filters` - JSON array of filter conditions
- `X-Columns` - Comma-separated columns to select
- `X-Sort` - JSON array of sort specifications
- `X-Limit` - Maximum records to return
- `X-Offset` - Records to skip
- `X-Preload` - Relations to eager load
- `X-Expand` - Relations to expand (LEFT JOIN)
- `X-Distinct` - Enable DISTINCT queries
- `X-Response-Format` - Response format (detail, simple, syncfusion)
- `X-Clean-JSON` - Remove null/empty fields
- `X-Custom-SQL-Where` - Custom WHERE clause (AND)
- `X-Custom-SQL-Or` - Custom WHERE clause (OR)
## ResolveSpec Request Body
The generator documents the ResolveSpec request body structure:
```json
{
"operation": "read",
"data": {},
"id": 123,
"options": {
"limit": 10,
"offset": 0,
"filters": [
{"column": "status", "operator": "eq", "value": "active"}
],
"sort": [
{"column": "created_at", "direction": "desc"}
]
}
}
```
## Security Schemes
The generator automatically includes common security schemes:
- **BearerAuth**: JWT Bearer token authentication
- **SessionToken**: Session token in Authorization header
- **CookieAuth**: Cookie-based session authentication
- **HeaderAuth**: Header-based user authentication (X-User-ID)
## FuncSpec Custom Endpoints
For FuncSpec, you can manually register custom SQL endpoints:
```go
funcSpecEndpoints := map[string]openapi.FuncSpecEndpoint{
"/api/reports/sales": {
Path: "/api/reports/sales",
Method: "GET",
Summary: "Get sales report",
Description: "Returns sales data for specified date range",
SQLQuery: "SELECT * FROM sales WHERE date BETWEEN [start_date] AND [end_date]",
Parameters: []string{"start_date", "end_date"},
},
}
generator := openapi.NewGenerator(openapi.GeneratorConfig{
// ... other config
IncludeFuncSpec: true,
FuncSpecEndpoints: funcSpecEndpoints,
})
```
## Combining Multiple Frameworks
You can generate a unified OpenAPI spec that includes multiple frameworks:
```go
generator := openapi.NewGenerator(openapi.GeneratorConfig{
Title: "Unified API",
Version: "1.0.0",
Registry: sharedRegistry,
IncludeRestheadSpec: true,
IncludeResolveSpec: true,
IncludeFuncSpec: true,
FuncSpecEndpoints: funcSpecEndpoints,
})
```
This will generate a complete spec with all endpoints from all frameworks.
## Advanced Customization
You can customize the generated spec further:
```go
handler.SetOpenAPIGenerator(func() (string, error) {
generator := openapi.NewGenerator(config)
// Generate initial spec
spec, err := generator.Generate()
if err != nil {
return "", err
}
// Add contact information
spec.Info.Contact = &openapi.Contact{
Name: "API Support",
Email: "support@example.com",
URL: "https://example.com/support",
}
// Add additional servers
spec.Servers = append(spec.Servers, openapi.Server{
URL: "https://staging.example.com",
Description: "Staging Server",
})
// Convert back to JSON
data, _ := json.MarshalIndent(spec, "", " ")
return string(data), nil
})
```
## Using with Swagger UI
You can serve the generated OpenAPI spec with Swagger UI:
1. Get the spec from `/openapi`
2. Load it in Swagger UI at `https://petstore.swagger.io/`
3. Or self-host Swagger UI and point it to your `/openapi` endpoint
Example with self-hosted Swagger UI:
```go
// Serve Swagger UI static files
router.PathPrefix("/swagger/").Handler(
http.StripPrefix("/swagger/", http.FileServer(http.Dir("./swagger-ui"))),
)
// Configure Swagger UI to use /openapi
```
## Testing
You can test the OpenAPI endpoint:
```bash
# Get the full spec
curl http://localhost:8080/openapi | jq
# Validate with openapi-generator
openapi-generator validate -i http://localhost:8080/openapi
# Generate client SDKs
openapi-generator generate -i http://localhost:8080/openapi -g typescript-fetch -o ./client
```
## Complete Example
See `example.go` in this package for complete, runnable examples including:
- Basic RestheadSpec setup
- Basic ResolveSpec setup
- Combining both frameworks
- Adding FuncSpec endpoints
- Advanced customization
## License
Part of the ResolveSpec project.

236
pkg/openapi/example.go Normal file
View File

@@ -0,0 +1,236 @@
package openapi
import (
"github.com/gorilla/mux"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
)
// ExampleRestheadSpec shows how to configure OpenAPI generation for RestheadSpec
func ExampleRestheadSpec(db *gorm.DB) {
// 1. Create registry and register models
registry := modelregistry.NewModelRegistry()
// registry.RegisterModel("public.users", User{})
// registry.RegisterModel("public.products", Product{})
// 2. Create handler with custom registry
// import "github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
// gormAdapter := database.NewGormAdapter(db)
// handler := restheadspec.NewHandler(gormAdapter, registry)
// Or use the convenience function (creates its own registry):
handler := restheadspec.NewHandlerWithGORM(db)
// 3. Configure OpenAPI generator
handler.SetOpenAPIGenerator(func() (string, error) {
generator := NewGenerator(GeneratorConfig{
Title: "My API",
Description: "API documentation for my application",
Version: "1.0.0",
BaseURL: "http://localhost:8080",
Registry: registry,
IncludeRestheadSpec: true,
IncludeResolveSpec: false,
IncludeFuncSpec: false,
})
return generator.GenerateJSON()
})
// 4. Setup routes (includes /openapi endpoint)
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, handler, nil)
// Now the following endpoints are available:
// GET /openapi - Full OpenAPI spec
// GET /public/users?openapi - OpenAPI spec
// GET /public/products?openapi - OpenAPI spec
// etc.
}
// ExampleResolveSpec shows how to configure OpenAPI generation for ResolveSpec
func ExampleResolveSpec(db *gorm.DB) {
// 1. Create registry and register models
registry := modelregistry.NewModelRegistry()
// registry.RegisterModel("public.users", User{})
// registry.RegisterModel("public.products", Product{})
// 2. Create handler with custom registry
// import "github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
// gormAdapter := database.NewGormAdapter(db)
// handler := resolvespec.NewHandler(gormAdapter, registry)
// Or use the convenience function (creates its own registry):
handler := resolvespec.NewHandlerWithGORM(db)
// Note: handler.RegisterModel("schema", "entity", model) can be used
// 3. Configure OpenAPI generator
handler.SetOpenAPIGenerator(func() (string, error) {
generator := NewGenerator(GeneratorConfig{
Title: "My API",
Description: "API documentation for my application",
Version: "1.0.0",
BaseURL: "http://localhost:8080",
Registry: registry,
IncludeRestheadSpec: false,
IncludeResolveSpec: true,
IncludeFuncSpec: false,
})
return generator.GenerateJSON()
})
// 4. Setup routes (includes /openapi endpoint)
router := mux.NewRouter()
resolvespec.SetupMuxRoutes(router, handler, nil)
// Now the following endpoints are available:
// GET /openapi - Full OpenAPI spec
// POST /resolve/public/users?openapi - OpenAPI spec
// POST /resolve/public/products?openapi - OpenAPI spec
// etc.
}
// ExampleBothSpecs shows how to combine both RestheadSpec and ResolveSpec
func ExampleBothSpecs(db *gorm.DB) {
// Create shared registry
sharedRegistry := modelregistry.NewModelRegistry()
// Register models once
// sharedRegistry.RegisterModel("public.users", User{})
// sharedRegistry.RegisterModel("public.products", Product{})
// Create handlers - they will have separate registries initially
restheadHandler := restheadspec.NewHandlerWithGORM(db)
resolveHandler := resolvespec.NewHandlerWithGORM(db)
// Note: If you want to use a shared registry, create handlers manually:
// import "github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
// gormAdapter := database.NewGormAdapter(db)
// restheadHandler := restheadspec.NewHandler(gormAdapter, sharedRegistry)
// resolveHandler := resolvespec.NewHandler(gormAdapter, sharedRegistry)
// Configure OpenAPI generator for both
generatorFunc := func() (string, error) {
generator := NewGenerator(GeneratorConfig{
Title: "My Unified API",
Description: "Complete API documentation with both RestheadSpec and ResolveSpec endpoints",
Version: "1.0.0",
BaseURL: "http://localhost:8080",
Registry: sharedRegistry,
IncludeRestheadSpec: true,
IncludeResolveSpec: true,
IncludeFuncSpec: false,
})
return generator.GenerateJSON()
}
restheadHandler.SetOpenAPIGenerator(generatorFunc)
resolveHandler.SetOpenAPIGenerator(generatorFunc)
// Setup routes
router := mux.NewRouter()
restheadspec.SetupMuxRoutes(router, restheadHandler, nil)
// Add ResolveSpec routes under /resolve prefix
resolveRouter := router.PathPrefix("/resolve").Subrouter()
resolvespec.SetupMuxRoutes(resolveRouter, resolveHandler, nil)
// Now you have both styles of API available:
// GET /openapi - Full OpenAPI spec (both styles)
// GET /public/users - RestheadSpec list endpoint
// POST /resolve/public/users - ResolveSpec operation endpoint
// GET /public/users?openapi - OpenAPI spec
// POST /resolve/public/users?openapi - OpenAPI spec
}
// ExampleWithFuncSpec shows how to add FuncSpec endpoints to OpenAPI
func ExampleWithFuncSpec() {
// FuncSpec endpoints need to be registered manually since they don't use model registry
generatorFunc := func() (string, error) {
funcSpecEndpoints := map[string]FuncSpecEndpoint{
"/api/reports/sales": {
Path: "/api/reports/sales",
Method: "GET",
Summary: "Get sales report",
Description: "Returns sales data for the specified date range",
SQLQuery: "SELECT * FROM sales WHERE date BETWEEN [start_date] AND [end_date]",
Parameters: []string{"start_date", "end_date"},
},
"/api/analytics/users": {
Path: "/api/analytics/users",
Method: "GET",
Summary: "Get user analytics",
Description: "Returns user activity analytics",
SQLQuery: "SELECT * FROM user_analytics WHERE user_id = [user_id]",
Parameters: []string{"user_id"},
},
}
generator := NewGenerator(GeneratorConfig{
Title: "My API with Custom Queries",
Description: "API with FuncSpec custom SQL endpoints",
Version: "1.0.0",
BaseURL: "http://localhost:8080",
Registry: modelregistry.NewModelRegistry(),
IncludeRestheadSpec: false,
IncludeResolveSpec: false,
IncludeFuncSpec: true,
FuncSpecEndpoints: funcSpecEndpoints,
})
return generator.GenerateJSON()
}
// Use this generator function with your handlers
_ = generatorFunc
}
// ExampleCustomization shows advanced customization options
func ExampleCustomization() {
// Create registry and register models with descriptions using struct tags
registry := modelregistry.NewModelRegistry()
// type User struct {
// ID int `json:"id" gorm:"primaryKey" description:"Unique user identifier"`
// Name string `json:"name" description:"User's full name"`
// Email string `json:"email" gorm:"unique" description:"User's email address"`
// }
// registry.RegisterModel("public.users", User{})
// Advanced configuration - create generator function
generatorFunc := func() (string, error) {
generator := NewGenerator(GeneratorConfig{
Title: "My Advanced API",
Description: "Comprehensive API documentation with custom configuration",
Version: "2.1.0",
BaseURL: "https://api.myapp.com",
Registry: registry,
IncludeRestheadSpec: true,
IncludeResolveSpec: true,
IncludeFuncSpec: false,
})
// Generate the spec
// spec, err := generator.Generate()
// if err != nil {
// return "", err
// }
// Customize the spec further if needed
// spec.Info.Contact = &Contact{
// Name: "API Support",
// Email: "support@myapp.com",
// URL: "https://myapp.com/support",
// }
// Add additional servers
// spec.Servers = append(spec.Servers, Server{
// URL: "https://staging-api.myapp.com",
// Description: "Staging Server",
// })
// Convert back to JSON - or use GenerateJSON() for simple cases
return generator.GenerateJSON()
}
// Use this generator function with your handlers
_ = generatorFunc
}

513
pkg/openapi/generator.go Normal file
View File

@@ -0,0 +1,513 @@
package openapi
import (
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
// OpenAPISpec represents the OpenAPI 3.0 specification structure
type OpenAPISpec struct {
OpenAPI string `json:"openapi"`
Info Info `json:"info"`
Servers []Server `json:"servers,omitempty"`
Paths map[string]PathItem `json:"paths"`
Components Components `json:"components"`
Security []map[string][]string `json:"security,omitempty"`
}
type Info struct {
Title string `json:"title"`
Description string `json:"description,omitempty"`
Version string `json:"version"`
Contact *Contact `json:"contact,omitempty"`
}
type Contact struct {
Name string `json:"name,omitempty"`
URL string `json:"url,omitempty"`
Email string `json:"email,omitempty"`
}
type Server struct {
URL string `json:"url"`
Description string `json:"description,omitempty"`
}
type PathItem struct {
Get *Operation `json:"get,omitempty"`
Post *Operation `json:"post,omitempty"`
Put *Operation `json:"put,omitempty"`
Patch *Operation `json:"patch,omitempty"`
Delete *Operation `json:"delete,omitempty"`
Options *Operation `json:"options,omitempty"`
}
type Operation struct {
Summary string `json:"summary,omitempty"`
Description string `json:"description,omitempty"`
OperationID string `json:"operationId,omitempty"`
Tags []string `json:"tags,omitempty"`
Parameters []Parameter `json:"parameters,omitempty"`
RequestBody *RequestBody `json:"requestBody,omitempty"`
Responses map[string]Response `json:"responses"`
Security []map[string][]string `json:"security,omitempty"`
}
type Parameter struct {
Name string `json:"name"`
In string `json:"in"` // "query", "header", "path", "cookie"
Description string `json:"description,omitempty"`
Required bool `json:"required,omitempty"`
Schema *Schema `json:"schema,omitempty"`
Example interface{} `json:"example,omitempty"`
}
type RequestBody struct {
Description string `json:"description,omitempty"`
Required bool `json:"required,omitempty"`
Content map[string]MediaType `json:"content"`
}
type MediaType struct {
Schema *Schema `json:"schema,omitempty"`
Example interface{} `json:"example,omitempty"`
}
type Response struct {
Description string `json:"description"`
Content map[string]MediaType `json:"content,omitempty"`
}
type Components struct {
Schemas map[string]Schema `json:"schemas,omitempty"`
SecuritySchemes map[string]SecurityScheme `json:"securitySchemes,omitempty"`
}
type Schema struct {
Type string `json:"type,omitempty"`
Format string `json:"format,omitempty"`
Description string `json:"description,omitempty"`
Properties map[string]*Schema `json:"properties,omitempty"`
Items *Schema `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Ref string `json:"$ref,omitempty"`
Enum []interface{} `json:"enum,omitempty"`
Example interface{} `json:"example,omitempty"`
AdditionalProperties interface{} `json:"additionalProperties,omitempty"`
OneOf []*Schema `json:"oneOf,omitempty"`
AnyOf []*Schema `json:"anyOf,omitempty"`
}
type SecurityScheme struct {
Type string `json:"type"` // "apiKey", "http", "oauth2", "openIdConnect"
Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"` // For apiKey
In string `json:"in,omitempty"` // For apiKey: "query", "header", "cookie"
Scheme string `json:"scheme,omitempty"` // For http: "basic", "bearer"
BearerFormat string `json:"bearerFormat,omitempty"` // For http bearer
}
// GeneratorConfig holds configuration for OpenAPI spec generation
type GeneratorConfig struct {
Title string
Description string
Version string
BaseURL string
Registry *modelregistry.DefaultModelRegistry
IncludeRestheadSpec bool
IncludeResolveSpec bool
IncludeFuncSpec bool
FuncSpecEndpoints map[string]FuncSpecEndpoint // path -> endpoint info
}
// FuncSpecEndpoint represents a FuncSpec endpoint for OpenAPI generation
type FuncSpecEndpoint struct {
Path string
Method string
Summary string
Description string
SQLQuery string
Parameters []string // Parameter names extracted from SQL
}
// Generator creates OpenAPI specifications
type Generator struct {
config GeneratorConfig
}
// NewGenerator creates a new OpenAPI generator
func NewGenerator(config GeneratorConfig) *Generator {
if config.Title == "" {
config.Title = "ResolveSpec API"
}
if config.Version == "" {
config.Version = "1.0.0"
}
return &Generator{config: config}
}
// Generate creates the complete OpenAPI specification
func (g *Generator) Generate() (*OpenAPISpec, error) {
spec := &OpenAPISpec{
OpenAPI: "3.0.0",
Info: Info{
Title: g.config.Title,
Description: g.config.Description,
Version: g.config.Version,
},
Paths: make(map[string]PathItem),
Components: Components{
Schemas: make(map[string]Schema),
SecuritySchemes: g.generateSecuritySchemes(),
},
}
if g.config.BaseURL != "" {
spec.Servers = []Server{
{URL: g.config.BaseURL, Description: "API Server"},
}
}
// Add common schemas
g.addCommonSchemas(spec)
// Generate paths and schemas from registered models
if err := g.generateFromModels(spec); err != nil {
return nil, err
}
return spec, nil
}
// GenerateJSON generates OpenAPI spec as JSON string
func (g *Generator) GenerateJSON() (string, error) {
spec, err := g.Generate()
if err != nil {
return "", err
}
data, err := json.MarshalIndent(spec, "", " ")
if err != nil {
return "", fmt.Errorf("failed to marshal spec: %w", err)
}
return string(data), nil
}
// generateSecuritySchemes creates security scheme definitions
func (g *Generator) generateSecuritySchemes() map[string]SecurityScheme {
return map[string]SecurityScheme{
"BearerAuth": {
Type: "http",
Scheme: "bearer",
BearerFormat: "JWT",
Description: "JWT Bearer token authentication",
},
"SessionToken": {
Type: "apiKey",
In: "header",
Name: "Authorization",
Description: "Session token authentication",
},
"CookieAuth": {
Type: "apiKey",
In: "cookie",
Name: "session_token",
Description: "Cookie-based session authentication",
},
"HeaderAuth": {
Type: "apiKey",
In: "header",
Name: "X-User-ID",
Description: "Header-based user authentication",
},
}
}
// addCommonSchemas adds common reusable schemas
func (g *Generator) addCommonSchemas(spec *OpenAPISpec) {
// Response wrapper schema
spec.Components.Schemas["Response"] = Schema{
Type: "object",
Properties: map[string]*Schema{
"success": {Type: "boolean", Description: "Indicates if the operation was successful"},
"data": {Description: "The response data"},
"metadata": {Ref: "#/components/schemas/Metadata"},
"error": {Ref: "#/components/schemas/APIError"},
},
}
// Metadata schema
spec.Components.Schemas["Metadata"] = Schema{
Type: "object",
Properties: map[string]*Schema{
"total": {Type: "integer", Description: "Total number of records"},
"count": {Type: "integer", Description: "Number of records in this response"},
"filtered": {Type: "integer", Description: "Number of records after filtering"},
"limit": {Type: "integer", Description: "Limit applied"},
"offset": {Type: "integer", Description: "Offset applied"},
"rowNumber": {Type: "integer", Description: "Row number for cursor pagination"},
},
}
// APIError schema
spec.Components.Schemas["APIError"] = Schema{
Type: "object",
Properties: map[string]*Schema{
"code": {Type: "string", Description: "Error code"},
"message": {Type: "string", Description: "Error message"},
"details": {Type: "string", Description: "Detailed error information"},
},
}
// RequestOptions schema
spec.Components.Schemas["RequestOptions"] = Schema{
Type: "object",
Properties: map[string]*Schema{
"preload": {
Type: "array",
Description: "Relations to eager load",
Items: &Schema{Ref: "#/components/schemas/PreloadOption"},
},
"columns": {
Type: "array",
Description: "Columns to select",
Items: &Schema{Type: "string"},
},
"omitColumns": {
Type: "array",
Description: "Columns to exclude",
Items: &Schema{Type: "string"},
},
"filters": {
Type: "array",
Description: "Filter conditions",
Items: &Schema{Ref: "#/components/schemas/FilterOption"},
},
"sort": {
Type: "array",
Description: "Sort specifications",
Items: &Schema{Ref: "#/components/schemas/SortOption"},
},
"limit": {Type: "integer", Description: "Maximum number of records"},
"offset": {Type: "integer", Description: "Number of records to skip"},
},
}
// FilterOption schema
spec.Components.Schemas["FilterOption"] = Schema{
Type: "object",
Properties: map[string]*Schema{
"column": {Type: "string", Description: "Column name"},
"operator": {Type: "string", Description: "Comparison operator", Enum: []interface{}{"eq", "neq", "gt", "lt", "gte", "lte", "like", "ilike", "in", "not_in", "between", "is_null", "is_not_null"}},
"value": {Description: "Filter value"},
"logicOperator": {Type: "string", Description: "Logic operator", Enum: []interface{}{"AND", "OR"}},
},
}
// SortOption schema
spec.Components.Schemas["SortOption"] = Schema{
Type: "object",
Properties: map[string]*Schema{
"column": {Type: "string", Description: "Column name"},
"direction": {Type: "string", Description: "Sort direction", Enum: []interface{}{"asc", "desc"}},
},
}
// PreloadOption schema
spec.Components.Schemas["PreloadOption"] = Schema{
Type: "object",
Properties: map[string]*Schema{
"relation": {Type: "string", Description: "Relation name"},
"columns": {
Type: "array",
Description: "Columns to select from related table",
Items: &Schema{Type: "string"},
},
},
}
// ResolveSpec RequestBody schema
spec.Components.Schemas["ResolveSpecRequest"] = Schema{
Type: "object",
Properties: map[string]*Schema{
"operation": {Type: "string", Description: "Operation type", Enum: []interface{}{"read", "create", "update", "delete", "meta"}},
"data": {Description: "Payload data (object or array)"},
"id": {Type: "integer", Description: "Record ID for single operations"},
"options": {Ref: "#/components/schemas/RequestOptions"},
},
}
}
// generateFromModels generates paths and schemas from registered models
func (g *Generator) generateFromModels(spec *OpenAPISpec) error {
if g.config.Registry == nil {
return fmt.Errorf("model registry is required")
}
models := g.config.Registry.GetAllModels()
for name, model := range models {
// Parse schema.entity from model name
schema, entity := parseModelName(name)
// Generate schema for this model
modelSchema := g.generateModelSchema(model)
schemaName := formatSchemaName(schema, entity)
spec.Components.Schemas[schemaName] = modelSchema
// Generate paths for different frameworks
if g.config.IncludeRestheadSpec {
g.generateRestheadSpecPaths(spec, schema, entity, schemaName)
}
if g.config.IncludeResolveSpec {
g.generateResolveSpecPaths(spec, schema, entity, schemaName)
}
}
// Generate FuncSpec paths if configured
if g.config.IncludeFuncSpec && len(g.config.FuncSpecEndpoints) > 0 {
g.generateFuncSpecPaths(spec)
}
return nil
}
// generateModelSchema creates an OpenAPI schema from a Go struct
func (g *Generator) generateModelSchema(model interface{}) Schema {
schema := Schema{
Type: "object",
Properties: make(map[string]*Schema),
Required: []string{},
}
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType.Kind() != reflect.Struct {
return schema
}
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
// Get JSON tag name
jsonTag := field.Tag.Get("json")
if jsonTag == "-" {
continue
}
fieldName := strings.Split(jsonTag, ",")[0]
if fieldName == "" {
fieldName = field.Name
}
// Generate property schema
propSchema := g.generatePropertySchema(field)
schema.Properties[fieldName] = propSchema
// Check if field is required (not a pointer and no omitempty)
if field.Type.Kind() != reflect.Ptr && !strings.Contains(jsonTag, "omitempty") {
schema.Required = append(schema.Required, fieldName)
}
}
return schema
}
// generatePropertySchema creates a schema for a struct field
func (g *Generator) generatePropertySchema(field reflect.StructField) *Schema {
schema := &Schema{}
fieldType := field.Type
if fieldType.Kind() == reflect.Ptr {
fieldType = fieldType.Elem()
}
// Get description from tag
if desc := field.Tag.Get("description"); desc != "" {
schema.Description = desc
}
switch fieldType.Kind() {
case reflect.String:
schema.Type = "string"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
schema.Type = "integer"
case reflect.Float32, reflect.Float64:
schema.Type = "number"
case reflect.Bool:
schema.Type = "boolean"
case reflect.Slice, reflect.Array:
schema.Type = "array"
elemType := fieldType.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
if elemType.Kind() == reflect.Struct {
// Complex type - would need recursive handling
schema.Items = &Schema{Type: "object"}
} else {
schema.Items = g.generatePropertySchema(reflect.StructField{Type: elemType})
}
case reflect.Struct:
// Check for time.Time
if fieldType.String() == "time.Time" {
schema.Type = "string"
schema.Format = "date-time"
} else {
schema.Type = "object"
}
default:
schema.Type = "string"
}
// Check for custom format from gorm/bun tags
if gormTag := field.Tag.Get("gorm"); gormTag != "" {
if strings.Contains(gormTag, "type:uuid") {
schema.Format = "uuid"
}
}
return schema
}
// parseModelName splits "schema.entity" or returns "public" and entity
func parseModelName(name string) (schema, entity string) {
parts := strings.Split(name, ".")
if len(parts) == 2 {
return parts[0], parts[1]
}
return "public", name
}
// formatSchemaName creates a component schema name
func formatSchemaName(schema, entity string) string {
if schema == "public" {
return toTitleCase(entity)
}
return toTitleCase(schema) + toTitleCase(entity)
}
// toTitleCase converts a string to title case (first letter uppercase)
func toTitleCase(s string) string {
if s == "" {
return ""
}
if len(s) == 1 {
return strings.ToUpper(s)
}
return strings.ToUpper(s[:1]) + s[1:]
}

View File

@@ -0,0 +1,714 @@
package openapi
import (
"encoding/json"
"strings"
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
// Test models
type TestUser struct {
ID int `json:"id" gorm:"primaryKey" description:"User ID"`
Name string `json:"name" gorm:"not null" description:"User's full name"`
Email string `json:"email" gorm:"unique" description:"Email address"`
Age int `json:"age" description:"User age"`
IsActive bool `json:"is_active" description:"Active status"`
CreatedAt time.Time `json:"created_at" description:"Creation timestamp"`
UpdatedAt *time.Time `json:"updated_at,omitempty" description:"Last update timestamp"`
Roles []string `json:"roles,omitempty" description:"User roles"`
}
type TestProduct struct {
ID int `json:"id" gorm:"primaryKey"`
Name string `json:"name" gorm:"not null"`
Description string `json:"description"`
Price float64 `json:"price"`
InStock bool `json:"in_stock"`
}
type TestOrder struct {
ID int `json:"id" gorm:"primaryKey"`
UserID int `json:"user_id" gorm:"not null"`
ProductID int `json:"product_id" gorm:"not null"`
Quantity int `json:"quantity"`
TotalPrice float64 `json:"total_price"`
}
func TestNewGenerator(t *testing.T) {
registry := modelregistry.NewModelRegistry()
tests := []struct {
name string
config GeneratorConfig
want string // expected title
}{
{
name: "with all fields",
config: GeneratorConfig{
Title: "Test API",
Description: "Test Description",
Version: "1.0.0",
BaseURL: "http://localhost:8080",
Registry: registry,
},
want: "Test API",
},
{
name: "with defaults",
config: GeneratorConfig{
Registry: registry,
},
want: "ResolveSpec API",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gen := NewGenerator(tt.config)
if gen == nil {
t.Fatal("NewGenerator returned nil")
}
if gen.config.Title != tt.want {
t.Errorf("Title = %v, want %v", gen.config.Title, tt.want)
}
})
}
}
func TestGenerateBasicSpec(t *testing.T) {
registry := modelregistry.NewModelRegistry()
err := registry.RegisterModel("public.users", TestUser{})
if err != nil {
t.Fatalf("Failed to register model: %v", err)
}
config := GeneratorConfig{
Title: "Test API",
Version: "1.0.0",
Registry: registry,
IncludeRestheadSpec: true,
}
gen := NewGenerator(config)
spec, err := gen.Generate()
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
// Test basic spec structure
if spec.OpenAPI != "3.0.0" {
t.Errorf("OpenAPI version = %v, want 3.0.0", spec.OpenAPI)
}
if spec.Info.Title != "Test API" {
t.Errorf("Title = %v, want Test API", spec.Info.Title)
}
if spec.Info.Version != "1.0.0" {
t.Errorf("Version = %v, want 1.0.0", spec.Info.Version)
}
// Test that common schemas are added
if spec.Components.Schemas["Response"].Type != "object" {
t.Error("Response schema not found or invalid")
}
if spec.Components.Schemas["Metadata"].Type != "object" {
t.Error("Metadata schema not found or invalid")
}
// Test that model schema is added
if _, exists := spec.Components.Schemas["Users"]; !exists {
t.Error("Users schema not found")
}
// Test that security schemes are added
if len(spec.Components.SecuritySchemes) == 0 {
t.Error("Security schemes not added")
}
}
func TestGenerateModelSchema(t *testing.T) {
registry := modelregistry.NewModelRegistry()
gen := NewGenerator(GeneratorConfig{Registry: registry})
schema := gen.generateModelSchema(TestUser{})
// Test basic properties
if schema.Type != "object" {
t.Errorf("Schema type = %v, want object", schema.Type)
}
// Test that properties are generated
expectedProps := []string{"id", "name", "email", "age", "is_active", "created_at", "updated_at", "roles"}
for _, prop := range expectedProps {
if _, exists := schema.Properties[prop]; !exists {
t.Errorf("Property %s not found in schema", prop)
}
}
// Test property types
if schema.Properties["id"].Type != "integer" {
t.Errorf("id type = %v, want integer", schema.Properties["id"].Type)
}
if schema.Properties["name"].Type != "string" {
t.Errorf("name type = %v, want string", schema.Properties["name"].Type)
}
if schema.Properties["is_active"].Type != "boolean" {
t.Errorf("is_active type = %v, want boolean", schema.Properties["is_active"].Type)
}
// Test array type
if schema.Properties["roles"].Type != "array" {
t.Errorf("roles type = %v, want array", schema.Properties["roles"].Type)
}
if schema.Properties["roles"].Items.Type != "string" {
t.Errorf("roles items type = %v, want string", schema.Properties["roles"].Items.Type)
}
// Test time.Time format
if schema.Properties["created_at"].Type != "string" {
t.Errorf("created_at type = %v, want string", schema.Properties["created_at"].Type)
}
if schema.Properties["created_at"].Format != "date-time" {
t.Errorf("created_at format = %v, want date-time", schema.Properties["created_at"].Format)
}
// Test required fields (non-pointer, no omitempty)
requiredFields := map[string]bool{}
for _, field := range schema.Required {
requiredFields[field] = true
}
if !requiredFields["id"] {
t.Error("id should be required")
}
if !requiredFields["name"] {
t.Error("name should be required")
}
if requiredFields["updated_at"] {
t.Error("updated_at should not be required (pointer + omitempty)")
}
if requiredFields["roles"] {
t.Error("roles should not be required (omitempty)")
}
// Test descriptions
if schema.Properties["id"].Description != "User ID" {
t.Errorf("id description = %v, want 'User ID'", schema.Properties["id"].Description)
}
}
func TestGenerateRestheadSpecPaths(t *testing.T) {
registry := modelregistry.NewModelRegistry()
err := registry.RegisterModel("public.users", TestUser{})
if err != nil {
t.Fatalf("Failed to register model: %v", err)
}
config := GeneratorConfig{
Title: "Test API",
Version: "1.0.0",
Registry: registry,
IncludeRestheadSpec: true,
}
gen := NewGenerator(config)
spec, err := gen.Generate()
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
// Test that paths are generated
expectedPaths := []string{
"/public/users",
"/public/users/{id}",
"/public/users/metadata",
}
for _, path := range expectedPaths {
if _, exists := spec.Paths[path]; !exists {
t.Errorf("Path %s not found", path)
}
}
// Test collection endpoint methods
usersPath := spec.Paths["/public/users"]
if usersPath.Get == nil {
t.Error("GET method not found for /public/users")
}
if usersPath.Post == nil {
t.Error("POST method not found for /public/users")
}
if usersPath.Options == nil {
t.Error("OPTIONS method not found for /public/users")
}
// Test single record endpoint methods
userIDPath := spec.Paths["/public/users/{id}"]
if userIDPath.Get == nil {
t.Error("GET method not found for /public/users/{id}")
}
if userIDPath.Put == nil {
t.Error("PUT method not found for /public/users/{id}")
}
if userIDPath.Patch == nil {
t.Error("PATCH method not found for /public/users/{id}")
}
if userIDPath.Delete == nil {
t.Error("DELETE method not found for /public/users/{id}")
}
// Test metadata endpoint
metadataPath := spec.Paths["/public/users/metadata"]
if metadataPath.Get == nil {
t.Error("GET method not found for /public/users/metadata")
}
// Test operation details
getOp := usersPath.Get
if getOp.Summary == "" {
t.Error("GET operation summary is empty")
}
if getOp.OperationID == "" {
t.Error("GET operation ID is empty")
}
if len(getOp.Tags) == 0 {
t.Error("GET operation has no tags")
}
if len(getOp.Parameters) == 0 {
t.Error("GET operation has no parameters")
}
// Test RestheadSpec headers
hasFiltersHeader := false
for _, param := range getOp.Parameters {
if param.Name == "X-Filters" && param.In == "header" {
hasFiltersHeader = true
break
}
}
if !hasFiltersHeader {
t.Error("X-Filters header parameter not found")
}
}
func TestGenerateResolveSpecPaths(t *testing.T) {
registry := modelregistry.NewModelRegistry()
err := registry.RegisterModel("public.products", TestProduct{})
if err != nil {
t.Fatalf("Failed to register model: %v", err)
}
config := GeneratorConfig{
Title: "Test API",
Version: "1.0.0",
Registry: registry,
IncludeResolveSpec: true,
}
gen := NewGenerator(config)
spec, err := gen.Generate()
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
// Test that paths are generated
expectedPaths := []string{
"/resolve/public/products",
"/resolve/public/products/{id}",
}
for _, path := range expectedPaths {
if _, exists := spec.Paths[path]; !exists {
t.Errorf("Path %s not found", path)
}
}
// Test collection endpoint methods
productsPath := spec.Paths["/resolve/public/products"]
if productsPath.Post == nil {
t.Error("POST method not found for /resolve/public/products")
}
if productsPath.Get == nil {
t.Error("GET method not found for /resolve/public/products")
}
if productsPath.Options == nil {
t.Error("OPTIONS method not found for /resolve/public/products")
}
// Test POST operation has request body
postOp := productsPath.Post
if postOp.RequestBody == nil {
t.Error("POST operation has no request body")
}
if _, exists := postOp.RequestBody.Content["application/json"]; !exists {
t.Error("POST operation request body has no application/json content")
}
// Test request body schema references ResolveSpecRequest
reqBodySchema := postOp.RequestBody.Content["application/json"].Schema
if reqBodySchema.Ref != "#/components/schemas/ResolveSpecRequest" {
t.Errorf("Request body schema ref = %v, want #/components/schemas/ResolveSpecRequest", reqBodySchema.Ref)
}
}
func TestGenerateFuncSpecPaths(t *testing.T) {
registry := modelregistry.NewModelRegistry()
funcSpecEndpoints := map[string]FuncSpecEndpoint{
"/api/reports/sales": {
Path: "/api/reports/sales",
Method: "GET",
Summary: "Get sales report",
Description: "Returns sales data",
Parameters: []string{"start_date", "end_date"},
},
"/api/analytics/users": {
Path: "/api/analytics/users",
Method: "POST",
Summary: "Get user analytics",
Description: "Returns user activity",
Parameters: []string{"user_id"},
},
}
config := GeneratorConfig{
Title: "Test API",
Version: "1.0.0",
Registry: registry,
IncludeFuncSpec: true,
FuncSpecEndpoints: funcSpecEndpoints,
}
gen := NewGenerator(config)
spec, err := gen.Generate()
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
// Test that FuncSpec paths are generated
salesPath := spec.Paths["/api/reports/sales"]
if salesPath.Get == nil {
t.Error("GET method not found for /api/reports/sales")
}
if salesPath.Get.Summary != "Get sales report" {
t.Errorf("GET summary = %v, want 'Get sales report'", salesPath.Get.Summary)
}
if len(salesPath.Get.Parameters) != 2 {
t.Errorf("GET has %d parameters, want 2", len(salesPath.Get.Parameters))
}
analyticsPath := spec.Paths["/api/analytics/users"]
if analyticsPath.Post == nil {
t.Error("POST method not found for /api/analytics/users")
}
if len(analyticsPath.Post.Parameters) != 1 {
t.Errorf("POST has %d parameters, want 1", len(analyticsPath.Post.Parameters))
}
}
func TestGenerateJSON(t *testing.T) {
registry := modelregistry.NewModelRegistry()
err := registry.RegisterModel("public.users", TestUser{})
if err != nil {
t.Fatalf("Failed to register model: %v", err)
}
config := GeneratorConfig{
Title: "Test API",
Version: "1.0.0",
Registry: registry,
IncludeRestheadSpec: true,
}
gen := NewGenerator(config)
jsonStr, err := gen.GenerateJSON()
if err != nil {
t.Fatalf("GenerateJSON failed: %v", err)
}
// Test that it's valid JSON
var spec OpenAPISpec
if err := json.Unmarshal([]byte(jsonStr), &spec); err != nil {
t.Fatalf("Generated JSON is invalid: %v", err)
}
// Test basic structure
if spec.OpenAPI != "3.0.0" {
t.Errorf("OpenAPI version = %v, want 3.0.0", spec.OpenAPI)
}
if spec.Info.Title != "Test API" {
t.Errorf("Title = %v, want Test API", spec.Info.Title)
}
// Test that JSON contains expected fields
if !strings.Contains(jsonStr, `"openapi"`) {
t.Error("JSON doesn't contain 'openapi' field")
}
if !strings.Contains(jsonStr, `"paths"`) {
t.Error("JSON doesn't contain 'paths' field")
}
if !strings.Contains(jsonStr, `"components"`) {
t.Error("JSON doesn't contain 'components' field")
}
}
func TestMultipleModels(t *testing.T) {
registry := modelregistry.NewModelRegistry()
registry.RegisterModel("public.users", TestUser{})
registry.RegisterModel("public.products", TestProduct{})
registry.RegisterModel("public.orders", TestOrder{})
config := GeneratorConfig{
Title: "Test API",
Version: "1.0.0",
Registry: registry,
IncludeRestheadSpec: true,
}
gen := NewGenerator(config)
spec, err := gen.Generate()
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
// Test that all model schemas are generated
expectedSchemas := []string{"Users", "Products", "Orders"}
for _, schemaName := range expectedSchemas {
if _, exists := spec.Components.Schemas[schemaName]; !exists {
t.Errorf("Schema %s not found", schemaName)
}
}
// Test that all paths are generated
expectedPaths := []string{
"/public/users",
"/public/products",
"/public/orders",
}
for _, path := range expectedPaths {
if _, exists := spec.Paths[path]; !exists {
t.Errorf("Path %s not found", path)
}
}
}
func TestModelNameParsing(t *testing.T) {
tests := []struct {
name string
fullName string
wantSchema string
wantEntity string
}{
{
name: "with schema",
fullName: "public.users",
wantSchema: "public",
wantEntity: "users",
},
{
name: "without schema",
fullName: "users",
wantSchema: "public",
wantEntity: "users",
},
{
name: "custom schema",
fullName: "custom.products",
wantSchema: "custom",
wantEntity: "products",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
schema, entity := parseModelName(tt.fullName)
if schema != tt.wantSchema {
t.Errorf("schema = %v, want %v", schema, tt.wantSchema)
}
if entity != tt.wantEntity {
t.Errorf("entity = %v, want %v", entity, tt.wantEntity)
}
})
}
}
func TestSchemaNameFormatting(t *testing.T) {
tests := []struct {
name string
schema string
entity string
wantName string
}{
{
name: "public schema",
schema: "public",
entity: "users",
wantName: "Users",
},
{
name: "custom schema",
schema: "custom",
entity: "products",
wantName: "CustomProducts",
},
{
name: "multi-word entity",
schema: "public",
entity: "user_profiles",
wantName: "User_profiles",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
name := formatSchemaName(tt.schema, tt.entity)
if name != tt.wantName {
t.Errorf("formatSchemaName() = %v, want %v", name, tt.wantName)
}
})
}
}
func TestToTitleCase(t *testing.T) {
tests := []struct {
input string
want string
}{
{"users", "Users"},
{"products", "Products"},
{"userProfiles", "UserProfiles"},
{"a", "A"},
{"", ""},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
got := toTitleCase(tt.input)
if got != tt.want {
t.Errorf("toTitleCase(%v) = %v, want %v", tt.input, got, tt.want)
}
})
}
}
func TestGenerateWithBaseURL(t *testing.T) {
registry := modelregistry.NewModelRegistry()
registry.RegisterModel("public.users", TestUser{})
config := GeneratorConfig{
Title: "Test API",
Version: "1.0.0",
BaseURL: "https://api.example.com",
Registry: registry,
IncludeRestheadSpec: true,
}
gen := NewGenerator(config)
spec, err := gen.Generate()
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
// Test that server is added
if len(spec.Servers) == 0 {
t.Fatal("No servers added")
}
if spec.Servers[0].URL != "https://api.example.com" {
t.Errorf("Server URL = %v, want https://api.example.com", spec.Servers[0].URL)
}
if spec.Servers[0].Description != "API Server" {
t.Errorf("Server description = %v, want 'API Server'", spec.Servers[0].Description)
}
}
func TestGenerateCombinedFrameworks(t *testing.T) {
registry := modelregistry.NewModelRegistry()
registry.RegisterModel("public.users", TestUser{})
config := GeneratorConfig{
Title: "Test API",
Version: "1.0.0",
Registry: registry,
IncludeRestheadSpec: true,
IncludeResolveSpec: true,
}
gen := NewGenerator(config)
spec, err := gen.Generate()
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
// Test that both RestheadSpec and ResolveSpec paths are generated
restheadPath := "/public/users"
resolveSpecPath := "/resolve/public/users"
if _, exists := spec.Paths[restheadPath]; !exists {
t.Errorf("RestheadSpec path %s not found", restheadPath)
}
if _, exists := spec.Paths[resolveSpecPath]; !exists {
t.Errorf("ResolveSpec path %s not found", resolveSpecPath)
}
}
func TestNilRegistry(t *testing.T) {
config := GeneratorConfig{
Title: "Test API",
Version: "1.0.0",
}
gen := NewGenerator(config)
_, err := gen.Generate()
if err == nil {
t.Error("Expected error for nil registry, got nil")
}
if !strings.Contains(err.Error(), "registry") {
t.Errorf("Error message should mention registry, got: %v", err)
}
}
func TestSecuritySchemes(t *testing.T) {
registry := modelregistry.NewModelRegistry()
config := GeneratorConfig{
Registry: registry,
}
gen := NewGenerator(config)
spec, err := gen.Generate()
if err != nil {
t.Fatalf("Generate failed: %v", err)
}
// Test that all security schemes are present
expectedSchemes := []string{"BearerAuth", "SessionToken", "CookieAuth", "HeaderAuth"}
for _, scheme := range expectedSchemes {
if _, exists := spec.Components.SecuritySchemes[scheme]; !exists {
t.Errorf("Security scheme %s not found", scheme)
}
}
// Test BearerAuth scheme details
bearerAuth := spec.Components.SecuritySchemes["BearerAuth"]
if bearerAuth.Type != "http" {
t.Errorf("BearerAuth type = %v, want http", bearerAuth.Type)
}
if bearerAuth.Scheme != "bearer" {
t.Errorf("BearerAuth scheme = %v, want bearer", bearerAuth.Scheme)
}
if bearerAuth.BearerFormat != "JWT" {
t.Errorf("BearerAuth format = %v, want JWT", bearerAuth.BearerFormat)
}
// Test HeaderAuth scheme details
headerAuth := spec.Components.SecuritySchemes["HeaderAuth"]
if headerAuth.Type != "apiKey" {
t.Errorf("HeaderAuth type = %v, want apiKey", headerAuth.Type)
}
if headerAuth.In != "header" {
t.Errorf("HeaderAuth in = %v, want header", headerAuth.In)
}
if headerAuth.Name != "X-User-ID" {
t.Errorf("HeaderAuth name = %v, want X-User-ID", headerAuth.Name)
}
}

499
pkg/openapi/paths.go Normal file
View File

@@ -0,0 +1,499 @@
package openapi
import (
"fmt"
)
// generateRestheadSpecPaths generates OpenAPI paths for RestheadSpec endpoints
func (g *Generator) generateRestheadSpecPaths(spec *OpenAPISpec, schema, entity, schemaName string) {
basePath := fmt.Sprintf("/%s/%s", schema, entity)
idPath := fmt.Sprintf("/%s/%s/{id}", schema, entity)
metaPath := fmt.Sprintf("/%s/%s/metadata", schema, entity)
// Collection endpoint: GET (list), POST (create)
spec.Paths[basePath] = PathItem{
Get: &Operation{
Summary: fmt.Sprintf("List %s records", entity),
Description: fmt.Sprintf("Retrieve a list of %s records with optional filtering, sorting, and pagination via headers", entity),
OperationID: fmt.Sprintf("listRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
Parameters: g.getRestheadSpecHeaders(),
Responses: map[string]Response{
"200": {
Description: "Successful response",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{
Type: "object",
Properties: map[string]*Schema{
"success": {Type: "boolean"},
"data": {Type: "array", Items: &Schema{Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)}},
"metadata": {Ref: "#/components/schemas/Metadata"},
},
},
},
},
},
"401": g.errorResponse("Unauthorized"),
"500": g.errorResponse("Internal server error"),
},
Security: g.securityRequirements(),
},
Post: &Operation{
Summary: fmt.Sprintf("Create %s record", entity),
Description: fmt.Sprintf("Create a new %s record", entity),
OperationID: fmt.Sprintf("createRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
RequestBody: &RequestBody{
Required: true,
Description: fmt.Sprintf("%s object to create", entity),
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
},
},
},
Responses: map[string]Response{
"201": {
Description: "Record created successfully",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{
Type: "object",
Properties: map[string]*Schema{
"success": {Type: "boolean"},
"data": {Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
},
},
},
},
},
"400": g.errorResponse("Bad request"),
"401": g.errorResponse("Unauthorized"),
"500": g.errorResponse("Internal server error"),
},
Security: g.securityRequirements(),
},
Options: &Operation{
Summary: "CORS preflight",
Description: "Handle CORS preflight requests",
OperationID: fmt.Sprintf("optionsRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
Responses: map[string]Response{
"204": {Description: "No content"},
},
},
}
// Single record endpoint: GET (read), PUT/PATCH (update), DELETE
spec.Paths[idPath] = PathItem{
Get: &Operation{
Summary: fmt.Sprintf("Get %s record by ID", entity),
Description: fmt.Sprintf("Retrieve a single %s record by its ID", entity),
OperationID: fmt.Sprintf("getRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
Parameters: []Parameter{
{Name: "id", In: "path", Required: true, Description: "Record ID", Schema: &Schema{Type: "integer"}},
},
Responses: map[string]Response{
"200": {
Description: "Successful response",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{
Type: "object",
Properties: map[string]*Schema{
"success": {Type: "boolean"},
"data": {Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
},
},
},
},
},
"404": g.errorResponse("Record not found"),
"401": g.errorResponse("Unauthorized"),
"500": g.errorResponse("Internal server error"),
},
Security: g.securityRequirements(),
},
Put: &Operation{
Summary: fmt.Sprintf("Update %s record", entity),
Description: fmt.Sprintf("Update an existing %s record by ID", entity),
OperationID: fmt.Sprintf("updateRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
Parameters: []Parameter{
{Name: "id", In: "path", Required: true, Description: "Record ID", Schema: &Schema{Type: "integer"}},
},
RequestBody: &RequestBody{
Required: true,
Description: fmt.Sprintf("Updated %s object", entity),
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
},
},
},
Responses: map[string]Response{
"200": {
Description: "Record updated successfully",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{
Type: "object",
Properties: map[string]*Schema{
"success": {Type: "boolean"},
"data": {Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
},
},
},
},
},
"400": g.errorResponse("Bad request"),
"404": g.errorResponse("Record not found"),
"401": g.errorResponse("Unauthorized"),
"500": g.errorResponse("Internal server error"),
},
Security: g.securityRequirements(),
},
Patch: &Operation{
Summary: fmt.Sprintf("Partially update %s record", entity),
Description: fmt.Sprintf("Partially update an existing %s record by ID", entity),
OperationID: fmt.Sprintf("patchRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
Parameters: []Parameter{
{Name: "id", In: "path", Required: true, Description: "Record ID", Schema: &Schema{Type: "integer"}},
},
RequestBody: &RequestBody{
Required: true,
Description: fmt.Sprintf("Partial %s object", entity),
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
},
},
},
Responses: map[string]Response{
"200": {
Description: "Record updated successfully",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{
Type: "object",
Properties: map[string]*Schema{
"success": {Type: "boolean"},
"data": {Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
},
},
},
},
},
"400": g.errorResponse("Bad request"),
"404": g.errorResponse("Record not found"),
"401": g.errorResponse("Unauthorized"),
"500": g.errorResponse("Internal server error"),
},
Security: g.securityRequirements(),
},
Delete: &Operation{
Summary: fmt.Sprintf("Delete %s record", entity),
Description: fmt.Sprintf("Delete a %s record by ID", entity),
OperationID: fmt.Sprintf("deleteRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
Parameters: []Parameter{
{Name: "id", In: "path", Required: true, Description: "Record ID", Schema: &Schema{Type: "integer"}},
},
Responses: map[string]Response{
"200": {
Description: "Record deleted successfully",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{
Type: "object",
Properties: map[string]*Schema{
"success": {Type: "boolean"},
},
},
},
},
},
"404": g.errorResponse("Record not found"),
"401": g.errorResponse("Unauthorized"),
"500": g.errorResponse("Internal server error"),
},
Security: g.securityRequirements(),
},
}
// Metadata endpoint
spec.Paths[metaPath] = PathItem{
Get: &Operation{
Summary: fmt.Sprintf("Get %s metadata", entity),
Description: fmt.Sprintf("Retrieve metadata information for %s table", entity),
OperationID: fmt.Sprintf("metadataRestheadSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
Tags: []string{fmt.Sprintf("%s (RestheadSpec)", entity)},
Responses: map[string]Response{
"200": {
Description: "Metadata retrieved successfully",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{
Type: "object",
Properties: map[string]*Schema{
"success": {Type: "boolean"},
"data": {
Type: "object",
Properties: map[string]*Schema{
"schema": {Type: "string"},
"table": {Type: "string"},
"columns": {Type: "array", Items: &Schema{Type: "object"}},
},
},
},
},
},
},
},
"401": g.errorResponse("Unauthorized"),
"500": g.errorResponse("Internal server error"),
},
Security: g.securityRequirements(),
},
}
}
// generateResolveSpecPaths generates OpenAPI paths for ResolveSpec endpoints
func (g *Generator) generateResolveSpecPaths(spec *OpenAPISpec, schema, entity, schemaName string) {
basePath := fmt.Sprintf("/resolve/%s/%s", schema, entity)
idPath := fmt.Sprintf("/resolve/%s/%s/{id}", schema, entity)
// Collection endpoint: POST (operations)
spec.Paths[basePath] = PathItem{
Post: &Operation{
Summary: fmt.Sprintf("Perform operation on %s", entity),
Description: fmt.Sprintf("Execute read, create, or meta operations on %s records", entity),
OperationID: fmt.Sprintf("operateResolveSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
Tags: []string{fmt.Sprintf("%s (ResolveSpec)", entity)},
RequestBody: &RequestBody{
Required: true,
Description: "Operation request with operation type and options",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{Ref: "#/components/schemas/ResolveSpecRequest"},
Example: map[string]interface{}{
"operation": "read",
"options": map[string]interface{}{
"limit": 10,
"filters": []map[string]interface{}{
{"column": "status", "operator": "eq", "value": "active"},
},
},
},
},
},
},
Responses: map[string]Response{
"200": {
Description: "Operation completed successfully",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{
Type: "object",
Properties: map[string]*Schema{
"success": {Type: "boolean"},
"data": {Type: "array", Items: &Schema{Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)}},
"metadata": {Ref: "#/components/schemas/Metadata"},
},
},
},
},
},
"400": g.errorResponse("Bad request"),
"401": g.errorResponse("Unauthorized"),
"500": g.errorResponse("Internal server error"),
},
Security: g.securityRequirements(),
},
Get: &Operation{
Summary: fmt.Sprintf("Get %s metadata", entity),
Description: fmt.Sprintf("Retrieve metadata for %s", entity),
OperationID: fmt.Sprintf("metadataResolveSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
Tags: []string{fmt.Sprintf("%s (ResolveSpec)", entity)},
Responses: map[string]Response{
"200": {
Description: "Metadata retrieved successfully",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{Ref: "#/components/schemas/Response"},
},
},
},
"401": g.errorResponse("Unauthorized"),
"500": g.errorResponse("Internal server error"),
},
Security: g.securityRequirements(),
},
Options: &Operation{
Summary: "CORS preflight",
Description: "Handle CORS preflight requests",
OperationID: fmt.Sprintf("optionsResolveSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
Tags: []string{fmt.Sprintf("%s (ResolveSpec)", entity)},
Responses: map[string]Response{
"204": {Description: "No content"},
},
},
}
// Single record endpoint: POST (update/delete)
spec.Paths[idPath] = PathItem{
Post: &Operation{
Summary: fmt.Sprintf("Update or delete %s record", entity),
Description: fmt.Sprintf("Execute update or delete operation on a specific %s record", entity),
OperationID: fmt.Sprintf("modifyResolveSpec%s%s", formatSchemaName(schema, ""), formatSchemaName("", entity)),
Tags: []string{fmt.Sprintf("%s (ResolveSpec)", entity)},
Parameters: []Parameter{
{Name: "id", In: "path", Required: true, Description: "Record ID", Schema: &Schema{Type: "integer"}},
},
RequestBody: &RequestBody{
Required: true,
Description: "Operation request (update or delete)",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{Ref: "#/components/schemas/ResolveSpecRequest"},
Example: map[string]interface{}{
"operation": "update",
"data": map[string]interface{}{
"status": "inactive",
},
},
},
},
},
Responses: map[string]Response{
"200": {
Description: "Operation completed successfully",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{
Type: "object",
Properties: map[string]*Schema{
"success": {Type: "boolean"},
"data": {Ref: fmt.Sprintf("#/components/schemas/%s", schemaName)},
},
},
},
},
},
"400": g.errorResponse("Bad request"),
"404": g.errorResponse("Record not found"),
"401": g.errorResponse("Unauthorized"),
"500": g.errorResponse("Internal server error"),
},
Security: g.securityRequirements(),
},
}
}
// generateFuncSpecPaths generates OpenAPI paths for FuncSpec endpoints
func (g *Generator) generateFuncSpecPaths(spec *OpenAPISpec) {
for path, endpoint := range g.config.FuncSpecEndpoints {
operation := &Operation{
Summary: endpoint.Summary,
Description: endpoint.Description,
OperationID: fmt.Sprintf("funcSpec%s", sanitizeOperationID(path)),
Tags: []string{"FuncSpec"},
Parameters: g.extractFuncSpecParameters(endpoint.Parameters),
Responses: map[string]Response{
"200": {
Description: "Query executed successfully",
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{Ref: "#/components/schemas/Response"},
},
},
},
"400": g.errorResponse("Bad request"),
"401": g.errorResponse("Unauthorized"),
"500": g.errorResponse("Internal server error"),
},
Security: g.securityRequirements(),
}
pathItem := spec.Paths[path]
switch endpoint.Method {
case "GET":
pathItem.Get = operation
case "POST":
pathItem.Post = operation
case "PUT":
pathItem.Put = operation
case "DELETE":
pathItem.Delete = operation
}
spec.Paths[path] = pathItem
}
}
// getRestheadSpecHeaders returns all RestheadSpec header parameters
func (g *Generator) getRestheadSpecHeaders() []Parameter {
return []Parameter{
{Name: "X-Filters", In: "header", Description: "JSON array of filter conditions", Schema: &Schema{Type: "string"}},
{Name: "X-Columns", In: "header", Description: "Comma-separated list of columns to select", Schema: &Schema{Type: "string"}},
{Name: "X-Sort", In: "header", Description: "JSON array of sort specifications", Schema: &Schema{Type: "string"}},
{Name: "X-Limit", In: "header", Description: "Maximum number of records to return", Schema: &Schema{Type: "integer"}},
{Name: "X-Offset", In: "header", Description: "Number of records to skip", Schema: &Schema{Type: "integer"}},
{Name: "X-Preload", In: "header", Description: "Relations to eager load (comma-separated)", Schema: &Schema{Type: "string"}},
{Name: "X-Expand", In: "header", Description: "Relations to expand with LEFT JOIN (comma-separated)", Schema: &Schema{Type: "string"}},
{Name: "X-Distinct", In: "header", Description: "Enable DISTINCT query (true/false)", Schema: &Schema{Type: "boolean"}},
{Name: "X-Response-Format", In: "header", Description: "Response format", Schema: &Schema{Type: "string", Enum: []interface{}{"detail", "simple", "syncfusion"}}},
{Name: "X-Clean-JSON", In: "header", Description: "Remove null/empty fields from response (true/false)", Schema: &Schema{Type: "boolean"}},
{Name: "X-Custom-SQL-Where", In: "header", Description: "Custom SQL WHERE clause (AND)", Schema: &Schema{Type: "string"}},
{Name: "X-Custom-SQL-Or", In: "header", Description: "Custom SQL WHERE clause (OR)", Schema: &Schema{Type: "string"}},
}
}
// extractFuncSpecParameters creates OpenAPI parameters from parameter names
func (g *Generator) extractFuncSpecParameters(paramNames []string) []Parameter {
params := []Parameter{}
for _, name := range paramNames {
params = append(params, Parameter{
Name: name,
In: "query",
Description: fmt.Sprintf("Parameter: %s", name),
Schema: &Schema{Type: "string"},
})
}
return params
}
// errorResponse creates a standard error response
func (g *Generator) errorResponse(description string) Response {
return Response{
Description: description,
Content: map[string]MediaType{
"application/json": {
Schema: &Schema{Ref: "#/components/schemas/APIError"},
},
},
}
}
// securityRequirements returns all security options (user can use any)
func (g *Generator) securityRequirements() []map[string][]string {
return []map[string][]string{
{"BearerAuth": {}},
{"SessionToken": {}},
{"CookieAuth": {}},
{"HeaderAuth": {}},
}
}
// sanitizeOperationID removes invalid characters from operation IDs
func sanitizeOperationID(path string) string {
result := ""
for _, char := range path {
if (char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9') {
result += string(char)
}
}
return result
}

View File

@@ -26,8 +26,7 @@ func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
} }
}() }()
var lst []ModelFieldDetail lst := make([]ModelFieldDetail, 0)
lst = make([]ModelFieldDetail, 0)
if !record.IsValid() { if !record.IsValid() {
return lst return lst

View File

@@ -0,0 +1,331 @@
package reflection
import (
"reflect"
"testing"
)
// Test models for GetModelColumnDetail
type TestModelForColumnDetail struct {
ID int `gorm:"column:rid_test;primaryKey;type:bigserial;not null" json:"id"`
Name string `gorm:"column:name;type:varchar(255);not null" json:"name"`
Email string `gorm:"column:email;type:varchar(255);unique;nullable" json:"email"`
Description string `gorm:"column:description;type:text;null" json:"description"`
ForeignKey int `gorm:"foreignKey:parent_id" json:"foreign_key"`
}
type EmbeddedBase struct {
ID int `gorm:"column:rid_base;primaryKey;identity" json:"id"`
CreatedAt string `gorm:"column:created_at;type:timestamp" json:"created_at"`
}
type ModelWithEmbeddedForDetail struct {
EmbeddedBase
Title string `gorm:"column:title;type:varchar(100);not null" json:"title"`
Content string `gorm:"column:content;type:text" json:"content"`
}
// Model with nil embedded pointer
type ModelWithNilEmbedded struct {
ID int `gorm:"column:id;primaryKey" json:"id"`
*EmbeddedBase
Name string `gorm:"column:name" json:"name"`
}
func TestGetModelColumnDetail(t *testing.T) {
t.Run("simple struct", func(t *testing.T) {
model := TestModelForColumnDetail{
ID: 1,
Name: "Test",
Email: "test@example.com",
Description: "Test description",
ForeignKey: 100,
}
details := GetModelColumnDetail(reflect.ValueOf(model))
if len(details) != 5 {
t.Errorf("Expected 5 fields, got %d", len(details))
}
// Check ID field
found := false
for _, detail := range details {
if detail.Name == "ID" {
found = true
if detail.SQLName != "rid_test" {
t.Errorf("Expected SQLName 'rid_test', got '%s'", detail.SQLName)
}
// Note: primaryKey (without underscore) is not detected as primary_key
// The function looks for "identity" or "primary_key" (with underscore)
if detail.SQLDataType != "bigserial" {
t.Errorf("Expected SQLDataType 'bigserial', got '%s'", detail.SQLDataType)
}
if detail.Nullable {
t.Errorf("Expected Nullable false, got true")
}
}
}
if !found {
t.Errorf("ID field not found in details")
}
})
t.Run("struct with embedded fields", func(t *testing.T) {
model := ModelWithEmbeddedForDetail{
EmbeddedBase: EmbeddedBase{
ID: 1,
CreatedAt: "2024-01-01",
},
Title: "Test Title",
Content: "Test Content",
}
details := GetModelColumnDetail(reflect.ValueOf(model))
// Should have 4 fields: ID, CreatedAt from embedded, Title, Content from main
if len(details) != 4 {
t.Errorf("Expected 4 fields, got %d", len(details))
}
// Check that embedded field is included
foundID := false
foundCreatedAt := false
for _, detail := range details {
if detail.Name == "ID" {
foundID = true
if detail.SQLKey != "primary_key" {
t.Errorf("Expected SQLKey 'primary_key' for embedded ID, got '%s'", detail.SQLKey)
}
}
if detail.Name == "CreatedAt" {
foundCreatedAt = true
}
}
if !foundID {
t.Errorf("Embedded ID field not found")
}
if !foundCreatedAt {
t.Errorf("Embedded CreatedAt field not found")
}
})
t.Run("nil embedded pointer is skipped", func(t *testing.T) {
model := ModelWithNilEmbedded{
ID: 1,
Name: "Test",
EmbeddedBase: nil, // nil embedded pointer
}
details := GetModelColumnDetail(reflect.ValueOf(model))
// Should have 2 fields: ID and Name (embedded is nil, so skipped)
if len(details) != 2 {
t.Errorf("Expected 2 fields (nil embedded skipped), got %d", len(details))
}
})
t.Run("pointer to struct", func(t *testing.T) {
model := &TestModelForColumnDetail{
ID: 1,
Name: "Test",
}
details := GetModelColumnDetail(reflect.ValueOf(model))
if len(details) != 5 {
t.Errorf("Expected 5 fields, got %d", len(details))
}
})
t.Run("invalid value", func(t *testing.T) {
var invalid reflect.Value
details := GetModelColumnDetail(invalid)
if len(details) != 0 {
t.Errorf("Expected 0 fields for invalid value, got %d", len(details))
}
})
t.Run("non-struct type", func(t *testing.T) {
details := GetModelColumnDetail(reflect.ValueOf(123))
if len(details) != 0 {
t.Errorf("Expected 0 fields for non-struct, got %d", len(details))
}
})
t.Run("nullable and not null detection", func(t *testing.T) {
model := TestModelForColumnDetail{}
details := GetModelColumnDetail(reflect.ValueOf(model))
for _, detail := range details {
switch detail.Name {
case "ID":
if detail.Nullable {
t.Errorf("ID should not be nullable (has 'not null')")
}
case "Name":
if detail.Nullable {
t.Errorf("Name should not be nullable (has 'not null')")
}
case "Email":
if !detail.Nullable {
t.Errorf("Email should be nullable (has 'nullable')")
}
case "Description":
if !detail.Nullable {
t.Errorf("Description should be nullable (has 'null')")
}
}
}
})
t.Run("unique and uniqueindex detection", func(t *testing.T) {
type UniqueTestModel struct {
ID int `gorm:"column:id;primary_key"`
Username string `gorm:"column:username;unique"`
Email string `gorm:"column:email;uniqueindex"`
}
model := UniqueTestModel{}
details := GetModelColumnDetail(reflect.ValueOf(model))
for _, detail := range details {
switch detail.Name {
case "ID":
if detail.SQLKey != "primary_key" {
t.Errorf("ID should have SQLKey 'primary_key', got '%s'", detail.SQLKey)
}
case "Username":
if detail.SQLKey != "unique" {
t.Errorf("Username should have SQLKey 'unique', got '%s'", detail.SQLKey)
}
case "Email":
// The function checks for "unique" first, so uniqueindex is also detected as "unique"
// This is expected behavior based on the code logic
if detail.SQLKey != "unique" {
t.Errorf("Email should have SQLKey 'unique' (uniqueindex contains 'unique'), got '%s'", detail.SQLKey)
}
}
}
})
t.Run("foreign key detection", func(t *testing.T) {
// Note: The foreignkey extraction in generic_model.go has a bug where
// it requires ik > 0, so foreignkey at the start won't extract the value
type FKTestModel struct {
ParentID int `gorm:"column:parent_id;foreignkey:rid_parent;association_foreignkey:id_atevent"`
}
model := FKTestModel{}
details := GetModelColumnDetail(reflect.ValueOf(model))
if len(details) == 0 {
t.Fatal("Expected at least 1 field")
}
detail := details[0]
if detail.SQLKey != "foreign_key" {
t.Errorf("Expected SQLKey 'foreign_key', got '%s'", detail.SQLKey)
}
// Due to the bug in the code (requires ik > 0), the SQLName will be extracted
// when foreignkey is not at the beginning of the string
if detail.SQLName != "rid_parent" {
t.Errorf("Expected SQLName 'rid_parent', got '%s'", detail.SQLName)
}
})
}
func TestFnFindKeyVal(t *testing.T) {
tests := []struct {
name string
src string
key string
expected string
}{
{
name: "find column",
src: "column:user_id;primaryKey;type:bigint",
key: "column:",
expected: "user_id",
},
{
name: "find type",
src: "column:name;type:varchar(255);not null",
key: "type:",
expected: "varchar(255)",
},
{
name: "key not found",
src: "primaryKey;autoIncrement",
key: "column:",
expected: "",
},
{
name: "key at end without semicolon",
src: "primaryKey;column:id",
key: "column:",
expected: "id",
},
{
name: "case insensitive search",
src: "Column:user_id;primaryKey",
key: "column:",
expected: "user_id",
},
{
name: "empty src",
src: "",
key: "column:",
expected: "",
},
{
name: "multiple occurrences (returns first)",
src: "column:first;column:second",
key: "column:",
expected: "first",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := fnFindKeyVal(tt.src, tt.key)
if result != tt.expected {
t.Errorf("fnFindKeyVal(%q, %q) = %q, want %q", tt.src, tt.key, result, tt.expected)
}
})
}
}
func TestGetModelColumnDetail_FieldValue(t *testing.T) {
model := TestModelForColumnDetail{
ID: 123,
Name: "TestName",
Email: "test@example.com",
}
details := GetModelColumnDetail(reflect.ValueOf(model))
for _, detail := range details {
if !detail.FieldValue.IsValid() {
t.Errorf("Field %s has invalid FieldValue", detail.Name)
}
// Check that FieldValue matches the actual value
switch detail.Name {
case "ID":
if detail.FieldValue.Int() != 123 {
t.Errorf("Expected ID FieldValue 123, got %v", detail.FieldValue.Int())
}
case "Name":
if detail.FieldValue.String() != "TestName" {
t.Errorf("Expected Name FieldValue 'TestName', got %v", detail.FieldValue.String())
}
case "Email":
if detail.FieldValue.String() != "test@example.com" {
t.Errorf("Expected Email FieldValue 'test@example.com', got %v", detail.FieldValue.String())
}
}
}
}

View File

@@ -17,3 +17,33 @@ func Len(v any) int {
return 0 return 0
} }
} }
// ExtractTableNameOnly extracts the table name from a fully qualified table reference.
// It removes any schema prefix (e.g., "schema.table" -> "table") and truncates at
// the first delimiter (comma, space, tab, or newline). If the input contains multiple
// dots, it returns everything after the last dot up to the first delimiter.
func ExtractTableNameOnly(fullName string) string {
// First, split by dot to remove schema prefix if present
lastDotIndex := -1
for i, char := range fullName {
if char == '.' {
lastDotIndex = i
}
}
// Start from after the last dot (or from beginning if no dot)
startIndex := 0
if lastDotIndex != -1 {
startIndex = lastDotIndex + 1
}
// Now find the end (first delimiter after the table name)
for i := startIndex; i < len(fullName); i++ {
char := rune(fullName[i])
if char == ',' || char == ' ' || char == '\t' || char == '\n' {
return fullName[startIndex:i]
}
}
return fullName[startIndex:]
}

View File

@@ -1,7 +1,9 @@
package reflection package reflection
import ( import (
"fmt"
"reflect" "reflect"
"strconv"
"strings" "strings"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry" "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
@@ -132,7 +134,7 @@ func findFieldByName(val reflect.Value, name string) any {
} }
// Check if field name matches // Check if field name matches
if strings.ToLower(field.Name) == name && fieldValue.CanInterface() { if strings.EqualFold(field.Name, name) && fieldValue.CanInterface() {
return fieldValue.Interface() return fieldValue.Interface()
} }
} }
@@ -323,6 +325,127 @@ func ExtractColumnFromBunTag(tag string) string {
return "" return ""
} }
// GetSQLModelColumns extracts column names that have valid SQL field mappings
// This function only returns columns that:
// 1. Have bun or gorm tags (not just json tags)
// 2. Are not relations (no rel:, join:, foreignKey, references, many2many tags)
// 3. Are not scan-only embedded fields
func GetSQLModelColumns(model any) []string {
var columns []string
modelType := reflect.TypeOf(model)
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return columns
}
collectSQLColumnsFromType(modelType, &columns, false)
return columns
}
// collectSQLColumnsFromType recursively collects SQL column names from a struct type
// scanOnlyEmbedded indicates if we're inside a scan-only embedded struct
func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbedded bool) {
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
// Check if this is an embedded struct
if field.Anonymous {
// Unwrap pointer type if necessary
fieldType := field.Type
if fieldType.Kind() == reflect.Pointer {
fieldType = fieldType.Elem()
}
// Check if the embedded struct itself is scan-only
isScanOnly := scanOnlyEmbedded
bunTag := field.Tag.Get("bun")
if bunTag != "" && isBunFieldScanOnly(bunTag) {
isScanOnly = true
}
// Recursively process embedded struct
if fieldType.Kind() == reflect.Struct {
collectSQLColumnsFromType(fieldType, columns, isScanOnly)
continue
}
}
// Skip fields in scan-only embedded structs
if scanOnlyEmbedded {
continue
}
// Get bun and gorm tags
bunTag := field.Tag.Get("bun")
gormTag := field.Tag.Get("gorm")
// Skip if neither bun nor gorm tag exists
if bunTag == "" && gormTag == "" {
continue
}
// Skip if explicitly marked with "-"
if bunTag == "-" || gormTag == "-" {
continue
}
// Skip if field itself is scan-only (bun)
if bunTag != "" && isBunFieldScanOnly(bunTag) {
continue
}
// Skip if field itself is read-only (gorm)
if gormTag != "" && isGormFieldReadOnly(gormTag) {
continue
}
// Skip relation fields (bun)
if bunTag != "" {
// Skip if it's a bun relation (rel:, join:, or m2m:)
if strings.Contains(bunTag, "rel:") ||
strings.Contains(bunTag, "join:") ||
strings.Contains(bunTag, "m2m:") {
continue
}
}
// Skip relation fields (gorm)
if gormTag != "" {
// Skip if it has gorm relationship tags
if strings.Contains(gormTag, "foreignKey:") ||
strings.Contains(gormTag, "references:") ||
strings.Contains(gormTag, "many2many:") ||
strings.Contains(gormTag, "constraint:") {
continue
}
}
// Get column name
columnName := ""
if bunTag != "" {
columnName = ExtractColumnFromBunTag(bunTag)
}
if columnName == "" && gormTag != "" {
columnName = ExtractColumnFromGormTag(gormTag)
}
// Skip if we couldn't extract a column name
if columnName == "" {
continue
}
*columns = append(*columns, columnName)
}
}
// IsColumnWritable checks if a column can be written to in the database // IsColumnWritable checks if a column can be written to in the database
// For bun: returns false if the field has "scanonly" tag // For bun: returns false if the field has "scanonly" tag
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag // For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
@@ -351,7 +474,7 @@ func IsColumnWritable(model any, columnName string) bool {
// isColumnWritableInType recursively searches for a column and checks if it's writable // isColumnWritableInType recursively searches for a column and checks if it's writable
// Returns (found, writable) where found indicates if the column was found // Returns (found, writable) where found indicates if the column was found
func isColumnWritableInType(typ reflect.Type, columnName string) (bool, bool) { func isColumnWritableInType(typ reflect.Type, columnName string) (found bool, writable bool) {
for i := 0; i < typ.NumField(); i++ { for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i) field := typ.Field(i)
@@ -440,3 +563,433 @@ func isGormFieldReadOnly(tag string) bool {
} }
return false return false
} }
// ExtractSourceColumn extracts the base column name from PostgreSQL JSON operators
// Examples:
// - "columna->>'val'" returns "columna"
// - "columna->'key'" returns "columna"
// - "columna" returns "columna"
// - "table.columna->>'val'" returns "table.columna"
func ExtractSourceColumn(colName string) string {
// Check for PostgreSQL JSON operators: -> and ->>
if idx := strings.Index(colName, "->>"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
if idx := strings.Index(colName, "->"); idx != -1 {
return strings.TrimSpace(colName[:idx])
}
return colName
}
// ToSnakeCase converts a string from CamelCase to snake_case
func ToSnakeCase(s string) string {
var result strings.Builder
for i, r := range s {
if i > 0 && r >= 'A' && r <= 'Z' {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// GetColumnTypeFromModel uses reflection to determine the Go type of a column in a model
func GetColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
if model == nil {
return reflect.Invalid
}
// Extract the source column name (remove JSON operators like ->> or ->)
sourceColName := ExtractSourceColumn(colName)
modelType := reflect.TypeOf(model)
// Dereference pointer if needed
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
// Ensure it's a struct
if modelType.Kind() != reflect.Struct {
return reflect.Invalid
}
// Find the field by JSON tag or field name
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Check JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" {
// Parse JSON tag (format: "name,omitempty")
parts := strings.Split(jsonTag, ",")
if parts[0] == sourceColName {
return field.Type.Kind()
}
}
// Check field name (case-insensitive)
if strings.EqualFold(field.Name, sourceColName) {
return field.Type.Kind()
}
// Check snake_case conversion
snakeCaseName := ToSnakeCase(field.Name)
if snakeCaseName == sourceColName {
return field.Type.Kind()
}
}
return reflect.Invalid
}
// IsNumericType checks if a reflect.Kind is a numeric type
func IsNumericType(kind reflect.Kind) bool {
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
}
// IsStringType checks if a reflect.Kind is a string type
func IsStringType(kind reflect.Kind) bool {
return kind == reflect.String
}
// IsNumericValue checks if a string value can be parsed as a number
func IsNumericValue(value string) bool {
value = strings.TrimSpace(value)
_, err := strconv.ParseFloat(value, 64)
return err == nil
}
// ConvertToNumericType converts a string value to the appropriate numeric type
func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
value = strings.TrimSpace(value)
switch kind {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
// Parse as integer
bitSize := 64
switch kind {
case reflect.Int8:
bitSize = 8
case reflect.Int16:
bitSize = 16
case reflect.Int32:
bitSize = 32
}
intVal, err := strconv.ParseInt(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Int:
return int(intVal), nil
case reflect.Int8:
return int8(intVal), nil
case reflect.Int16:
return int16(intVal), nil
case reflect.Int32:
return int32(intVal), nil
case reflect.Int64:
return intVal, nil
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
// Parse as unsigned integer
bitSize := 64
switch kind {
case reflect.Uint8:
bitSize = 8
case reflect.Uint16:
bitSize = 16
case reflect.Uint32:
bitSize = 32
}
uintVal, err := strconv.ParseUint(value, 10, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
}
// Return the appropriate type
switch kind {
case reflect.Uint:
return uint(uintVal), nil
case reflect.Uint8:
return uint8(uintVal), nil
case reflect.Uint16:
return uint16(uintVal), nil
case reflect.Uint32:
return uint32(uintVal), nil
case reflect.Uint64:
return uintVal, nil
}
case reflect.Float32, reflect.Float64:
// Parse as float
bitSize := 64
if kind == reflect.Float32 {
bitSize = 32
}
floatVal, err := strconv.ParseFloat(value, bitSize)
if err != nil {
return nil, fmt.Errorf("invalid float value: %w", err)
}
if kind == reflect.Float32 {
return float32(floatVal), nil
}
return floatVal, nil
}
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
}
// RelationType represents the type of database relationship
type RelationType string
const (
RelationHasMany RelationType = "has-many" // 1:N - use separate query
RelationBelongsTo RelationType = "belongs-to" // N:1 - use JOIN
RelationHasOne RelationType = "has-one" // 1:1 - use JOIN
RelationManyToMany RelationType = "many-to-many" // M:N - use separate query
RelationUnknown RelationType = "unknown"
)
// ShouldUseJoin returns true if the relation type should use a JOIN instead of separate query
func (rt RelationType) ShouldUseJoin() bool {
return rt == RelationBelongsTo || rt == RelationHasOne
}
// GetRelationType inspects the model's struct tags to determine the relationship type
// It checks both Bun and GORM tags to identify the relationship cardinality
func GetRelationType(model interface{}, fieldName string) RelationType {
if model == nil || fieldName == "" {
return RelationUnknown
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return RelationUnknown
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return RelationUnknown
}
// Find the field
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
// Check if field name matches (case-insensitive)
if !strings.EqualFold(field.Name, fieldName) {
continue
}
// Check Bun tags first
bunTag := field.Tag.Get("bun")
if bunTag != "" && strings.Contains(bunTag, "rel:") {
// Parse bun relation tag: rel:has-many, rel:belongs-to, rel:has-one, rel:many-to-many
parts := strings.Split(bunTag, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "rel:") {
relType := strings.TrimPrefix(part, "rel:")
switch relType {
case "has-many":
return RelationHasMany
case "belongs-to":
return RelationBelongsTo
case "has-one":
return RelationHasOne
case "many-to-many", "m2m":
return RelationManyToMany
}
}
}
}
// Check GORM tags
gormTag := field.Tag.Get("gorm")
if gormTag != "" {
// GORM uses different patterns:
// - foreignKey: usually indicates belongs-to or has-one
// - many2many: indicates many-to-many
// - Field type (slice vs pointer) helps determine cardinality
if strings.Contains(gormTag, "many2many:") {
return RelationManyToMany
}
// Check field type for cardinality hints
fieldType := field.Type
if fieldType.Kind() == reflect.Slice {
// Slice indicates has-many or many-to-many
return RelationHasMany
}
if fieldType.Kind() == reflect.Ptr {
// Pointer to single struct usually indicates belongs-to or has-one
// Check if it has foreignKey (belongs-to) or references (has-one)
if strings.Contains(gormTag, "foreignKey:") {
return RelationBelongsTo
}
return RelationHasOne
}
}
// Fall back to field type inference
fieldType := field.Type
if fieldType.Kind() == reflect.Slice {
// Slice of structs → has-many
return RelationHasMany
}
if fieldType.Kind() == reflect.Ptr || fieldType.Kind() == reflect.Struct {
// Single struct → belongs-to (default assumption for safety)
// Using belongs-to as default ensures we use JOIN, which is safer
return RelationBelongsTo
}
}
return RelationUnknown
}
// GetRelationModel gets the model type for a relation field
// It searches for the field by name in the following order (case-insensitive):
// 1. Actual field name
// 2. Bun tag name (if exists)
// 3. Gorm tag name (if exists)
// 4. JSON tag name (if exists)
//
// Supports recursive field paths using dot notation (e.g., "MAL.MAL.DEF")
// For nested fields, it traverses through each level of the struct hierarchy
func GetRelationModel(model interface{}, fieldName string) interface{} {
if model == nil || fieldName == "" {
return nil
}
// Split the field name by "." to handle nested/recursive relations
fieldParts := strings.Split(fieldName, ".")
// Start with the current model
currentModel := model
// Traverse through each level of the field path
for _, part := range fieldParts {
if part == "" {
continue
}
currentModel = getRelationModelSingleLevel(currentModel, part)
if currentModel == nil {
return nil
}
}
return currentModel
}
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
// This is a helper function used by GetRelationModel to handle one level at a time
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
if model == nil || fieldName == "" {
return nil
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nil
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil
}
// Find the field by checking in priority order (case-insensitive)
var field *reflect.StructField
normalizedFieldName := strings.ToLower(fieldName)
for i := 0; i < modelType.NumField(); i++ {
f := modelType.Field(i)
// 1. Check actual field name (case-insensitive)
if strings.EqualFold(f.Name, fieldName) {
field = &f
break
}
// 2. Check bun tag name
bunTag := f.Tag.Get("bun")
if bunTag != "" {
bunColName := ExtractColumnFromBunTag(bunTag)
if bunColName != "" && strings.EqualFold(bunColName, normalizedFieldName) {
field = &f
break
}
}
// 3. Check gorm tag name
gormTag := f.Tag.Get("gorm")
if gormTag != "" {
gormColName := ExtractColumnFromGormTag(gormTag)
if gormColName != "" && strings.EqualFold(gormColName, normalizedFieldName) {
field = &f
break
}
}
// 4. Check JSON tag name
jsonTag := f.Tag.Get("json")
if jsonTag != "" {
parts := strings.Split(jsonTag, ",")
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
if strings.EqualFold(parts[0], normalizedFieldName) {
field = &f
break
}
}
}
}
if field == nil {
return nil
}
// Get the target type
targetType := field.Type
if targetType == nil {
return nil
}
if targetType.Kind() == reflect.Slice {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
if targetType == nil {
return nil
}
}
if targetType.Kind() != reflect.Struct {
return nil
}
// Create a zero value of the target type
return reflect.New(targetType).Elem().Interface()
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,138 @@
package resolvespec
import (
"context"
"testing"
)
func TestContextOperations(t *testing.T) {
ctx := context.Background()
// Test Schema
t.Run("WithSchema and GetSchema", func(t *testing.T) {
ctx = WithSchema(ctx, "public")
schema := GetSchema(ctx)
if schema != "public" {
t.Errorf("Expected schema 'public', got '%s'", schema)
}
})
// Test Entity
t.Run("WithEntity and GetEntity", func(t *testing.T) {
ctx = WithEntity(ctx, "users")
entity := GetEntity(ctx)
if entity != "users" {
t.Errorf("Expected entity 'users', got '%s'", entity)
}
})
// Test TableName
t.Run("WithTableName and GetTableName", func(t *testing.T) {
ctx = WithTableName(ctx, "public.users")
tableName := GetTableName(ctx)
if tableName != "public.users" {
t.Errorf("Expected tableName 'public.users', got '%s'", tableName)
}
})
// Test Model
t.Run("WithModel and GetModel", func(t *testing.T) {
type TestModel struct {
ID int
Name string
}
model := &TestModel{ID: 1, Name: "test"}
ctx = WithModel(ctx, model)
retrieved := GetModel(ctx)
if retrieved == nil {
t.Error("Expected model to be retrieved, got nil")
}
if retrievedModel, ok := retrieved.(*TestModel); ok {
if retrievedModel.ID != 1 || retrievedModel.Name != "test" {
t.Errorf("Expected model with ID=1 and Name='test', got ID=%d, Name='%s'", retrievedModel.ID, retrievedModel.Name)
}
} else {
t.Error("Retrieved model is not of expected type")
}
})
// Test ModelPtr
t.Run("WithModelPtr and GetModelPtr", func(t *testing.T) {
type TestModel struct {
ID int
}
models := []*TestModel{}
ctx = WithModelPtr(ctx, &models)
retrieved := GetModelPtr(ctx)
if retrieved == nil {
t.Error("Expected modelPtr to be retrieved, got nil")
}
})
// Test WithRequestData
t.Run("WithRequestData", func(t *testing.T) {
type TestModel struct {
ID int
Name string
}
model := &TestModel{ID: 1, Name: "test"}
modelPtr := &[]*TestModel{}
ctx = WithRequestData(ctx, "test_schema", "test_entity", "test_schema.test_entity", model, modelPtr)
if GetSchema(ctx) != "test_schema" {
t.Errorf("Expected schema 'test_schema', got '%s'", GetSchema(ctx))
}
if GetEntity(ctx) != "test_entity" {
t.Errorf("Expected entity 'test_entity', got '%s'", GetEntity(ctx))
}
if GetTableName(ctx) != "test_schema.test_entity" {
t.Errorf("Expected tableName 'test_schema.test_entity', got '%s'", GetTableName(ctx))
}
if GetModel(ctx) == nil {
t.Error("Expected model to be set")
}
if GetModelPtr(ctx) == nil {
t.Error("Expected modelPtr to be set")
}
})
}
func TestEmptyContext(t *testing.T) {
ctx := context.Background()
t.Run("GetSchema with empty context", func(t *testing.T) {
schema := GetSchema(ctx)
if schema != "" {
t.Errorf("Expected empty schema, got '%s'", schema)
}
})
t.Run("GetEntity with empty context", func(t *testing.T) {
entity := GetEntity(ctx)
if entity != "" {
t.Errorf("Expected empty entity, got '%s'", entity)
}
})
t.Run("GetTableName with empty context", func(t *testing.T) {
tableName := GetTableName(ctx)
if tableName != "" {
t.Errorf("Expected empty tableName, got '%s'", tableName)
}
})
t.Run("GetModel with empty context", func(t *testing.T) {
model := GetModel(ctx)
if model != nil {
t.Errorf("Expected nil model, got %v", model)
}
})
t.Run("GetModelPtr with empty context", func(t *testing.T) {
modelPtr := GetModelPtr(ctx)
if modelPtr != nil {
t.Errorf("Expected nil modelPtr, got %v", modelPtr)
}
})
}

179
pkg/resolvespec/cursor.go Normal file
View File

@@ -0,0 +1,179 @@
package resolvespec
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// CursorDirection defines pagination direction
type CursorDirection int
const (
CursorForward CursorDirection = 1
CursorBackward CursorDirection = -1
)
// GetCursorFilter generates a SQL `EXISTS` subquery for cursor-based pagination.
// It uses the current request's sort and cursor values.
//
// Parameters:
// - tableName: name of the main table (e.g. "posts")
// - pkName: primary key column (e.g. "id")
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
// - options: the request options containing sort and cursor information
//
// Returns SQL snippet to embed in WHERE clause.
func GetCursorFilter(
tableName string,
pkName string,
modelColumns []string,
options common.RequestOptions,
) (string, error) {
// Remove schema prefix if present
if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1]
}
// --------------------------------------------------------------------- //
// 1. Determine active cursor
// --------------------------------------------------------------------- //
cursorID, direction := getActiveCursor(options)
if cursorID == "" {
return "", fmt.Errorf("no cursor provided for table %s", tableName)
}
// --------------------------------------------------------------------- //
// 2. Extract sort columns
// --------------------------------------------------------------------- //
sortItems := options.Sort
if len(sortItems) == 0 {
return "", fmt.Errorf("no sort columns defined")
}
// --------------------------------------------------------------------- //
// 3. Prepare
// --------------------------------------------------------------------- //
var whereClauses []string
reverse := direction < 0
// --------------------------------------------------------------------- //
// 4. Process each sort column
// --------------------------------------------------------------------- //
for _, s := range sortItems {
col := strings.TrimSpace(s.Column)
if col == "" {
continue
}
// Parse: "created_at", "user.name", etc.
parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".")
// Direction from struct
desc := strings.EqualFold(s.Direction, "desc")
if reverse {
desc = !desc
}
// Resolve column
cursorCol, targetCol, err := resolveColumn(
field, prefix, tableName, modelColumns,
)
if err != nil {
logger.Warn("Skipping invalid sort column %q: %v", col, err)
continue
}
// Build inequality
op := "<"
if desc {
op = ">"
}
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
}
if len(whereClauses) == 0 {
return "", fmt.Errorf("no valid sort columns after filtering")
}
// --------------------------------------------------------------------- //
// 5. Build priority OR-AND chain
// --------------------------------------------------------------------- //
orSQL := buildPriorityChain(whereClauses)
// --------------------------------------------------------------------- //
// 6. Final EXISTS subquery
// --------------------------------------------------------------------- //
query := fmt.Sprintf(`EXISTS (
SELECT 1
FROM %s cursor_select
WHERE cursor_select.%s = %s
AND (%s)
)`,
tableName,
pkName,
cursorID,
orSQL,
)
return query, nil
}
// ------------------------------------------------------------------------- //
// Helper: get active cursor (forward or backward)
func getActiveCursor(options common.RequestOptions) (id string, direction CursorDirection) {
if options.CursorForward != "" {
return options.CursorForward, CursorForward
}
if options.CursorBackward != "" {
return options.CursorBackward, CursorBackward
}
return "", 0
}
// Helper: resolve column (main table only for now)
func resolveColumn(
field, prefix, tableName string,
modelColumns []string,
) (cursorCol, targetCol string, err error) {
// JSON field
if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, nil
}
// Main table column
if modelColumns != nil {
for _, col := range modelColumns {
if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, nil
}
}
} else {
// No validation → allow all main-table fields
return "cursor_select." + field, tableName + "." + field, nil
}
// Joined column (not supported in resolvespec yet)
if prefix != "" && prefix != tableName {
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
}
return "", "", fmt.Errorf("invalid column: %s", field)
}
// ------------------------------------------------------------------------- //
// Helper: build OR-AND priority chain
func buildPriorityChain(clauses []string) string {
var or []string
for i := 0; i < len(clauses); i++ {
and := strings.Join(clauses[:i+1], "\n AND ")
or = append(or, "("+and+")")
}
return strings.Join(or, "\n OR ")
}

View File

@@ -0,0 +1,378 @@
package resolvespec
import (
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestGetCursorFilter_Forward(t *testing.T) {
options := common.RequestOptions{
Sort: []common.SortOption{
{Column: "created_at", Direction: "DESC"},
{Column: "id", Direction: "ASC"},
},
CursorForward: "123",
}
tableName := "posts"
pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
if filter == "" {
t.Fatal("Expected non-empty cursor filter")
}
// Verify filter contains EXISTS subquery
if !strings.Contains(filter, "EXISTS") {
t.Errorf("Filter should contain EXISTS subquery, got: %s", filter)
}
// Verify filter references the cursor ID
if !strings.Contains(filter, "123") {
t.Errorf("Filter should reference cursor ID 123, got: %s", filter)
}
// Verify filter contains the table name
if !strings.Contains(filter, tableName) {
t.Errorf("Filter should reference table name %s, got: %s", tableName, filter)
}
// Verify filter contains primary key
if !strings.Contains(filter, pkName) {
t.Errorf("Filter should reference primary key %s, got: %s", pkName, filter)
}
t.Logf("Generated cursor filter: %s", filter)
}
func TestGetCursorFilter_Backward(t *testing.T) {
options := common.RequestOptions{
Sort: []common.SortOption{
{Column: "created_at", Direction: "DESC"},
{Column: "id", Direction: "ASC"},
},
CursorBackward: "456",
}
tableName := "posts"
pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
if filter == "" {
t.Fatal("Expected non-empty cursor filter")
}
// Verify filter contains cursor ID
if !strings.Contains(filter, "456") {
t.Errorf("Filter should reference cursor ID 456, got: %s", filter)
}
// For backward cursor, sort direction should be reversed
// This is handled internally by the GetCursorFilter function
t.Logf("Generated backward cursor filter: %s", filter)
}
func TestGetCursorFilter_NoCursor(t *testing.T) {
options := common.RequestOptions{
Sort: []common.SortOption{
{Column: "created_at", Direction: "DESC"},
},
// No cursor set
}
tableName := "posts"
pkName := "id"
modelColumns := []string{"id", "title", "created_at"}
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
if err == nil {
t.Error("Expected error when no cursor is provided")
}
if !strings.Contains(err.Error(), "no cursor provided") {
t.Errorf("Expected 'no cursor provided' error, got: %v", err)
}
}
func TestGetCursorFilter_NoSort(t *testing.T) {
options := common.RequestOptions{
Sort: []common.SortOption{},
CursorForward: "123",
}
tableName := "posts"
pkName := "id"
modelColumns := []string{"id", "title"}
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
if err == nil {
t.Error("Expected error when no sort columns are defined")
}
if !strings.Contains(err.Error(), "no sort columns") {
t.Errorf("Expected 'no sort columns' error, got: %v", err)
}
}
func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
options := common.RequestOptions{
Sort: []common.SortOption{
{Column: "priority", Direction: "DESC"},
{Column: "created_at", Direction: "DESC"},
{Column: "id", Direction: "ASC"},
},
CursorForward: "789",
}
tableName := "tasks"
pkName := "id"
modelColumns := []string{"id", "title", "priority", "created_at"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
// Verify filter contains priority column
if !strings.Contains(filter, "priority") {
t.Errorf("Filter should reference priority column, got: %s", filter)
}
// Verify filter contains created_at column
if !strings.Contains(filter, "created_at") {
t.Errorf("Filter should reference created_at column, got: %s", filter)
}
t.Logf("Generated multi-column cursor filter: %s", filter)
}
func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
options := common.RequestOptions{
Sort: []common.SortOption{
{Column: "name", Direction: "ASC"},
},
CursorForward: "100",
}
tableName := "public.users"
pkName := "id"
modelColumns := []string{"id", "name", "email"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
// Should handle schema prefix properly
if !strings.Contains(filter, "users") {
t.Errorf("Filter should reference table name users, got: %s", filter)
}
t.Logf("Generated cursor filter with schema: %s", filter)
}
func TestGetActiveCursor(t *testing.T) {
tests := []struct {
name string
options common.RequestOptions
expectedID string
expectedDirection CursorDirection
}{
{
name: "Forward cursor only",
options: common.RequestOptions{
CursorForward: "123",
},
expectedID: "123",
expectedDirection: CursorForward,
},
{
name: "Backward cursor only",
options: common.RequestOptions{
CursorBackward: "456",
},
expectedID: "456",
expectedDirection: CursorBackward,
},
{
name: "Both cursors - forward takes precedence",
options: common.RequestOptions{
CursorForward: "123",
CursorBackward: "456",
},
expectedID: "123",
expectedDirection: CursorForward,
},
{
name: "No cursors",
options: common.RequestOptions{},
expectedID: "",
expectedDirection: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
id, direction := getActiveCursor(tt.options)
if id != tt.expectedID {
t.Errorf("Expected cursor ID %q, got %q", tt.expectedID, id)
}
if direction != tt.expectedDirection {
t.Errorf("Expected direction %d, got %d", tt.expectedDirection, direction)
}
})
}
}
func TestResolveColumn(t *testing.T) {
tests := []struct {
name string
field string
prefix string
tableName string
modelColumns []string
wantCursor string
wantTarget string
wantErr bool
}{
{
name: "Simple column",
field: "id",
prefix: "",
tableName: "users",
modelColumns: []string{"id", "name", "email"},
wantCursor: "cursor_select.id",
wantTarget: "users.id",
wantErr: false,
},
{
name: "Column with case insensitive match",
field: "NAME",
prefix: "",
tableName: "users",
modelColumns: []string{"id", "name", "email"},
wantCursor: "cursor_select.NAME",
wantTarget: "users.NAME",
wantErr: false,
},
{
name: "Invalid column",
field: "invalid_field",
prefix: "",
tableName: "users",
modelColumns: []string{"id", "name", "email"},
wantErr: true,
},
{
name: "JSON field",
field: "metadata->>'key'",
prefix: "",
tableName: "posts",
modelColumns: []string{"id", "metadata"},
wantCursor: "cursor_select.metadata->>'key'",
wantTarget: "posts.metadata->>'key'",
wantErr: false,
},
{
name: "Joined column (not supported)",
field: "name",
prefix: "user",
tableName: "posts",
modelColumns: []string{"id", "title"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
if tt.wantErr {
if err == nil {
t.Error("Expected error but got none")
}
return
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if cursor != tt.wantCursor {
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
}
if target != tt.wantTarget {
t.Errorf("Expected target %q, got %q", tt.wantTarget, target)
}
})
}
}
func TestBuildPriorityChain(t *testing.T) {
clauses := []string{
"cursor_select.priority > tasks.priority",
"cursor_select.created_at > tasks.created_at",
"cursor_select.id < tasks.id",
}
result := buildPriorityChain(clauses)
// Should build OR-AND chain for cursor comparison
if !strings.Contains(result, "OR") {
t.Error("Priority chain should contain OR operators")
}
if !strings.Contains(result, "AND") {
t.Error("Priority chain should contain AND operators for composite conditions")
}
// First clause should appear standalone
if !strings.Contains(result, clauses[0]) {
t.Errorf("Priority chain should contain first clause: %s", clauses[0])
}
t.Logf("Built priority chain: %s", result)
}
func TestCursorFilter_SQL_Safety(t *testing.T) {
// Test that cursor filter doesn't allow SQL injection
options := common.RequestOptions{
Sort: []common.SortOption{
{Column: "created_at", Direction: "DESC"},
},
CursorForward: "123; DROP TABLE users; --",
}
tableName := "posts"
pkName := "id"
modelColumns := []string{"id", "created_at"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
// The cursor ID is inserted directly into the query
// This should be sanitized by the sanitizeWhereClause function in the handler
// For now, just verify it generates a filter
if filter == "" {
t.Error("Expected non-empty cursor filter even with special characters")
}
t.Logf("Generated filter with special chars in cursor: %s", filter)
}

View File

@@ -8,17 +8,26 @@ import (
"reflect" "reflect"
"runtime/debug" "runtime/debug"
"strings" "strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
// FallbackHandler is a function that handles requests when no model is found
// It receives the same parameters as the Handle method
type FallbackHandler func(w common.ResponseWriter, r common.Request, params map[string]string)
// Handler handles API requests using database and model abstractions // Handler handles API requests using database and model abstractions
type Handler struct { type Handler struct {
db common.Database db common.Database
registry common.ModelRegistry registry common.ModelRegistry
nestedProcessor *common.NestedCUDProcessor nestedProcessor *common.NestedCUDProcessor
hooks *HookRegistry
fallbackHandler FallbackHandler
openAPIGenerator func() (string, error)
} }
// NewHandler creates a new API handler with database and registry abstractions // NewHandler creates a new API handler with database and registry abstractions
@@ -26,12 +35,31 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
handler := &Handler{ handler := &Handler{
db: db, db: db,
registry: registry, registry: registry,
hooks: NewHookRegistry(),
} }
// Initialize nested processor // Initialize nested processor
handler.nestedProcessor = common.NewNestedCUDProcessor(db, registry, handler) handler.nestedProcessor = common.NewNestedCUDProcessor(db, registry, handler)
return handler return handler
} }
// Hooks returns the hook registry for this handler
// Use this to register custom hooks for operations
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// SetFallbackHandler sets a fallback handler to be called when no model is found
// If not set, the handler will simply return (pass through to next route)
func (h *Handler) SetFallbackHandler(fallback FallbackHandler) {
h.fallbackHandler = fallback
}
// GetDatabase returns the underlying database connection
// Implements common.SpecHandler interface
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// handlePanic is a helper function to handle panics with stack traces // handlePanic is a helper function to handle panics with stack traces
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) { func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
stack := debug.Stack() stack := debug.Stack()
@@ -48,7 +76,13 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
} }
}() }()
ctx := context.Background() // Check for ?openapi query parameter
if r.UnderlyingRequest().URL.Query().Get("openapi") != "" {
h.HandleOpenAPI(w, r)
return
}
ctx := r.UnderlyingRequest().Context()
body, err := r.Body() body, err := r.Body()
if err != nil { if err != nil {
@@ -73,33 +107,27 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// Get model and populate context with request-scoped data // Get model and populate context with request-scoped data
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
logger.Error("Invalid entity: %v", err) // Model not found - call fallback handler if set, otherwise pass through
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err) logger.Debug("Model not found for %s.%s", schema, entity)
if h.fallbackHandler != nil {
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
h.fallbackHandler(w, r, params)
} else {
logger.Debug("No fallback handler set, passing through to next route")
}
return return
} }
// Validate that the model is a struct type (not a slice or pointer to slice) // Validate and unwrap model using common utility
modelType := reflect.TypeOf(model) result, err := common.ValidateAndUnwrapModel(model)
originalType := modelType if err != nil {
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { logger.Error("Model for %s.%s validation failed: %v", schema, entity, err)
modelType = modelType.Elem() h.sendError(w, http.StatusInternalServerError, "invalid_model_type", err.Error(), err)
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
fmt.Errorf("invalid model type: %v", originalType))
return return
} }
// If the registered model was a pointer or slice, use the unwrapped struct type model = result.Model
if originalType != modelType { modelPtr := result.ModelPtr
model = reflect.New(modelType).Elem().Interface()
}
// Create a pointer to the model type for database operations
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
tableName := h.getTableName(schema, entity, model) tableName := h.getTableName(schema, entity, model)
// Add request-scoped data to context // Add request-scoped data to context
@@ -118,6 +146,8 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
h.handleUpdate(ctx, w, id, req.ID, req.Data, req.Options) h.handleUpdate(ctx, w, id, req.ID, req.Data, req.Options)
case "delete": case "delete":
h.handleDelete(ctx, w, id, req.Data) h.handleDelete(ctx, w, id, req.Data)
case "meta":
h.handleMeta(ctx, w, schema, entity, model)
default: default:
logger.Error("Invalid operation: %s", req.Operation) logger.Error("Invalid operation: %s", req.Operation)
h.sendError(w, http.StatusBadRequest, "invalid_operation", "Invalid operation", nil) h.sendError(w, http.StatusBadRequest, "invalid_operation", "Invalid operation", nil)
@@ -133,6 +163,12 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
} }
}() }()
// Check for ?openapi query parameter
if r.UnderlyingRequest().URL.Query().Get("openapi") != "" {
h.HandleOpenAPI(w, r)
return
}
schema := params["schema"] schema := params["schema"]
entity := params["entity"] entity := params["entity"]
@@ -140,8 +176,14 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
logger.Error("Failed to get model: %v", err) // Model not found - call fallback handler if set, otherwise pass through
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err) logger.Debug("Model not found for %s.%s", schema, entity)
if h.fallbackHandler != nil {
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
h.fallbackHandler(w, r, params)
} else {
logger.Debug("No fallback handler set, passing through to next route")
}
return return
} }
@@ -149,6 +191,21 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
h.sendResponse(w, metadata, nil) h.sendResponse(w, metadata, nil)
} }
// handleMeta processes meta operation requests
func (h *Handler) handleMeta(ctx context.Context, w common.ResponseWriter, schema, entity string, model interface{}) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleMeta", err)
}
}()
logger.Info("Getting metadata for %s.%s via meta operation", schema, entity)
metadata := h.generateMetadata(schema, entity, model)
h.sendResponse(w, metadata, nil)
}
func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) { func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id string, options common.RequestOptions) {
// Capture panics and return error response // Capture panics and return error response
defer func() { defer func() {
@@ -191,10 +248,17 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Table(tableName) query = query.Table(tableName)
} }
if len(options.Columns) == 0 && (len(options.ComputedColumns) > 0) {
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
options.Columns = reflection.GetSQLModelColumns(model)
}
// Apply column selection // Apply column selection
if len(options.Columns) > 0 { if len(options.Columns) > 0 {
logger.Debug("Selecting columns: %v", options.Columns) logger.Debug("Selecting columns: %v", options.Columns)
query = query.Column(options.Columns...) for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
} }
if len(options.ComputedColumns) > 0 { if len(options.ComputedColumns) > 0 {
@@ -206,7 +270,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply preloading // Apply preloading
if len(options.Preload) > 0 { if len(options.Preload) > 0 {
query = h.applyPreloads(model, query, options.Preload) var err error
query, err = h.applyPreloads(model, query, options.Preload)
if err != nil {
logger.Error("Failed to apply preloads: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_preload", "Failed to apply preloads", err)
return
}
} }
// Apply filters // Apply filters
@@ -225,14 +295,91 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction)) query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
} }
// Get total count before pagination // Apply cursor-based pagination
total, err := query.Count(ctx) if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
if err != nil { logger.Debug("Applying cursor pagination")
logger.Error("Error counting records: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err) // Get primary key name
return pkName := reflection.GetPrimaryKeyName(model)
// Extract model columns for validation
modelColumns := reflection.GetModelColumns(model)
// Get cursor filter SQL
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
if err != nil {
logger.Error("Error building cursor filter: %v", err)
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
return
}
// Apply cursor filter to query
if cursorFilter != "" {
logger.Debug("Applying cursor filter: %s", cursorFilter)
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
if sanitizedCursor != "" {
query = query.Where(sanitizedCursor)
}
}
}
// Get total count before pagination
var total int
// Try to get from cache first
// Use extended cache key if cursors are present
var cacheKeyHash string
if len(options.CursorForward) > 0 || len(options.CursorBackward) > 0 {
cacheKeyHash = cache.BuildExtendedQueryCacheKey(
tableName,
options.Filters,
options.Sort,
"", // No custom SQL WHERE in resolvespec
"", // No custom SQL OR in resolvespec
nil, // No expand options in resolvespec
false, // distinct not used here
options.CursorForward,
options.CursorBackward,
)
} else {
cacheKeyHash = cache.BuildQueryCacheKey(
tableName,
options.Filters,
options.Sort,
"", // No custom SQL WHERE in resolvespec
"", // No custom SQL OR in resolvespec
)
}
cacheKey := cache.GetQueryTotalCacheKey(cacheKeyHash)
// Try to retrieve from cache
var cachedTotal cache.CachedTotal
err := cache.GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
if err == nil {
total = cachedTotal.Total
logger.Debug("Total records (from cache): %d", total)
} else {
// Cache miss - execute count query
logger.Debug("Cache miss for query total")
count, err := query.Count(ctx)
if err != nil {
logger.Error("Error counting records: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
return
}
total = count
logger.Debug("Total records (from query): %d", total)
// Store in cache
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
cacheData := cache.CachedTotal{Total: total}
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
logger.Warn("Failed to cache query total: %v", err)
// Don't fail the request if caching fails
} else {
logger.Debug("Cached query total with key: %s", cacheKey)
}
} }
logger.Debug("Total records before filtering: %d", total)
// Apply pagination // Apply pagination
if options.Limit != nil && *options.Limit > 0 { if options.Limit != nil && *options.Limit > 0 {
@@ -1105,7 +1252,7 @@ type relationshipInfo struct {
relatedModel interface{} relatedModel interface{}
} }
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) common.SelectQuery { func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
// Unwrap pointers, slices, and arrays to get to the base struct type // Unwrap pointers, slices, and arrays to get to the base struct type
@@ -1116,7 +1263,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
// Validate that we have a struct type // Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct { if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Warn("Cannot apply preloads to non-struct type: %v", modelType) logger.Warn("Cannot apply preloads to non-struct type: %v", modelType)
return query return query, nil
} }
for idx := range preloads { for idx := range preloads {
@@ -1132,15 +1279,25 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
// ORMs like GORM and Bun expect the struct field name, not the JSON name // ORMs like GORM and Bun expect the struct field name, not the JSON name
relationFieldName := relInfo.fieldName relationFieldName := relInfo.fieldName
// For now, we'll preload without conditions // Validate and fix WHERE clause to ensure it contains the relation prefix
// TODO: Implement column selection and filtering for preloads if len(preload.Where) > 0 {
// This requires a more sophisticated approach with callbacks or query builders fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, relationFieldName)
// Apply preloading if err != nil {
logger.Error("Invalid preload WHERE clause for relation '%s': %v", relationFieldName, err)
return query, fmt.Errorf("invalid preload WHERE clause for relation '%s': %w", relationFieldName, err)
}
preload.Where = fixedWhere
}
logger.Debug("Applying preload: %s", relationFieldName) logger.Debug("Applying preload: %s", relationFieldName)
query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery { query = query.PreloadRelation(relationFieldName, func(sq common.SelectQuery) common.SelectQuery {
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
preload.Columns = reflection.GetSQLModelColumns(model)
}
// Handle column selection and omission
if len(preload.OmitColumns) > 0 { if len(preload.OmitColumns) > 0 {
allCols := reflection.GetModelColumns(model) allCols := reflection.GetSQLModelColumns(model)
// Remove omitted columns // Remove omitted columns
preload.Columns = []string{} preload.Columns = []string{}
for _, col := range allCols { for _, col := range allCols {
@@ -1194,7 +1351,12 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
} }
if len(preload.Where) > 0 { if len(preload.Where) > 0 {
sq = sq.Where(preload.Where) // Build RequestOptions with all preloads to allow references to sibling relations
preloadOpts := &common.RequestOptions{Preload: preloads}
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere)
}
} }
if preload.Limit != nil && *preload.Limit > 0 { if preload.Limit != nil && *preload.Limit > 0 {
@@ -1207,7 +1369,7 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
logger.Debug("Applied Preload for relation: %s (field: %s)", preload.Relation, relationFieldName) logger.Debug("Applied Preload for relation: %s (field: %s)", preload.Relation, relationFieldName)
} }
return query return query, nil
} }
func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo { func (h *Handler) getRelationshipInfo(modelType reflect.Type, relationName string) *relationshipInfo {
@@ -1286,3 +1448,31 @@ func toSnakeCase(s string) string {
} }
return strings.ToLower(result.String()) return strings.ToLower(result.String())
} }
// HandleOpenAPI generates and returns the OpenAPI specification
func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
if h.openAPIGenerator == nil {
logger.Error("OpenAPI generator not configured")
h.sendError(w, http.StatusInternalServerError, "openapi_not_configured", "OpenAPI generation not configured", nil)
return
}
spec, err := h.openAPIGenerator()
if err != nil {
logger.Error("Failed to generate OpenAPI spec: %v", err)
h.sendError(w, http.StatusInternalServerError, "openapi_generation_error", "Failed to generate OpenAPI specification", err)
return
}
w.SetHeader("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, err = w.Write([]byte(spec))
if err != nil {
logger.Error("Error sending OpenAPI spec response: %v", err)
}
}
// SetOpenAPIGenerator sets the OpenAPI generator function
func (h *Handler) SetOpenAPIGenerator(generator func() (string, error)) {
h.openAPIGenerator = generator
}

View File

@@ -0,0 +1,367 @@
package resolvespec
import (
"reflect"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestNewHandler(t *testing.T) {
// Note: We can't create a real handler without actual DB and registry
// But we can test that the constructor doesn't panic with nil values
handler := NewHandler(nil, nil)
if handler == nil {
t.Error("Expected handler to be created, got nil")
}
if handler.hooks == nil {
t.Error("Expected hooks registry to be initialized")
}
}
func TestHandlerHooks(t *testing.T) {
handler := NewHandler(nil, nil)
hooks := handler.Hooks()
if hooks == nil {
t.Error("Expected hooks registry, got nil")
}
}
func TestSetFallbackHandler(t *testing.T) {
handler := NewHandler(nil, nil)
// We can't directly call the fallback without mocks, but we can verify it's set
handler.SetFallbackHandler(func(w common.ResponseWriter, r common.Request, params map[string]string) {
// Fallback handler implementation
})
if handler.fallbackHandler == nil {
t.Error("Expected fallback handler to be set")
}
}
func TestGetDatabase(t *testing.T) {
handler := NewHandler(nil, nil)
db := handler.GetDatabase()
// Should return nil since we passed nil
if db != nil {
t.Error("Expected nil database")
}
}
func TestParseTableName(t *testing.T) {
handler := NewHandler(nil, nil)
tests := []struct {
name string
fullTableName string
expectedSchema string
expectedTable string
}{
{
name: "Table with schema",
fullTableName: "public.users",
expectedSchema: "public",
expectedTable: "users",
},
{
name: "Table without schema",
fullTableName: "users",
expectedSchema: "",
expectedTable: "users",
},
{
name: "Multiple dots (use last)",
fullTableName: "db.public.users",
expectedSchema: "db.public",
expectedTable: "users",
},
{
name: "Empty string",
fullTableName: "",
expectedSchema: "",
expectedTable: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
schema, table := handler.parseTableName(tt.fullTableName)
if schema != tt.expectedSchema {
t.Errorf("Expected schema '%s', got '%s'", tt.expectedSchema, schema)
}
if table != tt.expectedTable {
t.Errorf("Expected table '%s', got '%s'", tt.expectedTable, table)
}
})
}
}
func TestGetColumnType(t *testing.T) {
tests := []struct {
name string
field reflect.StructField
expectedType string
}{
{
name: "String field",
field: reflect.StructField{
Name: "Name",
Type: reflect.TypeOf(""),
},
expectedType: "string",
},
{
name: "Int field",
field: reflect.StructField{
Name: "Count",
Type: reflect.TypeOf(int(0)),
},
expectedType: "integer",
},
{
name: "Int32 field",
field: reflect.StructField{
Name: "ID",
Type: reflect.TypeOf(int32(0)),
},
expectedType: "integer",
},
{
name: "Int64 field",
field: reflect.StructField{
Name: "BigID",
Type: reflect.TypeOf(int64(0)),
},
expectedType: "bigint",
},
{
name: "Float32 field",
field: reflect.StructField{
Name: "Price",
Type: reflect.TypeOf(float32(0)),
},
expectedType: "float",
},
{
name: "Float64 field",
field: reflect.StructField{
Name: "Amount",
Type: reflect.TypeOf(float64(0)),
},
expectedType: "double",
},
{
name: "Bool field",
field: reflect.StructField{
Name: "Active",
Type: reflect.TypeOf(false),
},
expectedType: "boolean",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
colType := getColumnType(tt.field)
if colType != tt.expectedType {
t.Errorf("Expected column type '%s', got '%s'", tt.expectedType, colType)
}
})
}
}
func TestIsNullable(t *testing.T) {
tests := []struct {
name string
field reflect.StructField
nullable bool
}{
{
name: "Pointer type is nullable",
field: reflect.StructField{
Name: "Name",
Type: reflect.TypeOf((*string)(nil)),
},
nullable: true,
},
{
name: "Non-pointer type without explicit 'not null' tag",
field: reflect.StructField{
Name: "ID",
Type: reflect.TypeOf(int(0)),
},
nullable: true, // isNullable returns true if there's no explicit "not null" tag
},
{
name: "Field with 'not null' tag is not nullable",
field: reflect.StructField{
Name: "Email",
Type: reflect.TypeOf(""),
Tag: `gorm:"not null"`,
},
nullable: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isNullable(tt.field)
if result != tt.nullable {
t.Errorf("Expected nullable=%v, got %v", tt.nullable, result)
}
})
}
}
func TestToSnakeCase(t *testing.T) {
tests := []struct {
input string
expected string
}{
{
input: "UserID",
expected: "user_id",
},
{
input: "DepartmentName",
expected: "department_name",
},
{
input: "ID",
expected: "id",
},
{
input: "HTTPServer",
expected: "http_server",
},
{
input: "createdAt",
expected: "created_at",
},
{
input: "name",
expected: "name",
},
{
input: "",
expected: "",
},
{
input: "A",
expected: "a",
},
{
input: "APIKey",
expected: "api_key",
},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := toSnakeCase(tt.input)
if result != tt.expected {
t.Errorf("toSnakeCase(%q) = %q, expected %q", tt.input, result, tt.expected)
}
})
}
}
func TestExtractTagValue(t *testing.T) {
handler := NewHandler(nil, nil)
tests := []struct {
name string
tag string
key string
expected string
}{
{
name: "Extract foreignKey",
tag: "foreignKey:UserID;references:ID",
key: "foreignKey",
expected: "UserID",
},
{
name: "Extract references",
tag: "foreignKey:UserID;references:ID",
key: "references",
expected: "ID",
},
{
name: "Key not found",
tag: "foreignKey:UserID;references:ID",
key: "notfound",
expected: "",
},
{
name: "Empty tag",
tag: "",
key: "foreignKey",
expected: "",
},
{
name: "Single value",
tag: "many2many:user_roles",
key: "many2many",
expected: "user_roles",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.extractTagValue(tt.tag, tt.key)
if result != tt.expected {
t.Errorf("extractTagValue(%q, %q) = %q, expected %q", tt.tag, tt.key, result, tt.expected)
}
})
}
}
func TestApplyFilter(t *testing.T) {
// Note: Without a real database, we can't fully test query execution
// But we can test that the method exists
_ = NewHandler(nil, nil)
// The applyFilter method exists and can be tested with actual queries
// but requires database setup which is beyond unit test scope
t.Log("applyFilter method exists and is used in handler operations")
}
func TestShouldUseNestedProcessor(t *testing.T) {
handler := NewHandler(nil, nil)
tests := []struct {
name string
data map[string]interface{}
expected bool
}{
{
name: "Has _request field",
data: map[string]interface{}{
"_request": "nested",
"name": "test",
},
expected: true,
},
{
name: "No special fields",
data: map[string]interface{}{
"name": "test",
"email": "test@example.com",
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Note: Without a real model, we can't fully test this
// But we can verify the function exists
result := handler.shouldUseNestedProcessor(tt.data, nil)
// The actual result depends on the model structure
_ = result
})
}
}

152
pkg/resolvespec/hooks.go Normal file
View File

@@ -0,0 +1,152 @@
package resolvespec
import (
"context"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// HookType defines the type of hook to execute
type HookType string
const (
// Read operation hooks
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
// Create operation hooks
BeforeCreate HookType = "before_create"
AfterCreate HookType = "after_create"
// Update operation hooks
BeforeUpdate HookType = "before_update"
AfterUpdate HookType = "after_update"
// Delete operation hooks
BeforeDelete HookType = "before_delete"
AfterDelete HookType = "after_delete"
// Scan/Execute operation hooks (for query building)
BeforeScan HookType = "before_scan"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler // Reference to the handler for accessing database, registry, etc.
Schema string
Entity string
Model interface{}
Options common.RequestOptions
Writer common.ResponseWriter
Request common.Request
// Operation-specific fields
ID string
Data interface{} // For create/update operations
Result interface{} // For after hooks
Error error // For after hooks
// Query chain - allows hooks to modify the query before execution
Query common.SelectQuery
// Allow hooks to abort the operation
Abort bool // If set to true, the operation will be aborted
AbortMessage string // Message to return if aborted
AbortCode int // HTTP status code if aborted
}
// HookFunc is the signature for hook functions
// It receives a HookContext and can modify it or return an error
// If an error is returned, the operation will be aborted
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
// NewHookRegistry creates a new hook registry
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
// Register adds a new hook for the specified hook type
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered resolvespec hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
// RegisterMultiple registers a hook for multiple hook types
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
// Execute runs all hooks for the specified type in order
// If any hook returns an error, execution stops and the error is returned
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
return nil
}
logger.Debug("Executing %d resolvespec hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("Resolvespec hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
// Check if hook requested abort
if ctx.Abort {
logger.Warn("Resolvespec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil
}
// Clear removes all hooks for the specified type
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
logger.Info("Cleared all resolvespec hooks for %s", hookType)
}
// ClearAll removes all registered hooks
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
logger.Info("Cleared all resolvespec hooks")
}
// Count returns the number of hooks registered for a specific type
func (r *HookRegistry) Count(hookType HookType) int {
if hooks, exists := r.hooks[hookType]; exists {
return len(hooks)
}
return 0
}
// HasHooks returns true if there are any hooks registered for the specified type
func (r *HookRegistry) HasHooks(hookType HookType) bool {
return r.Count(hookType) > 0
}
// GetAllHookTypes returns all hook types that have registered hooks
func (r *HookRegistry) GetAllHookTypes() []HookType {
types := make([]HookType, 0, len(r.hooks))
for hookType := range r.hooks {
types = append(types, hookType)
}
return types
}

View File

@@ -0,0 +1,400 @@
package resolvespec
import (
"context"
"fmt"
"testing"
)
func TestHookRegistry(t *testing.T) {
registry := NewHookRegistry()
// Test registering a hook
called := false
hook := func(ctx *HookContext) error {
called = true
return nil
}
registry.Register(BeforeRead, hook)
if registry.Count(BeforeRead) != 1 {
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeRead))
}
// Test executing a hook
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeRead, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if !called {
t.Error("Hook was not called")
}
}
func TestHookExecutionOrder(t *testing.T) {
registry := NewHookRegistry()
order := []int{}
hook1 := func(ctx *HookContext) error {
order = append(order, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
order = append(order, 2)
return nil
}
hook3 := func(ctx *HookContext) error {
order = append(order, 3)
return nil
}
registry.Register(BeforeCreate, hook1)
registry.Register(BeforeCreate, hook2)
registry.Register(BeforeCreate, hook3)
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeCreate, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if len(order) != 3 {
t.Errorf("Expected 3 hooks to be called, got %d", len(order))
}
if order[0] != 1 || order[1] != 2 || order[2] != 3 {
t.Errorf("Hooks executed in wrong order: %v", order)
}
}
func TestHookError(t *testing.T) {
registry := NewHookRegistry()
executed := []string{}
hook1 := func(ctx *HookContext) error {
executed = append(executed, "hook1")
return nil
}
hook2 := func(ctx *HookContext) error {
executed = append(executed, "hook2")
return fmt.Errorf("hook2 error")
}
hook3 := func(ctx *HookContext) error {
executed = append(executed, "hook3")
return nil
}
registry.Register(BeforeUpdate, hook1)
registry.Register(BeforeUpdate, hook2)
registry.Register(BeforeUpdate, hook3)
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeUpdate, ctx)
if err == nil {
t.Error("Expected error from hook execution")
}
if len(executed) != 2 {
t.Errorf("Expected only 2 hooks to be executed, got %d", len(executed))
}
if executed[0] != "hook1" || executed[1] != "hook2" {
t.Errorf("Unexpected execution order: %v", executed)
}
}
func TestHookDataModification(t *testing.T) {
registry := NewHookRegistry()
modifyHook := func(ctx *HookContext) error {
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
dataMap["modified"] = true
ctx.Data = dataMap
}
return nil
}
registry.Register(BeforeCreate, modifyHook)
data := map[string]interface{}{
"name": "test",
}
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
Data: data,
}
err := registry.Execute(BeforeCreate, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
modifiedData := ctx.Data.(map[string]interface{})
if !modifiedData["modified"].(bool) {
t.Error("Data was not modified by hook")
}
}
func TestRegisterMultiple(t *testing.T) {
registry := NewHookRegistry()
called := 0
hook := func(ctx *HookContext) error {
called++
return nil
}
registry.RegisterMultiple([]HookType{
BeforeRead,
BeforeCreate,
BeforeUpdate,
}, hook)
if registry.Count(BeforeRead) != 1 {
t.Error("Hook not registered for BeforeRead")
}
if registry.Count(BeforeCreate) != 1 {
t.Error("Hook not registered for BeforeCreate")
}
if registry.Count(BeforeUpdate) != 1 {
t.Error("Hook not registered for BeforeUpdate")
}
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
registry.Execute(BeforeRead, ctx)
registry.Execute(BeforeCreate, ctx)
registry.Execute(BeforeUpdate, ctx)
if called != 3 {
t.Errorf("Expected hook to be called 3 times, got %d", called)
}
}
func TestClearHooks(t *testing.T) {
registry := NewHookRegistry()
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
registry.Register(BeforeCreate, hook)
if registry.Count(BeforeRead) != 1 {
t.Error("Hook not registered")
}
registry.Clear(BeforeRead)
if registry.Count(BeforeRead) != 0 {
t.Error("Hook not cleared")
}
if registry.Count(BeforeCreate) != 1 {
t.Error("Wrong hook was cleared")
}
}
func TestClearAllHooks(t *testing.T) {
registry := NewHookRegistry()
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
registry.Register(BeforeCreate, hook)
registry.Register(BeforeUpdate, hook)
registry.ClearAll()
if registry.Count(BeforeRead) != 0 || registry.Count(BeforeCreate) != 0 || registry.Count(BeforeUpdate) != 0 {
t.Error("Not all hooks were cleared")
}
}
func TestHasHooks(t *testing.T) {
registry := NewHookRegistry()
if registry.HasHooks(BeforeRead) {
t.Error("Should not have hooks initially")
}
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
if !registry.HasHooks(BeforeRead) {
t.Error("Should have hooks after registration")
}
}
func TestGetAllHookTypes(t *testing.T) {
registry := NewHookRegistry()
hook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeRead, hook)
registry.Register(BeforeCreate, hook)
registry.Register(AfterUpdate, hook)
types := registry.GetAllHookTypes()
if len(types) != 3 {
t.Errorf("Expected 3 hook types, got %d", len(types))
}
// Verify all expected types are present
expectedTypes := map[HookType]bool{
BeforeRead: true,
BeforeCreate: true,
AfterUpdate: true,
}
for _, hookType := range types {
if !expectedTypes[hookType] {
t.Errorf("Unexpected hook type: %s", hookType)
}
}
}
func TestHookContextHandler(t *testing.T) {
registry := NewHookRegistry()
var capturedHandler *Handler
hook := func(ctx *HookContext) error {
if ctx.Handler == nil {
return fmt.Errorf("handler is nil in hook context")
}
capturedHandler = ctx.Handler
return nil
}
registry.Register(BeforeRead, hook)
handler := &Handler{
hooks: registry,
}
ctx := &HookContext{
Context: context.Background(),
Handler: handler,
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeRead, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if capturedHandler == nil {
t.Error("Handler was not captured from hook context")
}
if capturedHandler != handler {
t.Error("Captured handler does not match original handler")
}
}
func TestHookAbort(t *testing.T) {
registry := NewHookRegistry()
abortHook := func(ctx *HookContext) error {
ctx.Abort = true
ctx.AbortMessage = "Operation aborted by hook"
ctx.AbortCode = 403
return nil
}
registry.Register(BeforeCreate, abortHook)
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
err := registry.Execute(BeforeCreate, ctx)
if err == nil {
t.Error("Expected error when hook sets Abort=true")
}
if err.Error() != "operation aborted by hook: Operation aborted by hook" {
t.Errorf("Expected abort error message, got: %v", err)
}
}
func TestHookTypes(t *testing.T) {
// Test all hook type constants
hookTypes := []HookType{
BeforeRead,
AfterRead,
BeforeCreate,
AfterCreate,
BeforeUpdate,
AfterUpdate,
BeforeDelete,
AfterDelete,
BeforeScan,
}
for _, hookType := range hookTypes {
if string(hookType) == "" {
t.Errorf("Hook type should not be empty: %v", hookType)
}
}
}
func TestExecuteWithNoHooks(t *testing.T) {
registry := NewHookRegistry()
ctx := &HookContext{
Context: context.Background(),
Schema: "test",
Entity: "users",
}
// Executing with no registered hooks should not cause an error
err := registry.Execute(BeforeRead, ctx)
if err != nil {
t.Errorf("Execute should not fail with no hooks, got: %v", err)
}
}

View File

@@ -0,0 +1,508 @@
// +build integration
package resolvespec
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/gorilla/mux"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Test models
type TestUser struct {
ID uint `gorm:"primaryKey" json:"id"`
Name string `gorm:"not null" json:"name"`
Email string `gorm:"uniqueIndex;not null" json:"email"`
Age int `json:"age"`
Active bool `gorm:"default:true" json:"active"`
CreatedAt time.Time `json:"created_at"`
Posts []TestPost `gorm:"foreignKey:UserID" json:"posts,omitempty"`
}
func (TestUser) TableName() string {
return "test_users"
}
type TestPost struct {
ID uint `gorm:"primaryKey" json:"id"`
UserID uint `gorm:"not null" json:"user_id"`
Title string `gorm:"not null" json:"title"`
Content string `json:"content"`
Published bool `gorm:"default:false" json:"published"`
CreatedAt time.Time `json:"created_at"`
User *TestUser `gorm:"foreignKey:UserID" json:"user,omitempty"`
Comments []TestComment `gorm:"foreignKey:PostID" json:"comments,omitempty"`
}
func (TestPost) TableName() string {
return "test_posts"
}
type TestComment struct {
ID uint `gorm:"primaryKey" json:"id"`
PostID uint `gorm:"not null" json:"post_id"`
Content string `gorm:"not null" json:"content"`
CreatedAt time.Time `json:"created_at"`
Post *TestPost `gorm:"foreignKey:PostID" json:"post,omitempty"`
}
func (TestComment) TableName() string {
return "test_comments"
}
// Test helper functions
func setupTestDB(t *testing.T) *gorm.DB {
// Get connection string from environment or use default
dsn := os.Getenv("TEST_DATABASE_URL")
if dsn == "" {
dsn = "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable"
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Skipf("Skipping integration test: database not available: %v", err)
return nil
}
// Run migrations
err = db.AutoMigrate(&TestUser{}, &TestPost{}, &TestComment{})
if err != nil {
t.Skipf("Skipping integration test: failed to migrate database: %v", err)
return nil
}
return db
}
func cleanupTestDB(t *testing.T, db *gorm.DB) {
// Clean up test data
db.Exec("TRUNCATE TABLE test_comments CASCADE")
db.Exec("TRUNCATE TABLE test_posts CASCADE")
db.Exec("TRUNCATE TABLE test_users CASCADE")
}
func createTestData(t *testing.T, db *gorm.DB) {
users := []TestUser{
{Name: "John Doe", Email: "john@example.com", Age: 30, Active: true},
{Name: "Jane Smith", Email: "jane@example.com", Age: 25, Active: true},
{Name: "Bob Johnson", Email: "bob@example.com", Age: 35, Active: false},
}
for i := range users {
if err := db.Create(&users[i]).Error; err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
}
posts := []TestPost{
{UserID: users[0].ID, Title: "First Post", Content: "Hello World", Published: true},
{UserID: users[0].ID, Title: "Second Post", Content: "More content", Published: true},
{UserID: users[1].ID, Title: "Jane's Post", Content: "Jane's content", Published: false},
}
for i := range posts {
if err := db.Create(&posts[i]).Error; err != nil {
t.Fatalf("Failed to create test post: %v", err)
}
}
comments := []TestComment{
{PostID: posts[0].ID, Content: "Great post!"},
{PostID: posts[0].ID, Content: "Thanks for sharing"},
{PostID: posts[1].ID, Content: "Interesting"},
}
for i := range comments {
if err := db.Create(&comments[i]).Error; err != nil {
t.Fatalf("Failed to create test comment: %v", err)
}
}
}
// Integration tests
func TestIntegration_CreateOperation(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Create a new user
requestBody := map[string]interface{}{
"operation": "create",
"data": map[string]interface{}{
"name": "Test User",
"email": "test@example.com",
"age": 28,
},
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v. Error: %v", response.Success, response.Error)
}
// Verify user was created
var user TestUser
if err := db.Where("email = ?", "test@example.com").First(&user).Error; err != nil {
t.Errorf("Failed to find created user: %v", err)
}
if user.Name != "Test User" {
t.Errorf("Expected name 'Test User', got '%s'", user.Name)
}
}
func TestIntegration_ReadOperation(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Read all users
requestBody := map[string]interface{}{
"operation": "read",
"options": map[string]interface{}{
"limit": 10,
},
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
if response.Metadata == nil {
t.Fatal("Expected metadata, got nil")
}
if response.Metadata.Total != 3 {
t.Errorf("Expected 3 users, got %d", response.Metadata.Total)
}
}
func TestIntegration_ReadWithFilters(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Read users with age > 25
requestBody := map[string]interface{}{
"operation": "read",
"options": map[string]interface{}{
"filters": []map[string]interface{}{
{
"column": "age",
"operator": "gt",
"value": 25,
},
},
},
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d", w.Code)
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Should return 2 users (John: 30, Bob: 35)
if response.Metadata.Total != 2 {
t.Errorf("Expected 2 filtered users, got %d", response.Metadata.Total)
}
}
func TestIntegration_UpdateOperation(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Get user ID
var user TestUser
db.Where("email = ?", "john@example.com").First(&user)
// Update user
requestBody := map[string]interface{}{
"operation": "update",
"data": map[string]interface{}{
"id": user.ID,
"age": 31,
"name": "John Doe Updated",
},
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/public/test_users/%d", user.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
// Verify update
var updatedUser TestUser
db.First(&updatedUser, user.ID)
if updatedUser.Age != 31 {
t.Errorf("Expected age 31, got %d", updatedUser.Age)
}
if updatedUser.Name != "John Doe Updated" {
t.Errorf("Expected name 'John Doe Updated', got '%s'", updatedUser.Name)
}
}
func TestIntegration_DeleteOperation(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Get user ID
var user TestUser
db.Where("email = ?", "bob@example.com").First(&user)
// Delete user
requestBody := map[string]interface{}{
"operation": "delete",
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", fmt.Sprintf("/public/test_users/%d", user.ID), bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
// Verify deletion
var count int64
db.Model(&TestUser{}).Where("id = ?", user.ID).Count(&count)
if count != 0 {
t.Errorf("Expected user to be deleted, but found %d records", count)
}
}
func TestIntegration_MetadataOperation(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Get metadata
requestBody := map[string]interface{}{
"operation": "meta",
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Check that metadata includes columns
// The response.Data is an interface{}, we need to unmarshal it properly
dataBytes, _ := json.Marshal(response.Data)
var metadata common.TableMetadata
if err := json.Unmarshal(dataBytes, &metadata); err != nil {
t.Fatalf("Failed to unmarshal metadata: %v. Raw data: %+v", err, response.Data)
}
if len(metadata.Columns) == 0 {
t.Error("Expected metadata to contain columns")
}
// Verify some expected columns
hasID := false
hasName := false
hasEmail := false
for _, col := range metadata.Columns {
if col.Name == "id" {
hasID = true
if !col.IsPrimary {
t.Error("Expected 'id' column to be primary key")
}
}
if col.Name == "name" {
hasName = true
}
if col.Name == "email" {
hasEmail = true
}
}
if !hasID || !hasName || !hasEmail {
t.Error("Expected metadata to contain 'id', 'name', and 'email' columns")
}
}
func TestIntegration_ReadWithPreload(t *testing.T) {
db := setupTestDB(t)
defer cleanupTestDB(t, db)
createTestData(t, db)
handler := NewHandlerWithGORM(db)
handler.RegisterModel("public", "test_users", TestUser{})
handler.RegisterModel("public", "test_posts", TestPost{})
handler.RegisterModel("public", "test_comments", TestComment{})
muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler, nil)
// Read users with posts preloaded
requestBody := map[string]interface{}{
"operation": "read",
"options": map[string]interface{}{
"filters": []map[string]interface{}{
{
"column": "email",
"operator": "eq",
"value": "john@example.com",
},
},
"preload": []map[string]interface{}{
{"relation": "posts"},
},
},
}
body, _ := json.Marshal(requestBody)
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
muxRouter.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
}
var response common.Response
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse response: %v", err)
}
if !response.Success {
t.Errorf("Expected success=true, got %v", response.Success)
}
// Verify posts are preloaded
dataBytes, _ := json.Marshal(response.Data)
var users []TestUser
json.Unmarshal(dataBytes, &users)
if len(users) == 0 {
t.Fatal("Expected at least one user")
}
if len(users[0].Posts) == 0 {
t.Error("Expected posts to be preloaded")
}
}

View File

@@ -2,12 +2,14 @@ package resolvespec
import ( import (
"net/http" "net/http"
"strings"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"github.com/uptrace/bunrouter" "github.com/uptrace/bunrouter"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database" "github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/router" "github.com/bitechdev/ResolveSpec/pkg/common/adapters/router"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry" "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
@@ -37,28 +39,132 @@ func NewStandardBunRouter() *router.StandardBunRouterAdapter {
return router.NewStandardBunRouterAdapter() return router.NewStandardBunRouterAdapter()
} }
// MiddlewareFunc is a function that wraps an http.Handler with additional functionality
type MiddlewareFunc func(http.Handler) http.Handler
// SetupMuxRoutes sets up routes for the ResolveSpec API with Mux // SetupMuxRoutes sets up routes for the ResolveSpec API with Mux
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) { // authMiddleware is optional - if provided, routes will be protected with the middleware
muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) { // Example: SetupMuxRoutes(router, handler, func(h http.Handler) http.Handler { return security.NewAuthHandler(securityList, h) })
vars := mux.Vars(r) func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware MiddlewareFunc) {
reqAdapter := router.NewHTTPRequest(r) // Add global /openapi route
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w) respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, vars) common.SetCORSHeaders(respAdapter, corsConfig)
}).Methods("POST") reqAdapter := router.NewHTTPRequest(r)
handler.HandleOpenAPI(respAdapter, reqAdapter)
})
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
muxRouter.HandleFunc("/{schema}/{entity}/{id}", func(w http.ResponseWriter, r *http.Request) { // Get all registered models from the registry
vars := mux.Vars(r) allModels := handler.registry.GetAllModels()
reqAdapter := router.NewHTTPRequest(r)
respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, vars)
}).Methods("POST")
muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) { // Loop through each registered model and create explicit routes
vars := mux.Vars(r) for fullName := range allModels {
reqAdapter := router.NewHTTPRequest(r) // Parse the full name (e.g., "public.users" or just "users")
schema, entity := parseModelName(fullName)
// Build the route paths
entityPath := buildRoutePath(schema, entity)
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
// Create handler functions for this specific entity
postEntityHandler := createMuxHandler(handler, schema, entity, "")
postEntityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
getEntityHandler := createMuxGetHandler(handler, schema, entity, "")
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
// Apply authentication middleware if provided
if authMiddleware != nil {
postEntityHandler = authMiddleware(postEntityHandler).(http.HandlerFunc)
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler).(http.HandlerFunc)
getEntityHandler = authMiddleware(getEntityHandler).(http.HandlerFunc)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
}
// Register routes for this entity
muxRouter.Handle(entityPath, postEntityHandler).Methods("POST")
muxRouter.Handle(entityWithIDPath, postEntityWithIDHandler).Methods("POST")
muxRouter.Handle(entityPath, getEntityHandler).Methods("GET")
muxRouter.Handle(entityPath, optionsEntityHandler).Methods("OPTIONS")
muxRouter.Handle(entityWithIDPath, optionsEntityWithIDHandler).Methods("OPTIONS")
}
}
// Helper function to create Mux handler for a specific entity with CORS support
func createMuxHandler(handler *Handler, schema, entity, idParam string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w) respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
vars := make(map[string]string)
vars["schema"] = schema
vars["entity"] = entity
if idParam != "" {
vars["id"] = mux.Vars(r)[idParam]
}
reqAdapter := router.NewHTTPRequest(r)
handler.Handle(respAdapter, reqAdapter, vars)
}
}
// Helper function to create Mux GET handler for a specific entity with CORS support
func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers
corsConfig := common.DefaultCORSConfig()
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
vars := make(map[string]string)
vars["schema"] = schema
vars["entity"] = entity
if idParam != "" {
vars["id"] = mux.Vars(r)[idParam]
}
reqAdapter := router.NewHTTPRequest(r)
handler.HandleGet(respAdapter, reqAdapter, vars) handler.HandleGet(respAdapter, reqAdapter, vars)
}).Methods("GET") }
}
// Helper function to create Mux OPTIONS handler that returns metadata
func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMethods []string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Set CORS headers with the allowed methods for this route
corsConfig := common.DefaultCORSConfig()
corsConfig.AllowedMethods = allowedMethods
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
// Return metadata in the OPTIONS response body
vars := make(map[string]string)
vars["schema"] = schema
vars["entity"] = entity
reqAdapter := router.NewHTTPRequest(r)
handler.HandleGet(respAdapter, reqAdapter, vars)
}
}
// parseModelName parses a model name like "public.users" into schema and entity
// If no schema is present, returns empty string for schema
func parseModelName(fullName string) (schema, entity string) {
parts := strings.Split(fullName, ".")
if len(parts) == 2 {
return parts[0], parts[1]
}
return "", fullName
}
// buildRoutePath builds a route path from schema and entity
// If schema is empty, returns just "/entity", otherwise "/{schema}/{entity}"
func buildRoutePath(schema, entity string) string {
if schema == "" {
return "/" + entity
}
return "/" + schema + "/" + entity
} }
// Example usage functions for documentation: // Example usage functions for documentation:
@@ -68,12 +174,20 @@ func ExampleWithGORM(db *gorm.DB) {
// Create handler using GORM // Create handler using GORM
handler := NewHandlerWithGORM(db) handler := NewHandlerWithGORM(db)
// Setup router // Setup router without authentication
muxRouter := mux.NewRouter() muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler) SetupMuxRoutes(muxRouter, handler, nil)
// Register models // Register models
// handler.RegisterModel("public", "users", &User{}) // handler.RegisterModel("public", "users", &User{})
// To add authentication, pass a middleware function:
// import "github.com/bitechdev/ResolveSpec/pkg/security"
// secList := security.NewSecurityList(myProvider)
// authMiddleware := func(h http.Handler) http.Handler {
// return security.NewAuthHandler(secList, h)
// }
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
} }
// ExampleWithBun shows how to switch to Bun ORM // ExampleWithBun shows how to switch to Bun ORM
@@ -88,60 +202,133 @@ func ExampleWithBun(bunDB *bun.DB) {
// Create handler // Create handler
handler := NewHandler(dbAdapter, registry) handler := NewHandler(dbAdapter, registry)
// Setup routes // Setup routes without authentication
muxRouter := mux.NewRouter() muxRouter := mux.NewRouter()
SetupMuxRoutes(muxRouter, handler) SetupMuxRoutes(muxRouter, handler, nil)
} }
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API // SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) { func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) {
r := bunRouter.GetBunRouter() r := bunRouter.GetBunRouter()
r.Handle("POST", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error { // CORS config
params := map[string]string{ corsConfig := common.DefaultCORSConfig()
"schema": req.Param("schema"),
"entity": req.Param("entity"), // Add global /openapi route
} r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
reqAdapter := router.NewHTTPRequest(req.Request)
respAdapter := router.NewHTTPResponseWriter(w) respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params) common.SetCORSHeaders(respAdapter, corsConfig)
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleOpenAPI(respAdapter, reqAdapter)
return nil return nil
}) })
r.Handle("POST", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error { r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
params := map[string]string{
"schema": req.Param("schema"),
"entity": req.Param("entity"),
"id": req.Param("id"),
}
reqAdapter := router.NewHTTPRequest(req.Request)
respAdapter := router.NewHTTPResponseWriter(w) respAdapter := router.NewHTTPResponseWriter(w)
handler.Handle(respAdapter, reqAdapter, params) common.SetCORSHeaders(respAdapter, corsConfig)
return nil return nil
}) })
r.Handle("GET", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error { // Get all registered models from the registry
params := map[string]string{ allModels := handler.registry.GetAllModels()
"schema": req.Param("schema"),
"entity": req.Param("entity"),
}
reqAdapter := router.NewHTTPRequest(req.Request)
respAdapter := router.NewHTTPResponseWriter(w)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
r.Handle("GET", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error { // Loop through each registered model and create explicit routes
params := map[string]string{ for fullName := range allModels {
"schema": req.Param("schema"), // Parse the full name (e.g., "public.users" or just "users")
"entity": req.Param("entity"), schema, entity := parseModelName(fullName)
"id": req.Param("id"),
} // Build the route paths
reqAdapter := router.NewHTTPRequest(req.Request) entityPath := buildRoutePath(schema, entity)
respAdapter := router.NewHTTPResponseWriter(w) entityWithIDPath := entityPath + "/:id"
handler.HandleGet(respAdapter, reqAdapter, params)
return nil // Create closure variables to capture current schema and entity
}) currentSchema := schema
currentEntity := entity
// POST route without ID
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// POST route with ID
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.Handle(respAdapter, reqAdapter, params)
return nil
})
// GET route without ID
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// GET route with ID
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
common.SetCORSHeaders(respAdapter, corsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
"id": req.Param("id"),
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// OPTIONS route without ID (returns metadata)
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
optionsCorsConfig := corsConfig
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
// OPTIONS route with ID (returns metadata)
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
respAdapter := router.NewHTTPResponseWriter(w)
optionsCorsConfig := corsConfig
optionsCorsConfig.AllowedMethods = []string{"POST", "OPTIONS"}
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
params := map[string]string{
"schema": currentSchema,
"entity": currentEntity,
}
reqAdapter := router.NewHTTPRequest(req.Request)
handler.HandleGet(respAdapter, reqAdapter, params)
return nil
})
}
} }
// ExampleWithBunRouter shows how to use bunrouter from uptrace // ExampleWithBunRouter shows how to use bunrouter from uptrace

View File

@@ -0,0 +1,114 @@
package resolvespec
import (
"testing"
)
func TestParseModelName(t *testing.T) {
tests := []struct {
name string
fullName string
expectedSchema string
expectedEntity string
}{
{
name: "Model with schema",
fullName: "public.users",
expectedSchema: "public",
expectedEntity: "users",
},
{
name: "Model without schema",
fullName: "users",
expectedSchema: "",
expectedEntity: "users",
},
{
name: "Model with custom schema",
fullName: "myschema.products",
expectedSchema: "myschema",
expectedEntity: "products",
},
{
name: "Empty string",
fullName: "",
expectedSchema: "",
expectedEntity: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
schema, entity := parseModelName(tt.fullName)
if schema != tt.expectedSchema {
t.Errorf("Expected schema '%s', got '%s'", tt.expectedSchema, schema)
}
if entity != tt.expectedEntity {
t.Errorf("Expected entity '%s', got '%s'", tt.expectedEntity, entity)
}
})
}
}
func TestBuildRoutePath(t *testing.T) {
tests := []struct {
name string
schema string
entity string
expectedPath string
}{
{
name: "With schema",
schema: "public",
entity: "users",
expectedPath: "/public/users",
},
{
name: "Without schema",
schema: "",
entity: "users",
expectedPath: "/users",
},
{
name: "Custom schema",
schema: "admin",
entity: "logs",
expectedPath: "/admin/logs",
},
{
name: "Empty entity with schema",
schema: "public",
entity: "",
expectedPath: "/public/",
},
{
name: "Both empty",
schema: "",
entity: "",
expectedPath: "/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
path := buildRoutePath(tt.schema, tt.entity)
if path != tt.expectedPath {
t.Errorf("Expected path '%s', got '%s'", tt.expectedPath, path)
}
})
}
}
func TestNewStandardMuxRouter(t *testing.T) {
router := NewStandardMuxRouter()
if router == nil {
t.Error("Expected router to be created, got nil")
}
}
func TestNewStandardBunRouter(t *testing.T) {
router := NewStandardBunRouter()
if router == nil {
t.Error("Expected router to be created, got nil")
}
}

View File

@@ -0,0 +1,85 @@
package resolvespec
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks registers all security-related hooks with the handler
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 1: BeforeRead - Load security rules
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.LoadSecurityRules(secCtx, securityList)
})
// Hook 2: BeforeScan - Apply row-level security filters
handler.Hooks().Register(BeforeScan, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.ApplyRowSecurity(secCtx, securityList)
})
// Hook 3: AfterRead - Apply column-level security (masking)
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.ApplyColumnSecurity(secCtx, securityList)
})
// Hook 4 (Optional): Audit logging
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
logger.Info("Security hooks registered for resolvespec handler")
}
// securityContext adapts resolvespec.HookContext to security.SecurityContext interface
type securityContext struct {
ctx *HookContext
}
func newSecurityContext(ctx *HookContext) security.SecurityContext {
return &securityContext{ctx: ctx}
}
func (s *securityContext) GetContext() context.Context {
return s.ctx.Context
}
func (s *securityContext) GetUserID() (int, bool) {
return security.GetUserID(s.ctx.Context)
}
func (s *securityContext) GetSchema() string {
return s.ctx.Schema
}
func (s *securityContext) GetEntity() string {
return s.ctx.Entity
}
func (s *securityContext) GetModel() interface{} {
return s.ctx.Model
}
func (s *securityContext) GetQuery() interface{} {
return s.ctx.Query
}
func (s *securityContext) SetQuery(query interface{}) {
if q, ok := query.(common.SelectQuery); ok {
s.ctx.Query = q
}
}
func (s *securityContext) GetResult() interface{} {
return s.ctx.Result
}
func (s *securityContext) SetResult(result interface{}) {
s.ctx.Result = result
}

View File

@@ -0,0 +1,181 @@
package restheadspec
import (
"context"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestContextOperations(t *testing.T) {
ctx := context.Background()
// Test Schema
t.Run("WithSchema and GetSchema", func(t *testing.T) {
ctx = WithSchema(ctx, "public")
schema := GetSchema(ctx)
if schema != "public" {
t.Errorf("Expected schema 'public', got '%s'", schema)
}
})
// Test Entity
t.Run("WithEntity and GetEntity", func(t *testing.T) {
ctx = WithEntity(ctx, "users")
entity := GetEntity(ctx)
if entity != "users" {
t.Errorf("Expected entity 'users', got '%s'", entity)
}
})
// Test TableName
t.Run("WithTableName and GetTableName", func(t *testing.T) {
ctx = WithTableName(ctx, "public.users")
tableName := GetTableName(ctx)
if tableName != "public.users" {
t.Errorf("Expected tableName 'public.users', got '%s'", tableName)
}
})
// Test Model
t.Run("WithModel and GetModel", func(t *testing.T) {
type TestModel struct {
ID int
Name string
}
model := &TestModel{ID: 1, Name: "test"}
ctx = WithModel(ctx, model)
retrieved := GetModel(ctx)
if retrieved == nil {
t.Error("Expected model to be retrieved, got nil")
}
if retrievedModel, ok := retrieved.(*TestModel); ok {
if retrievedModel.ID != 1 || retrievedModel.Name != "test" {
t.Errorf("Expected model with ID=1 and Name='test', got ID=%d, Name='%s'", retrievedModel.ID, retrievedModel.Name)
}
} else {
t.Error("Retrieved model is not of expected type")
}
})
// Test ModelPtr
t.Run("WithModelPtr and GetModelPtr", func(t *testing.T) {
type TestModel struct {
ID int
}
models := []*TestModel{}
ctx = WithModelPtr(ctx, &models)
retrieved := GetModelPtr(ctx)
if retrieved == nil {
t.Error("Expected modelPtr to be retrieved, got nil")
}
})
// Test Options
t.Run("WithOptions and GetOptions", func(t *testing.T) {
limit := 10
options := ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Limit: &limit,
},
}
ctx = WithOptions(ctx, options)
retrieved := GetOptions(ctx)
if retrieved == nil {
t.Error("Expected options to be retrieved, got nil")
return
}
if retrieved.Limit == nil || *retrieved.Limit != 10 {
t.Error("Expected options to be retrieved with limit=10")
}
})
// Test WithRequestData
t.Run("WithRequestData", func(t *testing.T) {
type TestModel struct {
ID int
Name string
}
model := &TestModel{ID: 1, Name: "test"}
modelPtr := &[]*TestModel{}
limit := 20
options := ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Limit: &limit,
},
}
ctx = WithRequestData(ctx, "test_schema", "test_entity", "test_schema.test_entity", model, modelPtr, options)
if GetSchema(ctx) != "test_schema" {
t.Errorf("Expected schema 'test_schema', got '%s'", GetSchema(ctx))
}
if GetEntity(ctx) != "test_entity" {
t.Errorf("Expected entity 'test_entity', got '%s'", GetEntity(ctx))
}
if GetTableName(ctx) != "test_schema.test_entity" {
t.Errorf("Expected tableName 'test_schema.test_entity', got '%s'", GetTableName(ctx))
}
if GetModel(ctx) == nil {
t.Error("Expected model to be set")
}
if GetModelPtr(ctx) == nil {
t.Error("Expected modelPtr to be set")
}
opts := GetOptions(ctx)
if opts == nil {
t.Error("Expected options to be set")
return
}
if opts.Limit == nil || *opts.Limit != 20 {
t.Error("Expected options to be set with limit=20")
}
})
}
func TestEmptyContext(t *testing.T) {
ctx := context.Background()
t.Run("GetSchema with empty context", func(t *testing.T) {
schema := GetSchema(ctx)
if schema != "" {
t.Errorf("Expected empty schema, got '%s'", schema)
}
})
t.Run("GetEntity with empty context", func(t *testing.T) {
entity := GetEntity(ctx)
if entity != "" {
t.Errorf("Expected empty entity, got '%s'", entity)
}
})
t.Run("GetTableName with empty context", func(t *testing.T) {
tableName := GetTableName(ctx)
if tableName != "" {
t.Errorf("Expected empty tableName, got '%s'", tableName)
}
})
t.Run("GetModel with empty context", func(t *testing.T) {
model := GetModel(ctx)
if model != nil {
t.Errorf("Expected nil model, got %v", model)
}
})
t.Run("GetModelPtr with empty context", func(t *testing.T) {
modelPtr := GetModelPtr(ctx)
if modelPtr != nil {
t.Errorf("Expected nil modelPtr, got %v", modelPtr)
}
})
t.Run("GetOptions with empty context", func(t *testing.T) {
options := GetOptions(ctx)
// GetOptions returns nil when context is empty
if options != nil {
t.Errorf("Expected nil options in empty context, got %v", options)
}
})
}

View File

@@ -0,0 +1,305 @@
package restheadspec
import (
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestGetCursorFilter_Forward(t *testing.T) {
opts := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Sort: []common.SortOption{
{Column: "created_at", Direction: "DESC"},
{Column: "id", Direction: "ASC"},
},
},
}
opts.CursorForward = "123"
tableName := "posts"
pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
if filter == "" {
t.Fatal("Expected non-empty cursor filter")
}
// Verify filter contains EXISTS subquery
if !strings.Contains(filter, "EXISTS") {
t.Errorf("Filter should contain EXISTS subquery, got: %s", filter)
}
// Verify filter references the cursor ID
if !strings.Contains(filter, "123") {
t.Errorf("Filter should reference cursor ID 123, got: %s", filter)
}
// Verify filter contains the table name
if !strings.Contains(filter, tableName) {
t.Errorf("Filter should reference table name %s, got: %s", tableName, filter)
}
// Verify filter contains primary key
if !strings.Contains(filter, pkName) {
t.Errorf("Filter should reference primary key %s, got: %s", pkName, filter)
}
t.Logf("Generated cursor filter: %s", filter)
}
func TestGetCursorFilter_Backward(t *testing.T) {
opts := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Sort: []common.SortOption{
{Column: "created_at", Direction: "DESC"},
{Column: "id", Direction: "ASC"},
},
},
}
opts.CursorBackward = "456"
tableName := "posts"
pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
if filter == "" {
t.Fatal("Expected non-empty cursor filter")
}
// Verify filter contains cursor ID
if !strings.Contains(filter, "456") {
t.Errorf("Filter should reference cursor ID 456, got: %s", filter)
}
// For backward cursor, sort direction should be reversed
// This is handled internally by the GetCursorFilter method
t.Logf("Generated backward cursor filter: %s", filter)
}
func TestGetCursorFilter_NoCursor(t *testing.T) {
opts := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Sort: []common.SortOption{
{Column: "created_at", Direction: "DESC"},
},
},
}
// No cursor set
tableName := "posts"
pkName := "id"
modelColumns := []string{"id", "title", "created_at"}
_, err := opts.GetCursorFilter(tableName, pkName, modelColumns, nil)
if err == nil {
t.Error("Expected error when no cursor is provided")
}
if !strings.Contains(err.Error(), "no cursor provided") {
t.Errorf("Expected 'no cursor provided' error, got: %v", err)
}
}
func TestGetCursorFilter_NoSort(t *testing.T) {
opts := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Sort: []common.SortOption{},
},
}
opts.CursorForward = "123"
tableName := "posts"
pkName := "id"
modelColumns := []string{"id", "title"}
_, err := opts.GetCursorFilter(tableName, pkName, modelColumns, nil)
if err == nil {
t.Error("Expected error when no sort columns are defined")
}
if !strings.Contains(err.Error(), "no sort columns") {
t.Errorf("Expected 'no sort columns' error, got: %v", err)
}
}
func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
opts := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Sort: []common.SortOption{
{Column: "priority", Direction: "DESC"},
{Column: "created_at", Direction: "DESC"},
{Column: "id", Direction: "ASC"},
},
},
}
opts.CursorForward = "789"
tableName := "tasks"
pkName := "id"
modelColumns := []string{"id", "title", "priority", "created_at"}
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
// Verify filter contains priority column
if !strings.Contains(filter, "priority") {
t.Errorf("Filter should reference priority column, got: %s", filter)
}
// Verify filter contains created_at column
if !strings.Contains(filter, "created_at") {
t.Errorf("Filter should reference created_at column, got: %s", filter)
}
t.Logf("Generated multi-column cursor filter: %s", filter)
}
func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
opts := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Sort: []common.SortOption{
{Column: "name", Direction: "ASC"},
},
},
}
opts.CursorForward = "100"
tableName := "public.users"
pkName := "id"
modelColumns := []string{"id", "name", "email"}
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
// Should handle schema prefix properly
if !strings.Contains(filter, "users") {
t.Errorf("Filter should reference table name users, got: %s", filter)
}
t.Logf("Generated cursor filter with schema: %s", filter)
}
func TestGetActiveCursor(t *testing.T) {
tests := []struct {
name string
cursorForward string
cursorBackward string
expectedID string
expectedDirection CursorDirection
}{
{
name: "Forward cursor only",
cursorForward: "123",
cursorBackward: "",
expectedID: "123",
expectedDirection: CursorForward,
},
{
name: "Backward cursor only",
cursorForward: "",
cursorBackward: "456",
expectedID: "456",
expectedDirection: CursorBackward,
},
{
name: "Both cursors - forward takes precedence",
cursorForward: "123",
cursorBackward: "456",
expectedID: "123",
expectedDirection: CursorForward,
},
{
name: "No cursors",
cursorForward: "",
cursorBackward: "",
expectedID: "",
expectedDirection: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := &ExtendedRequestOptions{}
opts.CursorForward = tt.cursorForward
opts.CursorBackward = tt.cursorBackward
id, direction := opts.getActiveCursor()
if id != tt.expectedID {
t.Errorf("Expected cursor ID %q, got %q", tt.expectedID, id)
}
if direction != tt.expectedDirection {
t.Errorf("Expected direction %d, got %d", tt.expectedDirection, direction)
}
})
}
}
func TestCleanSortField(t *testing.T) {
opts := &ExtendedRequestOptions{}
tests := []struct {
input string
expected string
}{
{"created_at desc", "created_at"},
{"name asc", "name"},
{"priority desc nulls last", "priority"},
{"id asc nulls first", "id"},
{"title", "title"},
{"updated_at DESC", "updated_at"},
{" status asc ", "status"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := opts.cleanSortField(tt.input)
if result != tt.expected {
t.Errorf("cleanSortField(%q) = %q, expected %q", tt.input, result, tt.expected)
}
})
}
}
func TestBuildPriorityChain(t *testing.T) {
clauses := []string{
"cursor_select.priority > posts.priority",
"cursor_select.created_at > posts.created_at",
"cursor_select.id < posts.id",
}
result := buildPriorityChain(clauses)
// Should build OR-AND chain for cursor comparison
if !strings.Contains(result, "OR") {
t.Error("Priority chain should contain OR operators")
}
if !strings.Contains(result, "AND") {
t.Error("Priority chain should contain AND operators for composite conditions")
}
// First clause should appear standalone
if !strings.Contains(result, clauses[0]) {
t.Errorf("Priority chain should contain first clause: %s", clauses[0])
}
t.Logf("Built priority chain: %s", result)
}

View File

@@ -9,19 +9,27 @@ import (
"runtime/debug" "runtime/debug"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
"github.com/bitechdev/ResolveSpec/pkg/common" "github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection" "github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
// FallbackHandler is a function that handles requests when no model is found
// It receives the same parameters as the Handle method
type FallbackHandler func(w common.ResponseWriter, r common.Request, params map[string]string)
// Handler handles API requests using database and model abstractions // Handler handles API requests using database and model abstractions
// This handler reads filters, columns, and options from HTTP headers // This handler reads filters, columns, and options from HTTP headers
type Handler struct { type Handler struct {
db common.Database db common.Database
registry common.ModelRegistry registry common.ModelRegistry
hooks *HookRegistry hooks *HookRegistry
nestedProcessor *common.NestedCUDProcessor nestedProcessor *common.NestedCUDProcessor
fallbackHandler FallbackHandler
openAPIGenerator func() (string, error)
} }
// NewHandler creates a new API handler with database and registry abstractions // NewHandler creates a new API handler with database and registry abstractions
@@ -36,12 +44,24 @@ func NewHandler(db common.Database, registry common.ModelRegistry) *Handler {
return handler return handler
} }
// GetDatabase returns the underlying database connection
// Implements common.SpecHandler interface
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// Hooks returns the hook registry for this handler // Hooks returns the hook registry for this handler
// Use this to register custom hooks for operations // Use this to register custom hooks for operations
func (h *Handler) Hooks() *HookRegistry { func (h *Handler) Hooks() *HookRegistry {
return h.hooks return h.hooks
} }
// SetFallbackHandler sets a fallback handler to be called when no model is found
// If not set, the handler will simply return (pass through to next route)
func (h *Handler) SetFallbackHandler(fallback FallbackHandler) {
h.fallbackHandler = fallback
}
// handlePanic is a helper function to handle panics with stack traces // handlePanic is a helper function to handle panics with stack traces
func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) { func (h *Handler) handlePanic(w common.ResponseWriter, method string, err interface{}) {
stack := debug.Stack() stack := debug.Stack()
@@ -59,7 +79,13 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
} }
}() }()
ctx := context.Background() // Check for ?openapi query parameter
if r.UnderlyingRequest().URL.Query().Get("openapi") != "" {
h.HandleOpenAPI(w, r)
return
}
ctx := r.UnderlyingRequest().Context()
schema := params["schema"] schema := params["schema"]
entity := params["entity"] entity := params["entity"]
@@ -73,32 +99,27 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// Get model and populate context with request-scoped data // Get model and populate context with request-scoped data
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
logger.Error("Invalid entity: %v", err) // Model not found - call fallback handler if set, otherwise pass through
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err) logger.Debug("Model not found for %s.%s", schema, entity)
if h.fallbackHandler != nil {
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
h.fallbackHandler(w, r, params)
} else {
logger.Debug("No fallback handler set, passing through to next route")
}
return return
} }
// Validate that the model is a struct type (not a slice or pointer to slice) // Validate and unwrap model using common utility
modelType := reflect.TypeOf(model) result, err := common.ValidateAndUnwrapModel(model)
originalType := modelType if err != nil {
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { logger.Error("Model for %s.%s validation failed: %v", schema, entity, err)
modelType = modelType.Elem() h.sendError(w, http.StatusInternalServerError, "invalid_model_type", err.Error(), err)
}
if modelType == nil || modelType.Kind() != reflect.Struct {
logger.Error("Model for %s.%s must be a struct type, got %v. Please register models as struct types, not slices or pointers to slices.", schema, entity, originalType)
h.sendError(w, http.StatusInternalServerError, "invalid_model_type",
fmt.Sprintf("Model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType),
fmt.Errorf("invalid model type: %v", originalType))
return return
} }
// If the registered model was a pointer or slice, use the unwrapped struct type model = result.Model
if originalType != modelType { modelPtr := result.ModelPtr
model = reflect.New(modelType).Elem().Interface()
}
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
tableName := h.getTableName(schema, entity, model) tableName := h.getTableName(schema, entity, model)
// Parse options from headers - this now includes relation name resolution // Parse options from headers - this now includes relation name resolution
@@ -121,13 +142,25 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
h.handleRead(ctx, w, "", options) h.handleRead(ctx, w, "", options)
} }
case "POST": case "POST":
// Create operation // Read request body
body, err := r.Body() body, err := r.Body()
if err != nil { if err != nil {
logger.Error("Failed to read request body: %v", err) logger.Error("Failed to read request body: %v", err)
h.sendError(w, http.StatusBadRequest, "invalid_request", "Failed to read request body", err) h.sendError(w, http.StatusBadRequest, "invalid_request", "Failed to read request body", err)
return return
} }
// Try to detect if this is a meta operation request
var bodyMap map[string]interface{}
if err := json.Unmarshal(body, &bodyMap); err == nil {
if operation, ok := bodyMap["operation"].(string); ok && operation == "meta" {
logger.Info("Detected meta operation request for %s.%s", schema, entity)
h.handleMeta(ctx, w, schema, entity, model)
return
}
}
// Not a meta operation, proceed with normal create/update
var data interface{} var data interface{}
if err := json.Unmarshal(body, &data); err != nil { if err := json.Unmarshal(body, &data); err != nil {
logger.Error("Failed to decode request body: %v", err) logger.Error("Failed to decode request body: %v", err)
@@ -182,6 +215,12 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
} }
}() }()
// Check for ?openapi query parameter
if r.UnderlyingRequest().URL.Query().Get("openapi") != "" {
h.HandleOpenAPI(w, r)
return
}
schema := params["schema"] schema := params["schema"]
entity := params["entity"] entity := params["entity"]
@@ -189,11 +228,44 @@ func (h *Handler) HandleGet(w common.ResponseWriter, r common.Request, params ma
model, err := h.registry.GetModelByEntity(schema, entity) model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil { if err != nil {
logger.Error("Failed to get model: %v", err) // Model not found - call fallback handler if set, otherwise pass through
h.sendError(w, http.StatusBadRequest, "invalid_entity", "Invalid entity", err) logger.Debug("Model not found for %s.%s", schema, entity)
if h.fallbackHandler != nil {
logger.Debug("Calling fallback handler for %s.%s", schema, entity)
h.fallbackHandler(w, r, params)
} else {
logger.Debug("No fallback handler set, passing through to next route")
}
return return
} }
// Parse request options from headers to get response format settings
options := h.parseOptionsFromHeaders(r, model)
tableMetadata := h.generateMetadata(schema, entity, model)
// Send with formatted response to respect DetailApi/SimpleApi/Syncfusion format
// Create empty metadata for response wrapper
responseMetadata := &common.Metadata{
Total: 0,
Filtered: 0,
Count: 0,
Limit: 0,
Offset: 0,
}
h.sendFormattedResponse(w, tableMetadata, responseMetadata, options)
}
// handleMeta processes meta operation requests
func (h *Handler) handleMeta(ctx context.Context, w common.ResponseWriter, schema, entity string, model interface{}) {
// Capture panics and return error response
defer func() {
if err := recover(); err != nil {
h.handlePanic(w, "handleMeta", err)
}
}()
logger.Info("Getting metadata for %s.%s via meta operation", schema, entity)
metadata := h.generateMetadata(schema, entity, model) metadata := h.generateMetadata(schema, entity, model)
h.sendResponse(w, metadata, nil) h.sendResponse(w, metadata, nil)
} }
@@ -213,6 +285,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
tableName := GetTableName(ctx) tableName := GetTableName(ctx)
model := GetModel(ctx) model := GetModel(ctx)
if id == "" {
options.SingleRecordAsObject = false
}
// Execute BeforeRead hooks // Execute BeforeRead hooks
hookCtx := &HookContext{ hookCtx := &HookContext{
Context: ctx, Context: ctx,
@@ -260,15 +336,23 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
query = query.Table(tableName) query = query.Table(tableName)
} }
// Note: X-Files configuration is now applied via parseXFiles which populates // If we have computed columns/expressions but options.Columns is empty,
// ExtendedRequestOptions fields (columns, filters, sort, preload, etc.) // populate it with all model columns first since computed columns are additions
// These are applied below in the normal query building process if len(options.Columns) == 0 && (len(options.ComputedQL) > 0 || len(options.ComputedColumns) > 0) {
logger.Debug("Populating options.Columns with all model columns since computed columns are additions")
options.Columns = reflection.GetSQLModelColumns(model)
}
// Apply ComputedQL fields if any // Apply ComputedQL fields if any
if len(options.ComputedQL) > 0 { if len(options.ComputedQL) > 0 {
for colName, colExpr := range options.ComputedQL { for colName, colExpr := range options.ComputedQL {
logger.Debug("Applying computed column: %s", colName) logger.Debug("Applying computed column: %s", colName)
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName)) if strings.Contains(colName, "cql") {
query = query.ColumnExpr(fmt.Sprintf("(%s)::text AS %s", colExpr, colName))
} else {
query = query.ColumnExpr(fmt.Sprintf("(%s)AS %s", colExpr, colName))
}
for colIndex := range options.Columns { for colIndex := range options.Columns {
if options.Columns[colIndex] == colName { if options.Columns[colIndex] == colName {
// Remove the computed column from the selected columns to avoid duplication // Remove the computed column from the selected columns to avoid duplication
@@ -282,7 +366,12 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
if len(options.ComputedColumns) > 0 { if len(options.ComputedColumns) > 0 {
for _, cu := range options.ComputedColumns { for _, cu := range options.ComputedColumns {
logger.Debug("Applying computed column: %s", cu.Name) logger.Debug("Applying computed column: %s", cu.Name)
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name)) if strings.Contains(cu.Name, "cql") {
query = query.ColumnExpr(fmt.Sprintf("(%s)::text AS %s", cu.Expression, cu.Name))
} else {
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
}
for colIndex := range options.Columns { for colIndex := range options.Columns {
if options.Columns[colIndex] == cu.Name { if options.Columns[colIndex] == cu.Name {
// Remove the computed column from the selected columns to avoid duplication // Remove the computed column from the selected columns to avoid duplication
@@ -296,7 +385,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply column selection // Apply column selection
if len(options.Columns) > 0 { if len(options.Columns) > 0 {
logger.Debug("Selecting columns: %v", options.Columns) logger.Debug("Selecting columns: %v", options.Columns)
query = query.Column(options.Columns...) for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
} }
// Apply expand (Just expand to Preload for now) // Apply expand (Just expand to Preload for now)
@@ -344,50 +436,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
for idx := range options.Preload { for idx := range options.Preload {
preload := options.Preload[idx] preload := options.Preload[idx]
logger.Debug("Applying preload: %s", preload.Relation) logger.Debug("Applying preload: %s", preload.Relation)
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
if len(preload.OmitColumns) > 0 {
allCols := reflection.GetModelColumns(model)
// Remove omitted columns
preload.Columns = []string{}
for _, col := range allCols {
addCols := true
for _, omitCol := range preload.OmitColumns {
if col == omitCol {
addCols = false
break
}
}
if addCols {
preload.Columns = append(preload.Columns, col)
}
}
}
if len(preload.Columns) > 0 { // Validate and fix WHERE clause to ensure it contains the relation prefix
sq = sq.Column(preload.Columns...) if len(preload.Where) > 0 {
fixedWhere, err := common.ValidateAndFixPreloadWhere(preload.Where, preload.Relation)
if err != nil {
logger.Error("Invalid preload WHERE clause for relation '%s': %v", preload.Relation, err)
h.sendError(w, http.StatusBadRequest, "invalid_preload_where",
fmt.Sprintf("Invalid preload WHERE clause for relation '%s'", preload.Relation), err)
return
} }
preload.Where = fixedWhere
}
if len(preload.Filters) > 0 { // Apply the preload with recursive support
for _, filter := range preload.Filters { query = h.applyPreloadWithRecursion(query, preload, options.Preload, model, 0)
sq = h.applyFilter(sq, filter, "", false, "AND")
}
}
if len(preload.Sort) > 0 {
for _, sort := range preload.Sort {
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
}
}
if len(preload.Where) > 0 {
sq = sq.Where(preload.Where)
}
if preload.Limit != nil && *preload.Limit > 0 {
sq = sq.Limit(*preload.Limit)
}
return sq
})
} }
// Apply DISTINCT if requested // Apply DISTINCT if requested
@@ -417,13 +480,21 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply custom SQL WHERE clause (AND condition) // Apply custom SQL WHERE clause (AND condition)
if options.CustomSQLWhere != "" { if options.CustomSQLWhere != "" {
logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere) logger.Debug("Applying custom SQL WHERE: %s", options.CustomSQLWhere)
query = query.Where(options.CustomSQLWhere) // Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
sanitizedWhere := common.SanitizeWhereClause(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
if sanitizedWhere != "" {
query = query.Where(sanitizedWhere)
}
} }
// Apply custom SQL WHERE clause (OR condition) // Apply custom SQL WHERE clause (OR condition)
if options.CustomSQLOr != "" { if options.CustomSQLOr != "" {
logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr) logger.Debug("Applying custom SQL OR: %s", options.CustomSQLOr)
query = query.WhereOr(options.CustomSQLOr) // Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
sanitizedOr := common.SanitizeWhereClause(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
if sanitizedOr != "" {
query = query.WhereOr(sanitizedOr)
}
} }
// If ID is provided, filter by ID // If ID is provided, filter by ID
@@ -447,14 +518,69 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Get total count before pagination (unless skip count is requested) // Get total count before pagination (unless skip count is requested)
var total int var total int
if !options.SkipCount { if !options.SkipCount {
count, err := query.Count(ctx) // Try to get from cache first (unless SkipCache is true)
if err != nil { var cachedTotal *cache.CachedTotal
logger.Error("Error counting records: %v", err) var cacheKey string
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
return if !options.SkipCache {
// Build cache key from query parameters
// Convert expand options to interface slice for the cache key builder
expandOpts := make([]interface{}, len(options.Expand))
for i, exp := range options.Expand {
expandOpts[i] = map[string]interface{}{
"relation": exp.Relation,
"where": exp.Where,
}
}
cacheKeyHash := cache.BuildExtendedQueryCacheKey(
tableName,
options.Filters,
options.Sort,
options.CustomSQLWhere,
options.CustomSQLOr,
expandOpts,
options.Distinct,
options.CursorForward,
options.CursorBackward,
)
cacheKey = cache.GetQueryTotalCacheKey(cacheKeyHash)
// Try to retrieve from cache
cachedTotal = &cache.CachedTotal{}
err := cache.GetDefaultCache().Get(ctx, cacheKey, cachedTotal)
if err == nil {
total = cachedTotal.Total
logger.Debug("Total records (from cache): %d", total)
} else {
logger.Debug("Cache miss for query total")
cachedTotal = nil
}
}
// If not in cache or cache skip, execute count query
if cachedTotal == nil {
count, err := query.Count(ctx)
if err != nil {
logger.Error("Error counting records: %v", err)
h.sendError(w, http.StatusInternalServerError, "query_error", "Error counting records", err)
return
}
total = count
logger.Debug("Total records (from query): %d", total)
// Store in cache (if caching is enabled)
if !options.SkipCache && cacheKey != "" {
cacheTTL := time.Minute * 2 // Default 2 minutes TTL
cacheData := &cache.CachedTotal{Total: total}
if err := cache.GetDefaultCache().Set(ctx, cacheKey, cacheData, cacheTTL); err != nil {
logger.Warn("Failed to cache query total: %v", err)
// Don't fail the request if caching fails
} else {
logger.Debug("Cached query total with key: %s", cacheKey)
}
}
} }
total = count
logger.Debug("Total records: %d", total)
} else { } else {
logger.Debug("Skipping count as requested") logger.Debug("Skipping count as requested")
total = -1 // Indicate count was skipped total = -1 // Indicate count was skipped
@@ -499,7 +625,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Apply cursor filter to query // Apply cursor filter to query
if cursorFilter != "" { if cursorFilter != "" {
logger.Debug("Applying cursor filter: %s", cursorFilter) logger.Debug("Applying cursor filter: %s", cursorFilter)
query = query.Where(cursorFilter) sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
if sanitizedCursor != "" {
query = query.Where(sanitizedCursor)
}
} }
} }
@@ -573,6 +702,144 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
h.sendFormattedResponse(w, modelPtr, metadata, options) h.sendFormattedResponse(w, modelPtr, metadata, options)
} }
// applyPreloadWithRecursion applies a preload with support for ComputedQL and recursive preloading
func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload common.PreloadOption, allPreloads []common.PreloadOption, model interface{}, depth int) common.SelectQuery {
// Log relationship keys if they're specified (from XFiles)
if preload.RelatedKey != "" || preload.ForeignKey != "" || preload.PrimaryKey != "" {
logger.Debug("Preload %s has relationship keys - PK: %s, RelatedKey: %s, ForeignKey: %s",
preload.Relation, preload.PrimaryKey, preload.RelatedKey, preload.ForeignKey)
// Build a WHERE clause using the relationship keys if needed
// Note: Bun's PreloadRelation typically handles the relationship join automatically via struct tags
// However, if the relationship keys are explicitly provided from XFiles, we can use them
// to add additional filtering or validation
if preload.RelatedKey != "" && preload.Where == "" {
// For child tables: ensure the child's relatedkey column will be matched
// The actual parent value is dynamic and handled by Bun's preload mechanism
// We just log this for visibility
logger.Debug("Child table %s will be filtered by %s matching parent's primary key",
preload.Relation, preload.RelatedKey)
}
if preload.ForeignKey != "" && preload.Where == "" {
// For parent tables: ensure the parent's primary key matches the current table's foreign key
logger.Debug("Parent table %s will be filtered by primary key matching current table's %s",
preload.Relation, preload.ForeignKey)
}
}
// Apply the preload
query = query.PreloadRelation(preload.Relation, func(sq common.SelectQuery) common.SelectQuery {
// Get the related model for column operations
relatedModel := reflection.GetRelationModel(model, preload.Relation)
if relatedModel == nil {
logger.Warn("Could not get related model for preload: %s", preload.Relation)
// relatedModel = model // fallback to parent model
} else {
// If we have computed columns but no explicit columns, populate with all model columns first
// since computed columns are additions
if len(preload.Columns) == 0 && (len(preload.ComputedQL) > 0 || len(preload.OmitColumns) > 0) {
logger.Debug("Populating preload columns with all model columns since computed columns are additions")
preload.Columns = reflection.GetSQLModelColumns(relatedModel)
}
// Apply ComputedQL fields if any
if len(preload.ComputedQL) > 0 {
for colName, colExpr := range preload.ComputedQL {
logger.Debug("Applying computed column to preload %s: %s", preload.Relation, colName)
sq = sq.ColumnExpr(fmt.Sprintf("(%s) AS %s", colExpr, colName))
// Remove the computed column from selected columns to avoid duplication
for colIndex := range preload.Columns {
if preload.Columns[colIndex] == colName {
preload.Columns = append(preload.Columns[:colIndex], preload.Columns[colIndex+1:]...)
break
}
}
}
}
// Handle OmitColumns
if len(preload.OmitColumns) > 0 {
allCols := preload.Columns
// Remove omitted columns
preload.Columns = []string{}
for _, col := range allCols {
addCols := true
for _, omitCol := range preload.OmitColumns {
if col == omitCol {
addCols = false
break
}
}
if addCols {
preload.Columns = append(preload.Columns, col)
}
}
}
// Apply column selection
if len(preload.Columns) > 0 {
sq = sq.Column(preload.Columns...)
}
}
// Apply filters
if len(preload.Filters) > 0 {
for _, filter := range preload.Filters {
sq = h.applyFilter(sq, filter, "", false, "AND")
}
}
// Apply sorting
if len(preload.Sort) > 0 {
for _, sort := range preload.Sort {
sq = sq.Order(fmt.Sprintf("%s %s", sort.Column, sort.Direction))
}
}
// Apply WHERE clause
if len(preload.Where) > 0 {
// Build RequestOptions with all preloads to allow references to sibling relations
preloadOpts := &common.RequestOptions{Preload: allPreloads}
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
if len(sanitizedWhere) > 0 {
sq = sq.Where(sanitizedWhere)
}
}
// Apply limit
if preload.Limit != nil && *preload.Limit > 0 {
sq = sq.Limit(*preload.Limit)
}
if preload.Offset != nil && *preload.Offset > 0 {
sq = sq.Offset(*preload.Offset)
}
return sq
})
// Handle recursive preloading
if preload.Recursive && depth < 5 {
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
// For recursive relationships, we need to get the last part of the relation path
// e.g., "MastertaskItems" -> "MastertaskItems.MastertaskItems"
relationParts := strings.Split(preload.Relation, ".")
lastRelationName := relationParts[len(relationParts)-1]
// Create a recursive preload with the same configuration
// but with the relation path extended
recursivePreload := preload
recursivePreload.Relation = preload.Relation + "." + lastRelationName
// Recursively apply preload until we reach depth 5
query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1)
}
return query
}
func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) { func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, data interface{}, options ExtendedRequestOptions) {
// Capture panics and return error response // Capture panics and return error response
defer func() { defer func() {
@@ -1233,7 +1500,7 @@ func (h *Handler) normalizeToSlice(data interface{}) []interface{} {
func (h *Handler) extractNestedRelations( func (h *Handler) extractNestedRelations(
data map[string]interface{}, data map[string]interface{},
model interface{}, model interface{},
) (map[string]interface{}, map[string]interface{}, error) { ) (_cleanedData map[string]interface{}, _relations map[string]interface{}, _err error) {
// Get model type for reflection // Get model type for reflection
modelType := reflect.TypeOf(model) modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) { for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
@@ -1565,23 +1832,52 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) *co
if modelType.Kind() != reflect.Struct { if modelType.Kind() != reflect.Struct {
logger.Error("Model type must be a struct, got %s for %s.%s", modelType.Kind(), schema, entity) logger.Error("Model type must be a struct, got %s for %s.%s", modelType.Kind(), schema, entity)
return &common.TableMetadata{ return &common.TableMetadata{
Schema: schema, Schema: schema,
Table: h.getTableName(schema, entity, model), Table: h.getTableName(schema, entity, model),
Columns: []common.Column{}, Columns: []common.Column{},
Relations: []string{},
} }
} }
tableName := h.getTableName(schema, entity, model) tableName := h.getTableName(schema, entity, model)
metadata := &common.TableMetadata{ metadata := &common.TableMetadata{
Schema: schema, Schema: schema,
Table: tableName, Table: tableName,
Columns: []common.Column{}, Columns: []common.Column{},
Relations: []string{},
} }
for i := 0; i < modelType.NumField(); i++ { for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i) field := modelType.Field(i)
// Skip unexported fields
if !field.IsExported() {
continue
}
gormTag := field.Tag.Get("gorm")
jsonTag := field.Tag.Get("json")
// Skip fields with json:"-"
if jsonTag == "-" {
continue
}
// Get JSON name
jsonName := strings.Split(jsonTag, ",")[0]
if jsonName == "" {
jsonName = field.Name
}
// Check if this is a relation field (slice or struct, but not time.Time)
if field.Type.Kind() == reflect.Slice ||
(field.Type.Kind() == reflect.Struct && field.Type.Name() != "Time") ||
(field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct && field.Type.Elem().Name() != "Time") {
metadata.Relations = append(metadata.Relations, jsonName)
continue
}
// Get column name from gorm tag or json tag // Get column name from gorm tag or json tag
columnName := field.Tag.Get("gorm") columnName := field.Tag.Get("gorm")
if strings.Contains(columnName, "column:") { if strings.Contains(columnName, "column:") {
@@ -1593,15 +1889,9 @@ func (h *Handler) generateMetadata(schema, entity string, model interface{}) *co
} }
} }
} else { } else {
columnName = field.Tag.Get("json") columnName = jsonName
if columnName == "" || columnName == "-" {
columnName = strings.ToLower(field.Name)
}
} }
// Check for primary key and unique constraint
gormTag := field.Tag.Get("gorm")
column := common.Column{ column := common.Column{
Name: columnName, Name: columnName,
Type: h.getColumnType(field.Type), Type: h.getColumnType(field.Type),
@@ -1652,6 +1942,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
} }
// Return data as-is without wrapping in common.Response // Return data as-is without wrapping in common.Response
w.SetHeader("Content-Type", "application/json")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
if err := w.WriteJSON(data); err != nil { if err := w.WriteJSON(data); err != nil {
logger.Error("Failed to write JSON response: %v", err) logger.Error("Failed to write JSON response: %v", err)
@@ -1662,7 +1953,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged // Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
func (h *Handler) normalizeResultArray(data interface{}) interface{} { func (h *Handler) normalizeResultArray(data interface{}) interface{} {
if data == nil { if data == nil {
return data return nil
} }
// Use reflection to check if data is a slice or array // Use reflection to check if data is a slice or array
@@ -1754,6 +2045,7 @@ func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, messa
"_error": errorMsg, "_error": errorMsg,
"_retval": 1, "_retval": 1,
} }
w.SetHeader("Content-Type", "application/json")
w.WriteHeader(statusCode) w.WriteHeader(statusCode)
if jsonErr := w.WriteJSON(response); jsonErr != nil { if jsonErr := w.WriteJSON(response); jsonErr != nil {
logger.Error("Failed to write JSON error response: %v", jsonErr) logger.Error("Failed to write JSON error response: %v", jsonErr)
@@ -2102,3 +2394,35 @@ func (h *Handler) extractTagValue(tag, key string) string {
} }
return "" return ""
} }
// HandleOpenAPI generates and returns the OpenAPI specification
func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
// Import needed here to avoid circular dependency
// The import is done inline
// We'll use a factory function approach instead
if h.openAPIGenerator == nil {
logger.Error("OpenAPI generator not configured")
h.sendError(w, http.StatusInternalServerError, "openapi_not_configured", "OpenAPI generation not configured", nil)
return
}
spec, err := h.openAPIGenerator()
if err != nil {
logger.Error("Failed to generate OpenAPI spec: %v", err)
h.sendError(w, http.StatusInternalServerError, "openapi_generation_error", "Failed to generate OpenAPI specification", err)
return
}
w.SetHeader("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, err = w.Write([]byte(spec))
if err != nil {
logger.Error("Error sending OpenAPI spec response: %v", err)
}
}
// SetOpenAPIGenerator sets the OpenAPI generator function
// This allows avoiding circular dependencies
func (h *Handler) SetOpenAPIGenerator(generator func() (string, error)) {
h.openAPIGenerator = generator
}

Some files were not shown because too many files have changed in this diff Show More