mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2025-12-29 07:44:25 +00:00
Compare commits
108 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e2fe33b77 | ||
|
|
1baa0af0ac | ||
|
|
659b2925e4 | ||
|
|
baca70cafc | ||
|
|
ed57978620 | ||
|
|
97b39de88a | ||
|
|
bf955b7971 | ||
|
|
545856f8a0 | ||
|
|
8d123e47bd | ||
|
|
c9eaf84125 | ||
|
|
aeae9d7e0c | ||
|
|
2a84652dba | ||
|
|
b741958895 | ||
|
|
2442589982 | ||
|
|
7c1bae60c9 | ||
|
|
06b2404c0c | ||
|
|
32007480c6 | ||
|
|
ab1ce869b6 | ||
|
|
ff72e04428 | ||
|
|
e35f8a4f14 | ||
|
|
5ff9a8a24e | ||
|
|
81b87af6e4 | ||
|
|
f3ba314640 | ||
|
|
93df33e274 | ||
|
|
abd045493a | ||
|
|
a61556d857 | ||
|
|
eaf1133575 | ||
|
|
8172c0495d | ||
|
|
7a3c368121 | ||
|
|
9c5c7689e9 | ||
|
|
08050c960d | ||
|
|
78029fb34f | ||
|
|
1643a5e920 | ||
|
|
6bbe0ec8b0 | ||
|
|
e32ec9e17e | ||
|
|
26c175e65e | ||
|
|
aa99e8e4bc | ||
|
|
163593901f | ||
|
|
1261960e97 | ||
|
|
76bbf33db2 | ||
|
|
02c9b96b0c | ||
|
|
9a3564f05f | ||
|
|
a931b8cdd2 | ||
|
|
7e76977dcc | ||
|
|
7853a3f56a | ||
|
|
c2e0c36c79 | ||
|
|
59bd709460 | ||
|
|
05962035b6 | ||
|
|
1cd04b7083 | ||
|
|
0d4909054c | ||
|
|
745564f2e7 | ||
|
|
311e50bfdd | ||
|
|
c95bc9e633 | ||
|
|
07b09e2025 | ||
|
|
3d5334002d | ||
|
|
640582d508 | ||
|
|
b0b3ae662b | ||
|
|
c9b9f75b06 | ||
|
|
af3260864d | ||
|
|
ca6d2deff6 | ||
|
|
1481443516 | ||
|
|
cb54ec5e27 | ||
|
|
7d6a9025f5 | ||
|
|
35089f511f | ||
|
|
66b6a0d835 | ||
|
|
456c165814 | ||
|
|
850d7b546c | ||
|
|
a44ef90d7c | ||
|
|
8b7db5b31a | ||
|
|
14daea3b05 | ||
|
|
35f23b6d9e | ||
|
|
53a4e67f70 | ||
|
|
1289c3af88 | ||
|
|
cdfb7a67fd | ||
|
|
7f5b851669 | ||
|
|
f0e26b1c0d | ||
|
|
1db1b924ef | ||
|
|
d9cf23b1dc | ||
|
|
94f013c872 | ||
|
|
c52fcff61d | ||
|
|
ce106fa940 | ||
|
|
37b4b75175 | ||
|
|
0cef0f75d3 | ||
|
|
006dc4a2b2 | ||
|
|
ecd7b31910 | ||
|
|
7b8216b71c | ||
|
|
682716dd31 | ||
|
|
412bbab560 | ||
|
|
dc3254522c | ||
|
|
2818e7e9cd | ||
|
|
e39012ddbd | ||
|
|
ceaa251301 | ||
|
|
faafe5abea | ||
|
|
3eb17666bf | ||
|
|
c8704c07dd | ||
|
|
fc82a9bc50 | ||
|
|
c26ea3cd61 | ||
|
|
a5d97cc07b | ||
|
|
0899ba5029 | ||
|
|
c84dd7dc91 | ||
|
|
f1c6b36374 | ||
|
|
abee5c942f | ||
|
|
2e9a0bd51a | ||
|
|
f518a3c73c | ||
|
|
07c239aaa1 | ||
|
|
1adca4c49b | ||
|
|
eefed23766 | ||
|
|
3b2d05465e |
1
.claude/readme
Normal file
1
.claude/readme
Normal file
@@ -0,0 +1 @@
|
||||
We use claude for testing and document generation.
|
||||
52
.env.example
Normal file
52
.env.example
Normal file
@@ -0,0 +1,52 @@
|
||||
# ResolveSpec Environment Variables Example
|
||||
# Environment variables override config file settings
|
||||
# All variables are prefixed with RESOLVESPEC_
|
||||
# Nested config uses underscores (e.g., server.addr -> RESOLVESPEC_SERVER_ADDR)
|
||||
|
||||
# Server Configuration
|
||||
RESOLVESPEC_SERVER_ADDR=:8080
|
||||
RESOLVESPEC_SERVER_SHUTDOWN_TIMEOUT=30s
|
||||
RESOLVESPEC_SERVER_DRAIN_TIMEOUT=25s
|
||||
RESOLVESPEC_SERVER_READ_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVER_WRITE_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVER_IDLE_TIMEOUT=120s
|
||||
|
||||
# Tracing Configuration
|
||||
RESOLVESPEC_TRACING_ENABLED=false
|
||||
RESOLVESPEC_TRACING_SERVICE_NAME=resolvespec
|
||||
RESOLVESPEC_TRACING_SERVICE_VERSION=1.0.0
|
||||
RESOLVESPEC_TRACING_ENDPOINT=http://localhost:4318/v1/traces
|
||||
|
||||
# Cache Configuration
|
||||
RESOLVESPEC_CACHE_PROVIDER=memory
|
||||
|
||||
# Redis Cache (when provider=redis)
|
||||
RESOLVESPEC_CACHE_REDIS_HOST=localhost
|
||||
RESOLVESPEC_CACHE_REDIS_PORT=6379
|
||||
RESOLVESPEC_CACHE_REDIS_PASSWORD=
|
||||
RESOLVESPEC_CACHE_REDIS_DB=0
|
||||
|
||||
# Memcache (when provider=memcache)
|
||||
# Note: For arrays, separate values with commas
|
||||
RESOLVESPEC_CACHE_MEMCACHE_SERVERS=localhost:11211
|
||||
RESOLVESPEC_CACHE_MEMCACHE_MAX_IDLE_CONNS=10
|
||||
RESOLVESPEC_CACHE_MEMCACHE_TIMEOUT=100ms
|
||||
|
||||
# Logger Configuration
|
||||
RESOLVESPEC_LOGGER_DEV=false
|
||||
RESOLVESPEC_LOGGER_PATH=
|
||||
|
||||
# Middleware Configuration
|
||||
RESOLVESPEC_MIDDLEWARE_RATE_LIMIT_RPS=100.0
|
||||
RESOLVESPEC_MIDDLEWARE_RATE_LIMIT_BURST=200
|
||||
RESOLVESPEC_MIDDLEWARE_MAX_REQUEST_SIZE=10485760
|
||||
|
||||
# CORS Configuration
|
||||
# Note: For arrays in env vars, separate with commas
|
||||
RESOLVESPEC_CORS_ALLOWED_ORIGINS=*
|
||||
RESOLVESPEC_CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
|
||||
RESOLVESPEC_CORS_ALLOWED_HEADERS=*
|
||||
RESOLVESPEC_CORS_MAX_AGE=3600
|
||||
|
||||
# Database Configuration
|
||||
RESOLVESPEC_DATABASE_URL=host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable
|
||||
84
.github/workflows/maint.yml
vendored
Normal file
84
.github/workflows/maint.yml
vendored
Normal file
@@ -0,0 +1,84 @@
|
||||
name: Build , Vet Test, and Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
pull_request:
|
||||
branches: [main, develop]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Run Vet Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
go-version: ["1.23.x", "1.24.x"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ${{ matrix.go-version }}
|
||||
cache: true
|
||||
|
||||
- name: Display Go version
|
||||
run: go version
|
||||
|
||||
- name: Download dependencies
|
||||
run: go mod download
|
||||
|
||||
- name: Verify dependencies
|
||||
run: go mod verify
|
||||
|
||||
- name: Run go vet
|
||||
run: go vet ./...
|
||||
|
||||
lint:
|
||||
name: Lint Code
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: true
|
||||
|
||||
- name: Run golangci-lint
|
||||
uses: golangci/golangci-lint-action@v9
|
||||
with:
|
||||
version: latest
|
||||
args: --timeout=5m
|
||||
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.x"
|
||||
cache: true
|
||||
|
||||
- name: Build
|
||||
run: go build -v ./...
|
||||
|
||||
- name: Check for uncommitted changes
|
||||
run: |
|
||||
if [[ -n $(git status -s) ]]; then
|
||||
echo "Error: Uncommitted changes found after build"
|
||||
git status -s
|
||||
exit 1
|
||||
fi
|
||||
81
.github/workflows/tests.yml
vendored
Normal file
81
.github/workflows/tests.yml
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
name: Tests
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
pull_request:
|
||||
branches: [main, develop]
|
||||
jobs:
|
||||
unit-tests:
|
||||
name: Unit Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: "1.24"
|
||||
- name: Run unit tests
|
||||
run: go test ./pkg/resolvespec ./pkg/restheadspec -v -cover
|
||||
- name: Generate coverage report
|
||||
run: |
|
||||
go test ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage.out
|
||||
go tool cover -html=coverage.out -o coverage.html
|
||||
- name: Upload coverage
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: coverage-report
|
||||
path: coverage.html
|
||||
integration-tests:
|
||||
name: Integration Tests
|
||||
runs-on: ubuntu-latest
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:15-alpine
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: postgres
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 5432:5432
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v6
|
||||
with:
|
||||
go-version: "1.24"
|
||||
- name: Create test databases
|
||||
env:
|
||||
PGPASSWORD: postgres
|
||||
run: |
|
||||
psql -h localhost -U postgres -c "CREATE DATABASE resolvespec_test;"
|
||||
psql -h localhost -U postgres -c "CREATE DATABASE restheadspec_test;"
|
||||
- name: Run resolvespec integration tests
|
||||
env:
|
||||
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
|
||||
run: go test -tags=integration ./pkg/resolvespec -v -coverprofile=coverage-resolvespec-integration.out
|
||||
- name: Run restheadspec integration tests
|
||||
env:
|
||||
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=restheadspec_test port=5432 sslmode=disable"
|
||||
run: go test -tags=integration ./pkg/restheadspec -v -coverprofile=coverage-restheadspec-integration.out
|
||||
- name: Generate integration coverage
|
||||
env:
|
||||
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
|
||||
run: |
|
||||
go tool cover -html=coverage-resolvespec-integration.out -o coverage-resolvespec-integration.html
|
||||
go tool cover -html=coverage-restheadspec-integration.out -o coverage-restheadspec-integration.html
|
||||
- name: Upload resolvespec integration coverage
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: resolvespec-integration-coverage-report
|
||||
path: coverage-resolvespec-integration.html
|
||||
|
||||
- name: Upload restheadspec integration coverage
|
||||
uses: actions/upload-artifact@v5
|
||||
with:
|
||||
name: integration-coverage-restheadspec-report
|
||||
path: coverage-restheadspec-integration
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -23,4 +23,5 @@ go.work.sum
|
||||
|
||||
# env file
|
||||
.env
|
||||
bin/
|
||||
bin/
|
||||
test.db
|
||||
|
||||
110
.golangci.bck.yml
Normal file
110
.golangci.bck.yml
Normal file
@@ -0,0 +1,110 @@
|
||||
run:
|
||||
timeout: 5m
|
||||
tests: true
|
||||
skip-dirs:
|
||||
- vendor
|
||||
- .github
|
||||
|
||||
linters:
|
||||
enable:
|
||||
- errcheck
|
||||
- gosimple
|
||||
- govet
|
||||
- ineffassign
|
||||
- staticcheck
|
||||
- unused
|
||||
- gofmt
|
||||
- goimports
|
||||
- misspell
|
||||
- gocritic
|
||||
- revive
|
||||
- stylecheck
|
||||
disable:
|
||||
- typecheck # Can cause issues with generics in some cases
|
||||
|
||||
linters-settings:
|
||||
errcheck:
|
||||
check-type-assertions: false
|
||||
check-blank: false
|
||||
|
||||
govet:
|
||||
check-shadowing: false
|
||||
|
||||
gofmt:
|
||||
simplify: true
|
||||
|
||||
goimports:
|
||||
local-prefixes: github.com/bitechdev/ResolveSpec
|
||||
|
||||
gocritic:
|
||||
enabled-checks:
|
||||
- appendAssign
|
||||
- assignOp
|
||||
- boolExprSimplify
|
||||
- builtinShadow
|
||||
- captLocal
|
||||
- caseOrder
|
||||
- defaultCaseOrder
|
||||
- dupArg
|
||||
- dupBranchBody
|
||||
- dupCase
|
||||
- dupSubExpr
|
||||
- elseif
|
||||
- emptyFallthrough
|
||||
- equalFold
|
||||
- flagName
|
||||
- ifElseChain
|
||||
- indexAlloc
|
||||
- initClause
|
||||
- methodExprCall
|
||||
- nilValReturn
|
||||
- rangeExprCopy
|
||||
- rangeValCopy
|
||||
- regexpMust
|
||||
- singleCaseSwitch
|
||||
- sloppyLen
|
||||
- stringXbytes
|
||||
- switchTrue
|
||||
- typeAssertChain
|
||||
- typeSwitchVar
|
||||
- underef
|
||||
- unlabelStmt
|
||||
- unnamedResult
|
||||
- unnecessaryBlock
|
||||
- weakCond
|
||||
- yodaStyleExpr
|
||||
|
||||
revive:
|
||||
rules:
|
||||
- name: exported
|
||||
disabled: true
|
||||
- name: package-comments
|
||||
disabled: true
|
||||
|
||||
issues:
|
||||
exclude-use-default: false
|
||||
max-issues-per-linter: 0
|
||||
max-same-issues: 0
|
||||
|
||||
# Exclude some linters from running on tests files
|
||||
exclude-rules:
|
||||
- path: _test\.go
|
||||
linters:
|
||||
- errcheck
|
||||
- dupl
|
||||
- gosec
|
||||
- gocritic
|
||||
|
||||
# Ignore "error return value not checked" for defer statements
|
||||
- linters:
|
||||
- errcheck
|
||||
text: "Error return value of .((os\\.)?std(out|err)\\..*|.*Close|.*Flush|os\\.Remove(All)?|.*print(f|ln)?|os\\.(Un)?Setenv). is not checked"
|
||||
|
||||
# Ignore complexity in test files
|
||||
- path: _test\.go
|
||||
text: "cognitive complexity|cyclomatic complexity"
|
||||
|
||||
output:
|
||||
format: colored-line-number
|
||||
print-issued-lines: true
|
||||
print-linter-name: true
|
||||
131
.golangci.json
Normal file
131
.golangci.json
Normal file
@@ -0,0 +1,131 @@
|
||||
{
|
||||
"formatters": {
|
||||
"enable": [
|
||||
"gofmt",
|
||||
"goimports"
|
||||
],
|
||||
"exclusions": {
|
||||
"generated": "lax",
|
||||
"paths": [
|
||||
"third_party$",
|
||||
"builtin$",
|
||||
"examples$"
|
||||
]
|
||||
},
|
||||
"settings": {
|
||||
"gofmt": {
|
||||
"simplify": true
|
||||
},
|
||||
"goimports": {
|
||||
"local-prefixes": [
|
||||
"github.com/bitechdev/ResolveSpec"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"issues": {
|
||||
"max-issues-per-linter": 0,
|
||||
"max-same-issues": 0
|
||||
},
|
||||
"linters": {
|
||||
"enable": [
|
||||
"gocritic",
|
||||
"misspell",
|
||||
"revive"
|
||||
],
|
||||
"exclusions": {
|
||||
"generated": "lax",
|
||||
"paths": [
|
||||
"third_party$",
|
||||
"builtin$",
|
||||
"examples$",
|
||||
"mocks?",
|
||||
"tests?"
|
||||
],
|
||||
"rules": [
|
||||
{
|
||||
"linters": [
|
||||
"dupl",
|
||||
"errcheck",
|
||||
"gocritic",
|
||||
"gosec"
|
||||
],
|
||||
"path": "_test\\.go"
|
||||
},
|
||||
{
|
||||
"linters": [
|
||||
"errcheck"
|
||||
],
|
||||
"text": "Error return value of .((os\\.)?std(out|err)\\..*|.*Close|.*Flush|os\\.Remove(All)?|.*print(f|ln)?|os\\.(Un)?Setenv). is not checked"
|
||||
},
|
||||
{
|
||||
"path": "_test\\.go",
|
||||
"text": "cognitive complexity|cyclomatic complexity"
|
||||
}
|
||||
]
|
||||
},
|
||||
"settings": {
|
||||
"errcheck": {
|
||||
"check-blank": false,
|
||||
"check-type-assertions": false
|
||||
},
|
||||
"gocritic": {
|
||||
"enabled-checks": [
|
||||
"appendAssign",
|
||||
"assignOp",
|
||||
"boolExprSimplify",
|
||||
"builtinShadow",
|
||||
"captLocal",
|
||||
"caseOrder",
|
||||
"defaultCaseOrder",
|
||||
"dupArg",
|
||||
"dupBranchBody",
|
||||
"dupCase",
|
||||
"dupSubExpr",
|
||||
"elseif",
|
||||
"emptyFallthrough",
|
||||
"equalFold",
|
||||
"flagName",
|
||||
"indexAlloc",
|
||||
"initClause",
|
||||
"methodExprCall",
|
||||
"nilValReturn",
|
||||
"rangeExprCopy",
|
||||
"rangeValCopy",
|
||||
"regexpMust",
|
||||
"singleCaseSwitch",
|
||||
"sloppyLen",
|
||||
"stringXbytes",
|
||||
"switchTrue",
|
||||
"typeAssertChain",
|
||||
"typeSwitchVar",
|
||||
"underef",
|
||||
"unlabelStmt",
|
||||
"unnamedResult",
|
||||
"unnecessaryBlock",
|
||||
"weakCond",
|
||||
"yodaStyleExpr"
|
||||
],
|
||||
"disabled-checks": [
|
||||
"ifElseChain"
|
||||
]
|
||||
},
|
||||
"revive": {
|
||||
"rules": [
|
||||
{
|
||||
"disabled": true,
|
||||
"name": "exported"
|
||||
},
|
||||
{
|
||||
"disabled": true,
|
||||
"name": "package-comments"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"run": {
|
||||
"tests": true
|
||||
},
|
||||
"version": "2"
|
||||
}
|
||||
56
.vscode/settings.json
vendored
Normal file
56
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
{
|
||||
"go.testFlags": [
|
||||
"-v"
|
||||
],
|
||||
"go.testTimeout": "300s",
|
||||
"go.coverOnSave": false,
|
||||
"go.coverOnSingleTest": true,
|
||||
"go.coverageDecorator": {
|
||||
"type": "gutter"
|
||||
},
|
||||
"go.testEnvVars": {
|
||||
"TEST_DATABASE_URL": "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
|
||||
},
|
||||
"go.toolsEnvVars": {
|
||||
"CGO_ENABLED": "0"
|
||||
},
|
||||
"go.buildTags": "",
|
||||
"go.testTags": "",
|
||||
"files.exclude": {
|
||||
"**/.git": true,
|
||||
"**/.DS_Store": true,
|
||||
"**/coverage.out": true,
|
||||
"**/coverage.html": true,
|
||||
"**/coverage-integration.out": true,
|
||||
"**/coverage-integration.html": true
|
||||
},
|
||||
"files.watcherExclude": {
|
||||
"**/.git/objects/**": true,
|
||||
"**/.git/subtree-cache/**": true,
|
||||
"**/node_modules/*/**": true,
|
||||
"**/.hg/store/**": true,
|
||||
"**/vendor/**": true
|
||||
},
|
||||
"editor.formatOnSave": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": "explicit"
|
||||
},
|
||||
"[go]": {
|
||||
"editor.defaultFormatter": "golang.go",
|
||||
"editor.formatOnSave": true,
|
||||
"editor.insertSpaces": false,
|
||||
"editor.tabSize": 4
|
||||
},
|
||||
"gopls": {
|
||||
"ui.completion.usePlaceholders": true,
|
||||
"ui.semanticTokens": true,
|
||||
"ui.codelenses": {
|
||||
"generate": true,
|
||||
"regenerate_cgo": true,
|
||||
"test": true,
|
||||
"tidy": true,
|
||||
"upgrade_dependency": true,
|
||||
"vendor": true
|
||||
}
|
||||
}
|
||||
}
|
||||
262
.vscode/tasks.json
vendored
262
.vscode/tasks.json
vendored
@@ -6,10 +6,10 @@
|
||||
"label": "go: build workspace",
|
||||
"command": "build",
|
||||
"options": {
|
||||
"env": {
|
||||
"env": {
|
||||
"CGO_ENABLED": "0"
|
||||
},
|
||||
"cwd": "${workspaceFolder}/bin",
|
||||
"cwd": "${workspaceFolder}/bin"
|
||||
},
|
||||
"args": [
|
||||
"../..."
|
||||
@@ -17,28 +17,262 @@
|
||||
"problemMatcher": [
|
||||
"$go"
|
||||
],
|
||||
"group": "build",
|
||||
|
||||
"group": "build"
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "test: unit tests (all)",
|
||||
"command": "go test ./pkg/resolvespec ./pkg/restheadspec -v -cover",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [
|
||||
"$go"
|
||||
],
|
||||
"group": {
|
||||
"kind": "test",
|
||||
"isDefault": true
|
||||
},
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "shared",
|
||||
"focus": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "test: unit tests (resolvespec)",
|
||||
"command": "go test ./pkg/resolvespec -v -cover",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [
|
||||
"$go"
|
||||
],
|
||||
"group": "test",
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "shared"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "test: unit tests (restheadspec)",
|
||||
"command": "go test ./pkg/restheadspec -v -cover",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [
|
||||
"$go"
|
||||
],
|
||||
"group": "test",
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "shared"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "test: integration tests (automated)",
|
||||
"command": "./scripts/run-integration-tests.sh",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [
|
||||
"$go"
|
||||
],
|
||||
"group": "test",
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "dedicated",
|
||||
"focus": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "test: integration tests (resolvespec only)",
|
||||
"command": "./scripts/run-integration-tests.sh resolvespec",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [
|
||||
"$go"
|
||||
],
|
||||
"group": "test",
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "dedicated"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "test: integration tests (restheadspec only)",
|
||||
"command": "./scripts/run-integration-tests.sh restheadspec",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [
|
||||
"$go"
|
||||
],
|
||||
"group": "test",
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "dedicated"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "test: coverage report",
|
||||
"command": "make coverage",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "test",
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "shared"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "test: integration coverage report",
|
||||
"command": "make coverage-integration",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "test",
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "shared"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "docker: start postgres",
|
||||
"command": "make docker-up",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "shared"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "docker: stop postgres",
|
||||
"command": "make docker-down",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "shared"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "docker: clean postgres data",
|
||||
"command": "make clean",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "shared"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "go",
|
||||
"label": "go: test workspace",
|
||||
"label": "go: test workspace (with race)",
|
||||
"command": "test",
|
||||
|
||||
"options": {
|
||||
"env": {
|
||||
"CGO_ENABLED": "0"
|
||||
},
|
||||
"cwd": "${workspaceFolder}/bin",
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"args": [
|
||||
"../..."
|
||||
"-v",
|
||||
"-race",
|
||||
"-coverprofile=coverage.out",
|
||||
"-covermode=atomic",
|
||||
"./..."
|
||||
],
|
||||
"problemMatcher": [
|
||||
"$go"
|
||||
],
|
||||
"group": "build",
|
||||
|
||||
"group": "test",
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "shared"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "go: vet workspace",
|
||||
"command": "go vet ./...",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [
|
||||
"$go"
|
||||
],
|
||||
"group": "test"
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "go: lint workspace",
|
||||
"command": "golangci-lint run --timeout=5m",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"problemMatcher": [],
|
||||
"group": "test"
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "test: all tests (unit + integration)",
|
||||
"command": "make test",
|
||||
"options": {
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
"dependsOn": [
|
||||
"docker: start postgres"
|
||||
],
|
||||
"problemMatcher": [
|
||||
"$go"
|
||||
],
|
||||
"group": "test",
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "dedicated",
|
||||
"focus": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "test: full suite with checks",
|
||||
"dependsOrder": "sequence",
|
||||
"dependsOn": [
|
||||
"go: vet workspace",
|
||||
"test: unit tests (all)",
|
||||
"test: integration tests (automated)"
|
||||
],
|
||||
"problemMatcher": [],
|
||||
"group": "test",
|
||||
"presentation": {
|
||||
"reveal": "always",
|
||||
"panel": "dedicated"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "shell",
|
||||
"label": "Make Release",
|
||||
"problemMatcher": [],
|
||||
"command": "sh ${workspaceFolder}/make_release.sh"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Warky Devs Pty Ltd
|
||||
Copyright (c) 2025
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
# Migration Guide: Database and Router Abstraction
|
||||
|
||||
This guide explains how to migrate from the direct GORM/Router dependencies to the new abstracted interfaces.
|
||||
|
||||
## Overview of Changes
|
||||
|
||||
### What was changed:
|
||||
1. **Database Operations**: GORM-specific code is now abstracted behind `Database` interface
|
||||
2. **Router Integration**: HTTP router dependencies are abstracted behind `Router` interface
|
||||
3. **Model Registry**: Models are now managed through a `ModelRegistry` interface
|
||||
4. **Backward Compatibility**: Existing code continues to work with `NewAPIHandler()`
|
||||
|
||||
### Benefits:
|
||||
- **Database Flexibility**: Switch between GORM, Bun, or other ORMs without code changes
|
||||
- **Router Flexibility**: Use Gorilla Mux, Gin, Echo, or other routers
|
||||
- **Better Testing**: Easy to mock database and router interactions
|
||||
- **Cleaner Separation**: Business logic separated from ORM/router specifics
|
||||
|
||||
## Migration Path
|
||||
|
||||
### Option 1: No Changes Required (Backward Compatible)
|
||||
Your existing code continues to work without any changes:
|
||||
|
||||
```go
|
||||
// This still works exactly as before
|
||||
handler := resolvespec.NewAPIHandler(db)
|
||||
```
|
||||
|
||||
### Option 2: Gradual Migration to New API
|
||||
|
||||
#### Step 1: Use New Handler Constructor
|
||||
```go
|
||||
// Old way
|
||||
handler := resolvespec.NewAPIHandler(gormDB)
|
||||
|
||||
// New way
|
||||
handler := resolvespec.NewHandlerWithGORM(gormDB)
|
||||
```
|
||||
|
||||
#### Step 2: Use Interface-based Approach
|
||||
```go
|
||||
// Create database adapter
|
||||
dbAdapter := resolvespec.NewGormAdapter(gormDB)
|
||||
|
||||
// Create model registry
|
||||
registry := resolvespec.NewModelRegistry()
|
||||
|
||||
// Register your models
|
||||
registry.RegisterModel("public.users", &User{})
|
||||
registry.RegisterModel("public.orders", &Order{})
|
||||
|
||||
// Create handler
|
||||
handler := resolvespec.NewHandler(dbAdapter, registry)
|
||||
```
|
||||
|
||||
## Switching Database Backends
|
||||
|
||||
### From GORM to Bun
|
||||
```go
|
||||
// Add bun dependency first:
|
||||
// go get github.com/uptrace/bun
|
||||
|
||||
// Old GORM setup
|
||||
gormDB, _ := gorm.Open(sqlite.Open("test.db"), &gorm.Config{})
|
||||
gormAdapter := resolvespec.NewGormAdapter(gormDB)
|
||||
|
||||
// New Bun setup
|
||||
sqlDB, _ := sql.Open("sqlite3", "test.db")
|
||||
bunDB := bun.NewDB(sqlDB, sqlitedialect.New())
|
||||
bunAdapter := resolvespec.NewBunAdapter(bunDB)
|
||||
|
||||
// Handler creation is identical
|
||||
handler := resolvespec.NewHandler(bunAdapter, registry)
|
||||
```
|
||||
|
||||
## Router Flexibility
|
||||
|
||||
### Current Gorilla Mux (Default)
|
||||
```go
|
||||
router := mux.NewRouter()
|
||||
resolvespec.SetupRoutes(router, handler)
|
||||
```
|
||||
|
||||
### BunRouter (Built-in Support)
|
||||
```go
|
||||
// Simple setup
|
||||
router := bunrouter.New()
|
||||
resolvespec.SetupBunRouterWithResolveSpec(router, handler)
|
||||
|
||||
// Or using adapter
|
||||
routerAdapter := resolvespec.NewStandardBunRouterAdapter()
|
||||
// Use routerAdapter.GetBunRouter() for the underlying router
|
||||
```
|
||||
|
||||
### Using Router Adapters (Advanced)
|
||||
```go
|
||||
// For when you want router abstraction
|
||||
routerAdapter := resolvespec.NewStandardRouter()
|
||||
routerAdapter.RegisterRoute("/{schema}/{entity}", handlerFunc)
|
||||
```
|
||||
|
||||
## Model Registration
|
||||
|
||||
### Old Way (Still Works)
|
||||
```go
|
||||
// Models registered through existing models package
|
||||
handler.RegisterModel("public", "users", &User{})
|
||||
```
|
||||
|
||||
### New Way (Recommended)
|
||||
```go
|
||||
registry := resolvespec.NewModelRegistry()
|
||||
registry.RegisterModel("public.users", &User{})
|
||||
registry.RegisterModel("public.orders", &Order{})
|
||||
|
||||
handler := resolvespec.NewHandler(dbAdapter, registry)
|
||||
```
|
||||
|
||||
## Interface Definitions
|
||||
|
||||
### Database Interface
|
||||
```go
|
||||
type Database interface {
|
||||
NewSelect() SelectQuery
|
||||
NewInsert() InsertQuery
|
||||
NewUpdate() UpdateQuery
|
||||
NewDelete() DeleteQuery
|
||||
// ... transaction methods
|
||||
}
|
||||
```
|
||||
|
||||
### Available Adapters
|
||||
- `GormAdapter` - For GORM (ready to use)
|
||||
- `BunAdapter` - For Bun (add dependency: `go get github.com/uptrace/bun`)
|
||||
- Easy to create custom adapters for other ORMs
|
||||
|
||||
## Testing Benefits
|
||||
|
||||
### Before (Tightly Coupled)
|
||||
```go
|
||||
// Hard to test - requires real GORM setup
|
||||
func TestHandler(t *testing.T) {
|
||||
db := setupRealGormDB()
|
||||
handler := resolvespec.NewAPIHandler(db)
|
||||
// ... test logic
|
||||
}
|
||||
```
|
||||
|
||||
### After (Mockable)
|
||||
```go
|
||||
// Easy to test - mock the interfaces
|
||||
func TestHandler(t *testing.T) {
|
||||
mockDB := &MockDatabase{}
|
||||
mockRegistry := &MockModelRegistry{}
|
||||
handler := resolvespec.NewHandler(mockDB, mockRegistry)
|
||||
// ... test logic with mocks
|
||||
}
|
||||
```
|
||||
|
||||
## Breaking Changes
|
||||
- **None for existing code** - Full backward compatibility maintained
|
||||
- New interfaces are additive, not replacing existing APIs
|
||||
|
||||
## Recommended Migration Timeline
|
||||
1. **Phase 1**: Use existing code (no changes needed)
|
||||
2. **Phase 2**: Gradually adopt new constructors (`NewHandlerWithGORM`)
|
||||
3. **Phase 3**: Move to interface-based approach when needed
|
||||
4. **Phase 4**: Switch database backends if desired
|
||||
|
||||
## Getting Help
|
||||
- Check example functions in `resolvespec.go`
|
||||
- Review interface definitions in `database.go`
|
||||
- Examine adapter implementations for patterns
|
||||
66
Makefile
Normal file
66
Makefile
Normal file
@@ -0,0 +1,66 @@
|
||||
.PHONY: test test-unit test-integration docker-up docker-down clean
|
||||
|
||||
# Run all unit tests
|
||||
test-unit:
|
||||
@echo "Running unit tests..."
|
||||
@go test ./pkg/resolvespec ./pkg/restheadspec -v -cover
|
||||
|
||||
# Run all integration tests (requires PostgreSQL)
|
||||
test-integration:
|
||||
@echo "Running integration tests..."
|
||||
@go test -tags=integration ./pkg/resolvespec ./pkg/restheadspec -v
|
||||
|
||||
# Run all tests (unit + integration)
|
||||
test: test-unit test-integration
|
||||
|
||||
# Start PostgreSQL for integration tests
|
||||
docker-up:
|
||||
@echo "Starting PostgreSQL container..."
|
||||
@docker-compose up -d postgres-test
|
||||
@echo "Waiting for PostgreSQL to be ready..."
|
||||
@sleep 5
|
||||
@echo "PostgreSQL is ready!"
|
||||
|
||||
# Stop PostgreSQL container
|
||||
docker-down:
|
||||
@echo "Stopping PostgreSQL container..."
|
||||
@docker-compose down
|
||||
|
||||
# Clean up Docker volumes and test data
|
||||
clean:
|
||||
@echo "Cleaning up..."
|
||||
@docker-compose down -v
|
||||
@echo "Cleanup complete!"
|
||||
|
||||
# Run integration tests with Docker (full workflow)
|
||||
test-integration-docker: docker-up
|
||||
@echo "Running integration tests with Docker..."
|
||||
@go test -tags=integration ./pkg/resolvespec ./pkg/restheadspec -v
|
||||
@$(MAKE) docker-down
|
||||
|
||||
# Check test coverage
|
||||
coverage:
|
||||
@echo "Generating coverage report..."
|
||||
@go test ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage.out
|
||||
@go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "Coverage report generated: coverage.html"
|
||||
|
||||
# Run integration tests coverage
|
||||
coverage-integration:
|
||||
@echo "Generating integration test coverage report..."
|
||||
@go test -tags=integration ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage-integration.out
|
||||
@go tool cover -html=coverage-integration.out -o coverage-integration.html
|
||||
@echo "Integration coverage report generated: coverage-integration.html"
|
||||
|
||||
help:
|
||||
@echo "Available targets:"
|
||||
@echo " test-unit - Run unit tests"
|
||||
@echo " test-integration - Run integration tests (requires PostgreSQL)"
|
||||
@echo " test - Run all tests"
|
||||
@echo " docker-up - Start PostgreSQL container"
|
||||
@echo " docker-down - Stop PostgreSQL container"
|
||||
@echo " test-integration-docker - Run integration tests with Docker (automated)"
|
||||
@echo " clean - Clean up Docker volumes"
|
||||
@echo " coverage - Generate unit test coverage report"
|
||||
@echo " coverage-integration - Generate integration test coverage report"
|
||||
@echo " help - Show this help message"
|
||||
@@ -1,17 +1,18 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/logger"
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/testmodels"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/config"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/server"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/testmodels"
|
||||
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/resolvespec"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
@@ -20,12 +21,27 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Initialize logger
|
||||
fmt.Println("ResolveSpec test server starting")
|
||||
logger.Init(true)
|
||||
// Load configuration
|
||||
cfgMgr := config.NewManager()
|
||||
if err := cfgMgr.Load(); err != nil {
|
||||
log.Fatalf("Failed to load configuration: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := cfgMgr.GetConfig()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to get configuration: %v", err)
|
||||
}
|
||||
|
||||
// Initialize logger with configuration
|
||||
logger.Init(cfg.Logger.Dev)
|
||||
if cfg.Logger.Path != "" {
|
||||
logger.UpdateLoggerPath(cfg.Logger.Path, cfg.Logger.Dev)
|
||||
}
|
||||
logger.Info("ResolveSpec test server starting")
|
||||
logger.Info("Configuration loaded - Server will listen on: %s", cfg.Server.Addr)
|
||||
|
||||
// Initialize database
|
||||
db, err := initDB()
|
||||
db, err := initDB(cfg)
|
||||
if err != nil {
|
||||
logger.Error("Failed to initialize database: %+v", err)
|
||||
os.Exit(1)
|
||||
@@ -48,32 +64,54 @@ func main() {
|
||||
handler.RegisterModel("public", modelNames[i], model)
|
||||
}
|
||||
|
||||
// Setup routes using new SetupMuxRoutes function
|
||||
resolvespec.SetupMuxRoutes(r, handler)
|
||||
// Setup routes using new SetupMuxRoutes function (without authentication)
|
||||
resolvespec.SetupMuxRoutes(r, handler, nil)
|
||||
|
||||
// Start server
|
||||
logger.Info("Starting server on :8080")
|
||||
if err := http.ListenAndServe(":8080", r); err != nil {
|
||||
// Create graceful server with configuration
|
||||
srv := server.NewGracefulServer(server.Config{
|
||||
Addr: cfg.Server.Addr,
|
||||
Handler: r,
|
||||
ShutdownTimeout: cfg.Server.ShutdownTimeout,
|
||||
DrainTimeout: cfg.Server.DrainTimeout,
|
||||
ReadTimeout: cfg.Server.ReadTimeout,
|
||||
WriteTimeout: cfg.Server.WriteTimeout,
|
||||
IdleTimeout: cfg.Server.IdleTimeout,
|
||||
})
|
||||
|
||||
// Start server with graceful shutdown
|
||||
logger.Info("Starting server on %s", cfg.Server.Addr)
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
logger.Error("Server failed to start: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func initDB() (*gorm.DB, error) {
|
||||
func initDB(cfg *config.Config) (*gorm.DB, error) {
|
||||
// Configure GORM logger based on config
|
||||
logLevel := gormlog.Info
|
||||
if !cfg.Logger.Dev {
|
||||
logLevel = gormlog.Warn
|
||||
}
|
||||
|
||||
newLogger := gormlog.New(
|
||||
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
|
||||
gormlog.Config{
|
||||
SlowThreshold: time.Second, // Slow SQL threshold
|
||||
LogLevel: gormlog.Info, // Log level
|
||||
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
|
||||
ParameterizedQueries: true, // Don't include params in the SQL log
|
||||
Colorful: true, // Disable color
|
||||
SlowThreshold: time.Second, // Slow SQL threshold
|
||||
LogLevel: logLevel, // Log level
|
||||
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
|
||||
ParameterizedQueries: true, // Don't include params in the SQL log
|
||||
Colorful: cfg.Logger.Dev,
|
||||
},
|
||||
)
|
||||
|
||||
// Use database URL from config if available, otherwise use default SQLite
|
||||
dbURL := cfg.Database.URL
|
||||
if dbURL == "" {
|
||||
dbURL = "test.db"
|
||||
}
|
||||
|
||||
// Create SQLite database
|
||||
db, err := gorm.Open(sqlite.Open("test.db"), &gorm.Config{Logger: newLogger, FullSaveAssociations: false})
|
||||
db, err := gorm.Open(sqlite.Open(dbURL), &gorm.Config{Logger: newLogger, FullSaveAssociations: false})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
41
config.yaml
Normal file
41
config.yaml
Normal file
@@ -0,0 +1,41 @@
|
||||
# ResolveSpec Test Server Configuration
|
||||
# This is a minimal configuration for the test server
|
||||
|
||||
server:
|
||||
addr: ":8080"
|
||||
shutdown_timeout: 30s
|
||||
drain_timeout: 25s
|
||||
read_timeout: 10s
|
||||
write_timeout: 10s
|
||||
idle_timeout: 120s
|
||||
|
||||
logger:
|
||||
dev: true # Enable development mode for readable logs
|
||||
path: "" # Empty means log to stdout
|
||||
|
||||
cache:
|
||||
provider: "memory"
|
||||
|
||||
middleware:
|
||||
rate_limit_rps: 100.0
|
||||
rate_limit_burst: 200
|
||||
max_request_size: 10485760 # 10MB
|
||||
|
||||
cors:
|
||||
allowed_origins:
|
||||
- "*"
|
||||
allowed_methods:
|
||||
- "GET"
|
||||
- "POST"
|
||||
- "PUT"
|
||||
- "DELETE"
|
||||
- "OPTIONS"
|
||||
allowed_headers:
|
||||
- "*"
|
||||
max_age: 3600
|
||||
|
||||
tracing:
|
||||
enabled: false
|
||||
|
||||
database:
|
||||
url: "" # Empty means use default SQLite (test.db)
|
||||
57
config.yaml.example
Normal file
57
config.yaml.example
Normal file
@@ -0,0 +1,57 @@
|
||||
# ResolveSpec Configuration Example
|
||||
# This file demonstrates all available configuration options
|
||||
# Copy this file to config.yaml and customize as needed
|
||||
|
||||
server:
|
||||
addr: ":8080"
|
||||
shutdown_timeout: 30s
|
||||
drain_timeout: 25s
|
||||
read_timeout: 10s
|
||||
write_timeout: 10s
|
||||
idle_timeout: 120s
|
||||
|
||||
tracing:
|
||||
enabled: false
|
||||
service_name: "resolvespec"
|
||||
service_version: "1.0.0"
|
||||
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
|
||||
|
||||
cache:
|
||||
provider: "memory" # Options: memory, redis, memcache
|
||||
|
||||
redis:
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
|
||||
memcache:
|
||||
servers:
|
||||
- "localhost:11211"
|
||||
max_idle_conns: 10
|
||||
timeout: 100ms
|
||||
|
||||
logger:
|
||||
dev: false
|
||||
path: "" # Empty for stdout, or specify file path
|
||||
|
||||
middleware:
|
||||
rate_limit_rps: 100.0
|
||||
rate_limit_burst: 200
|
||||
max_request_size: 10485760 # 10MB in bytes
|
||||
|
||||
cors:
|
||||
allowed_origins:
|
||||
- "*"
|
||||
allowed_methods:
|
||||
- "GET"
|
||||
- "POST"
|
||||
- "PUT"
|
||||
- "DELETE"
|
||||
- "OPTIONS"
|
||||
allowed_headers:
|
||||
- "*"
|
||||
max_age: 3600
|
||||
|
||||
database:
|
||||
url: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable"
|
||||
27
docker-compose.yml
Normal file
27
docker-compose.yml
Normal file
@@ -0,0 +1,27 @@
|
||||
services:
|
||||
postgres-test:
|
||||
image: postgres:15-alpine
|
||||
container_name: resolvespec-postgres-test
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: postgres
|
||||
ports:
|
||||
- "5434:5432"
|
||||
volumes:
|
||||
- postgres-test-data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
networks:
|
||||
- resolvespec-test
|
||||
|
||||
volumes:
|
||||
postgres-test-data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
resolvespec-test:
|
||||
driver: bridge
|
||||
78
go.mod
78
go.mod
@@ -1,39 +1,95 @@
|
||||
module github.com/Warky-Devs/ResolveSpec
|
||||
module github.com/bitechdev/ResolveSpec
|
||||
|
||||
go 1.23.0
|
||||
go 1.24.0
|
||||
|
||||
toolchain go1.24.6
|
||||
|
||||
require (
|
||||
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/stretchr/testify v1.8.1
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/redis/go-redis/v9 v9.17.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/uptrace/bun v1.2.15
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15
|
||||
github.com/uptrace/bunrouter v1.0.23
|
||||
go.opentelemetry.io/otel v1.38.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0
|
||||
go.opentelemetry.io/otel/sdk v1.38.0
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
go.uber.org/zap v1.27.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/gorm v1.25.12
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
github.com/jackc/pgx/v5 v5.6.0 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.28 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/ncruces/go-strftime v0.1.9 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.66.1 // indirect
|
||||
github.com/prometheus/procfs v0.16.1 // indirect
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/sagikazarmark/locafero v0.11.0 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/spf13/viper v1.21.0 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
|
||||
github.com/uptrace/bunrouter v1.0.23 // indirect
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
golang.org/x/sys v0.34.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
golang.org/x/crypto v0.41.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.28.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
||||
google.golang.org/grpc v1.75.0 // indirect
|
||||
google.golang.org/protobuf v1.36.8 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
modernc.org/libc v1.22.5 // indirect
|
||||
modernc.org/mathutil v1.5.0 // indirect
|
||||
modernc.org/memory v1.5.0 // indirect
|
||||
modernc.org/sqlite v1.23.1 // indirect
|
||||
modernc.org/libc v1.66.3 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.38.0 // indirect
|
||||
)
|
||||
|
||||
201
go.sum
201
go.sum
@@ -1,76 +1,227 @@
|
||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
|
||||
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
|
||||
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
|
||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
|
||||
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
|
||||
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
|
||||
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
||||
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs=
|
||||
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
|
||||
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
|
||||
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
|
||||
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
|
||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
|
||||
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15 h1:7upGMVjFRB1oI78GQw6ruNLblYn5CR+kxqcbbeBBils=
|
||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15/go.mod h1:c7YIDaPNS2CU2uI1p7umFuFWkuKbDcPDDvp+DLHZnkI=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15 h1:M/rZJSjOPV4OmfTVnDPtL+wJmdMTqDUn8cuk5ycfABA=
|
||||
github.com/uptrace/bun/driver/sqliteshim v1.2.15/go.mod h1:YqwxFyvM992XOCpGJtXyKPkgkb+aZpIIMzGbpaw1hIk=
|
||||
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
|
||||
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
|
||||
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
|
||||
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
|
||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
|
||||
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
|
||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||
go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4=
|
||||
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
|
||||
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
|
||||
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
|
||||
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
|
||||
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
|
||||
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
|
||||
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
|
||||
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
|
||||
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE=
|
||||
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY=
|
||||
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ=
|
||||
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E=
|
||||
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds=
|
||||
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU=
|
||||
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM=
|
||||
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk=
|
||||
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=
|
||||
modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||
modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
|
||||
modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
|
||||
modernc.org/fileutil v1.3.8 h1:qtzNm7ED75pd1C7WgAGcK4edm4fvhtBsEiI/0NQ54YM=
|
||||
modernc.org/fileutil v1.3.8/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
|
||||
modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
|
||||
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
||||
|
||||
@@ -4,18 +4,63 @@
|
||||
read -p "Do you want to make a release version? (y/n): " make_release
|
||||
|
||||
if [[ $make_release =~ ^[Yy]$ ]]; then
|
||||
# Ask the user for the version number
|
||||
read -p "Enter the version number : " version
|
||||
# Get the latest tag from git
|
||||
latest_tag=$(git describe --tags --abbrev=0 2>/dev/null)
|
||||
|
||||
if [ -z "$latest_tag" ]; then
|
||||
# No tags exist yet, start with v1.0.0
|
||||
suggested_version="v1.0.0"
|
||||
echo "No existing tags found. Starting with $suggested_version"
|
||||
else
|
||||
echo "Latest tag: $latest_tag"
|
||||
|
||||
# Remove 'v' prefix if present
|
||||
version_number="${latest_tag#v}"
|
||||
|
||||
# Split version into major.minor.patch
|
||||
IFS='.' read -r major minor patch <<< "$version_number"
|
||||
|
||||
# Increment patch version
|
||||
patch=$((patch + 1))
|
||||
|
||||
# Construct new version
|
||||
suggested_version="v${major}.${minor}.${patch}"
|
||||
echo "Suggested next version: $suggested_version"
|
||||
fi
|
||||
|
||||
# Ask the user for the version number with the suggested version as default
|
||||
read -p "Enter the version number (press Enter for $suggested_version): " version
|
||||
|
||||
# Use suggested version if user pressed Enter without input
|
||||
if [ -z "$version" ]; then
|
||||
version="$suggested_version"
|
||||
fi
|
||||
|
||||
# Prepend 'v' to the version if it doesn't start with it
|
||||
if ! [[ $version =~ ^v ]]; then
|
||||
version="v$version"
|
||||
else
|
||||
echo "Version already starts with 'v'."
|
||||
fi
|
||||
|
||||
# Create an annotated tag
|
||||
git tag -a "$version" -m "Released $version"
|
||||
# Get commit logs since the last tag
|
||||
if [ -z "$latest_tag" ]; then
|
||||
# No previous tag, get all commits
|
||||
commit_logs=$(git log --pretty=format:"- %s" --no-merges)
|
||||
else
|
||||
# Get commits since the last tag
|
||||
commit_logs=$(git log "${latest_tag}..HEAD" --pretty=format:"- %s" --no-merges)
|
||||
fi
|
||||
|
||||
# Create the tag message
|
||||
if [ -z "$commit_logs" ]; then
|
||||
tag_message="Release $version"
|
||||
else
|
||||
tag_message="Release $version
|
||||
|
||||
${commit_logs}"
|
||||
fi
|
||||
|
||||
# Create an annotated tag with the commit logs
|
||||
git tag -a "$version" -m "$tag_message"
|
||||
|
||||
# Push the tag to the remote repository
|
||||
git push origin "$version"
|
||||
|
||||
340
pkg/cache/README.md
vendored
Normal file
340
pkg/cache/README.md
vendored
Normal file
@@ -0,0 +1,340 @@
|
||||
# Cache Package
|
||||
|
||||
A flexible, provider-based caching library for Go that supports multiple backend storage systems including in-memory, Redis, and Memcache.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Providers**: Support for in-memory, Redis, and Memcache backends
|
||||
- **Pluggable Architecture**: Easy to add custom cache providers
|
||||
- **Type-Safe API**: Automatic JSON serialization/deserialization
|
||||
- **TTL Support**: Configurable time-to-live for cache entries
|
||||
- **Context-Aware**: All operations support Go contexts
|
||||
- **Statistics**: Built-in cache statistics and monitoring
|
||||
- **Pattern Deletion**: Delete keys by pattern (Redis)
|
||||
- **Lazy Loading**: GetOrSet pattern for easy cache-aside implementation
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
go get github.com/bitechdev/ResolveSpec/pkg/cache
|
||||
```
|
||||
|
||||
For Redis support:
|
||||
```bash
|
||||
go get github.com/redis/go-redis/v9
|
||||
```
|
||||
|
||||
For Memcache support:
|
||||
```bash
|
||||
go get github.com/bradfitz/gomemcache/memcache
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### In-Memory Cache
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Initialize with in-memory provider
|
||||
cache.UseMemory(&cache.Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
MaxSize: 10000,
|
||||
})
|
||||
defer cache.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := cache.GetDefaultCache()
|
||||
|
||||
// Store a value
|
||||
type User struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
user := User{ID: 1, Name: "John"}
|
||||
c.Set(ctx, "user:1", user, 10*time.Minute)
|
||||
|
||||
// Retrieve a value
|
||||
var retrieved User
|
||||
c.Get(ctx, "user:1", &retrieved)
|
||||
}
|
||||
```
|
||||
|
||||
### Redis Cache
|
||||
|
||||
```go
|
||||
cache.UseRedis(&cache.RedisConfig{
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
Password: "",
|
||||
DB: 0,
|
||||
Options: &cache.Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
},
|
||||
})
|
||||
defer cache.Close()
|
||||
```
|
||||
|
||||
### Memcache
|
||||
|
||||
```go
|
||||
cache.UseMemcache(&cache.MemcacheConfig{
|
||||
Servers: []string{"localhost:11211"},
|
||||
Options: &cache.Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
},
|
||||
})
|
||||
defer cache.Close()
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Core Methods
|
||||
|
||||
#### Set
|
||||
```go
|
||||
Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error
|
||||
```
|
||||
Stores a value in the cache with automatic JSON serialization.
|
||||
|
||||
#### Get
|
||||
```go
|
||||
Get(ctx context.Context, key string, dest interface{}) error
|
||||
```
|
||||
Retrieves and deserializes a value from the cache.
|
||||
|
||||
#### SetBytes / GetBytes
|
||||
```go
|
||||
SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||
GetBytes(ctx context.Context, key string) ([]byte, error)
|
||||
```
|
||||
Store and retrieve raw bytes without serialization.
|
||||
|
||||
#### Delete
|
||||
```go
|
||||
Delete(ctx context.Context, key string) error
|
||||
```
|
||||
Removes a key from the cache.
|
||||
|
||||
#### DeleteByPattern
|
||||
```go
|
||||
DeleteByPattern(ctx context.Context, pattern string) error
|
||||
```
|
||||
Removes all keys matching a pattern (Redis only).
|
||||
|
||||
#### Clear
|
||||
```go
|
||||
Clear(ctx context.Context) error
|
||||
```
|
||||
Removes all items from the cache.
|
||||
|
||||
#### Exists
|
||||
```go
|
||||
Exists(ctx context.Context, key string) bool
|
||||
```
|
||||
Checks if a key exists in the cache.
|
||||
|
||||
#### GetOrSet
|
||||
```go
|
||||
GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration,
|
||||
loader func() (interface{}, error)) error
|
||||
```
|
||||
Retrieves a value from cache, or loads and caches it if not found (lazy loading).
|
||||
|
||||
#### Stats
|
||||
```go
|
||||
Stats(ctx context.Context) (*CacheStats, error)
|
||||
```
|
||||
Returns cache statistics including hits, misses, and key counts.
|
||||
|
||||
## Provider Configuration
|
||||
|
||||
### In-Memory Options
|
||||
|
||||
```go
|
||||
&cache.Options{
|
||||
DefaultTTL: 5 * time.Minute, // Default expiration time
|
||||
MaxSize: 10000, // Maximum number of items
|
||||
EvictionPolicy: "LRU", // Eviction strategy (future)
|
||||
}
|
||||
```
|
||||
|
||||
### Redis Configuration
|
||||
|
||||
```go
|
||||
&cache.RedisConfig{
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
Password: "", // Optional authentication
|
||||
DB: 0, // Database number
|
||||
PoolSize: 10, // Connection pool size
|
||||
Options: &cache.Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
### Memcache Configuration
|
||||
|
||||
```go
|
||||
&cache.MemcacheConfig{
|
||||
Servers: []string{"localhost:11211"},
|
||||
MaxIdleConns: 2,
|
||||
Timeout: 1 * time.Second,
|
||||
Options: &cache.Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Custom Provider
|
||||
|
||||
```go
|
||||
// Create a custom provider instance
|
||||
memProvider := cache.NewMemoryProvider(&cache.Options{
|
||||
DefaultTTL: 10 * time.Minute,
|
||||
MaxSize: 500,
|
||||
})
|
||||
|
||||
// Initialize with custom provider
|
||||
cache.Initialize(memProvider)
|
||||
```
|
||||
|
||||
### Lazy Loading Pattern
|
||||
|
||||
```go
|
||||
var data ExpensiveData
|
||||
err := c.GetOrSet(ctx, "expensive:key", &data, 10*time.Minute, func() (interface{}, error) {
|
||||
// This expensive operation only runs if key is not in cache
|
||||
return computeExpensiveData(), nil
|
||||
})
|
||||
```
|
||||
|
||||
### Query API Cache
|
||||
|
||||
The package includes specialized functions for caching query results:
|
||||
|
||||
```go
|
||||
// Cache a query result
|
||||
api := "GetUsers"
|
||||
query := "SELECT * FROM users WHERE active = true"
|
||||
tablenames := "users"
|
||||
total := int64(150)
|
||||
|
||||
cache.PutQueryAPICache(ctx, api, query, tablenames, total)
|
||||
|
||||
// Retrieve cached query
|
||||
hash := cache.HashQueryAPICache(api, query)
|
||||
cachedQuery, err := cache.FetchQueryAPICache(ctx, hash)
|
||||
```
|
||||
|
||||
## Provider Comparison
|
||||
|
||||
| Feature | In-Memory | Redis | Memcache |
|
||||
|---------|-----------|-------|----------|
|
||||
| Persistence | No | Yes | No |
|
||||
| Distributed | No | Yes | Yes |
|
||||
| Pattern Delete | No | Yes | No |
|
||||
| Statistics | Full | Full | Limited |
|
||||
| Atomic Operations | Yes | Yes | Yes |
|
||||
| Max Item Size | Memory | 512MB | 1MB |
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use contexts**: Always pass context for cancellation and timeout control
|
||||
2. **Set appropriate TTLs**: Balance between freshness and performance
|
||||
3. **Handle errors**: Cache misses and errors should be handled gracefully
|
||||
4. **Monitor statistics**: Use Stats() to monitor cache performance
|
||||
5. **Clean up**: Always call Close() when shutting down
|
||||
6. **Pattern consistency**: Use consistent key naming patterns (e.g., "user:id:field")
|
||||
|
||||
## Example: Complete Application
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||
)
|
||||
|
||||
type UserService struct {
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
func NewUserService() *UserService {
|
||||
// Initialize with Redis in production, memory for testing
|
||||
cache.UseRedis(&cache.RedisConfig{
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
Options: &cache.Options{
|
||||
DefaultTTL: 10 * time.Minute,
|
||||
},
|
||||
})
|
||||
|
||||
return &UserService{
|
||||
cache: cache.GetDefaultCache(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserService) GetUser(ctx context.Context, userID int) (*User, error) {
|
||||
var user User
|
||||
cacheKey := fmt.Sprintf("user:%d", userID)
|
||||
|
||||
// Try to get from cache first
|
||||
err := s.cache.GetOrSet(ctx, cacheKey, &user, 15*time.Minute, func() (interface{}, error) {
|
||||
// Load from database if not in cache
|
||||
return s.loadUserFromDB(userID)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) InvalidateUser(ctx context.Context, userID int) error {
|
||||
cacheKey := fmt.Sprintf("user:%d", userID)
|
||||
return s.cache.Delete(ctx, cacheKey)
|
||||
}
|
||||
|
||||
func main() {
|
||||
service := NewUserService()
|
||||
defer cache.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
user, err := service.GetUser(ctx, 123)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
log.Printf("User: %+v", user)
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- **In-Memory**: Fastest but limited by RAM and not distributed
|
||||
- **Redis**: Great for distributed systems, persistent, but network overhead
|
||||
- **Memcache**: Good for distributed caching, simpler than Redis but less features
|
||||
|
||||
Choose based on your needs:
|
||||
- Single instance? Use in-memory
|
||||
- Need persistence or advanced features? Use Redis
|
||||
- Simple distributed cache? Use Memcache
|
||||
|
||||
## License
|
||||
|
||||
See repository license.
|
||||
76
pkg/cache/cache.go
vendored
Normal file
76
pkg/cache/cache.go
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultCache *Cache
|
||||
)
|
||||
|
||||
// Initialize initializes the cache with a provider.
|
||||
// If not called, the package will use an in-memory provider by default.
|
||||
func Initialize(provider Provider) {
|
||||
defaultCache = NewCache(provider)
|
||||
}
|
||||
|
||||
// UseMemory configures the cache to use in-memory storage.
|
||||
func UseMemory(opts *Options) error {
|
||||
provider := NewMemoryProvider(opts)
|
||||
defaultCache = NewCache(provider)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UseRedis configures the cache to use Redis storage.
|
||||
func UseRedis(config *RedisConfig) error {
|
||||
provider, err := NewRedisProvider(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize Redis provider: %w", err)
|
||||
}
|
||||
defaultCache = NewCache(provider)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UseMemcache configures the cache to use Memcache storage.
|
||||
func UseMemcache(config *MemcacheConfig) error {
|
||||
provider, err := NewMemcacheProvider(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize Memcache provider: %w", err)
|
||||
}
|
||||
defaultCache = NewCache(provider)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDefaultCache returns the default cache instance.
|
||||
// Initializes with in-memory provider if not already initialized.
|
||||
func GetDefaultCache() *Cache {
|
||||
if defaultCache == nil {
|
||||
_ = UseMemory(&Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
MaxSize: 10000,
|
||||
})
|
||||
}
|
||||
return defaultCache
|
||||
}
|
||||
|
||||
// SetDefaultCache sets a custom cache instance as the default cache.
|
||||
// This is useful for testing or when you want to use a pre-configured cache instance.
|
||||
func SetDefaultCache(cache *Cache) {
|
||||
defaultCache = cache
|
||||
}
|
||||
|
||||
// GetStats returns cache statistics.
|
||||
func GetStats(ctx context.Context) (*CacheStats, error) {
|
||||
cache := GetDefaultCache()
|
||||
return cache.Stats(ctx)
|
||||
}
|
||||
|
||||
// Close closes the cache and releases resources.
|
||||
func Close() error {
|
||||
if defaultCache != nil {
|
||||
return defaultCache.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
147
pkg/cache/cache_manager.go
vendored
Normal file
147
pkg/cache/cache_manager.go
vendored
Normal file
@@ -0,0 +1,147 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Cache is the main cache manager that wraps a Provider.
|
||||
type Cache struct {
|
||||
provider Provider
|
||||
}
|
||||
|
||||
// NewCache creates a new cache manager with the specified provider.
|
||||
func NewCache(provider Provider) *Cache {
|
||||
return &Cache{
|
||||
provider: provider,
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves and deserializes a value from the cache.
|
||||
func (c *Cache) Get(ctx context.Context, key string, dest interface{}) error {
|
||||
data, exists := c.provider.Get(ctx, key)
|
||||
if !exists {
|
||||
return fmt.Errorf("key not found: %s", key)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, dest); err != nil {
|
||||
return fmt.Errorf("failed to deserialize: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBytes retrieves raw bytes from the cache.
|
||||
func (c *Cache) GetBytes(ctx context.Context, key string) ([]byte, error) {
|
||||
data, exists := c.provider.Get(ctx, key)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("key not found: %s", key)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Set serializes and stores a value in the cache with the specified TTL.
|
||||
func (c *Cache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize: %w", err)
|
||||
}
|
||||
|
||||
return c.provider.Set(ctx, key, data, ttl)
|
||||
}
|
||||
|
||||
// SetBytes stores raw bytes in the cache with the specified TTL.
|
||||
func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
return c.provider.Set(ctx, key, value, ttl)
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
func (c *Cache) Delete(ctx context.Context, key string) error {
|
||||
return c.provider.Delete(ctx, key)
|
||||
}
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||
return c.provider.DeleteByPattern(ctx, pattern)
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache.
|
||||
func (c *Cache) Clear(ctx context.Context) error {
|
||||
return c.provider.Clear(ctx)
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache.
|
||||
func (c *Cache) Exists(ctx context.Context, key string) bool {
|
||||
return c.provider.Exists(ctx, key)
|
||||
}
|
||||
|
||||
// Stats returns statistics about the cache.
|
||||
func (c *Cache) Stats(ctx context.Context) (*CacheStats, error) {
|
||||
return c.provider.Stats(ctx)
|
||||
}
|
||||
|
||||
// Close closes the cache and releases any resources.
|
||||
func (c *Cache) Close() error {
|
||||
return c.provider.Close()
|
||||
}
|
||||
|
||||
// GetOrSet retrieves a value from cache, or sets it if it doesn't exist.
|
||||
// The loader function is called only if the key is not found in cache.
|
||||
func (c *Cache) GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration, loader func() (interface{}, error)) error {
|
||||
// Try to get from cache first
|
||||
err := c.Get(ctx, key, dest)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load the value
|
||||
value, err := loader()
|
||||
if err != nil {
|
||||
return fmt.Errorf("loader failed: %w", err)
|
||||
}
|
||||
|
||||
// Store in cache
|
||||
if err := c.Set(ctx, key, value, ttl); err != nil {
|
||||
return fmt.Errorf("failed to cache value: %w", err)
|
||||
}
|
||||
|
||||
// Populate dest with the loaded value
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize loaded value: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, dest); err != nil {
|
||||
return fmt.Errorf("failed to deserialize loaded value: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Remember is a convenience function that caches the result of a function call.
|
||||
// It's similar to GetOrSet but returns the value directly.
|
||||
func (c *Cache) Remember(ctx context.Context, key string, ttl time.Duration, loader func() (interface{}, error)) (interface{}, error) {
|
||||
// Try to get from cache first as bytes
|
||||
data, err := c.GetBytes(ctx, key)
|
||||
if err == nil {
|
||||
var result interface{}
|
||||
if err := json.Unmarshal(data, &result); err == nil {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Load the value
|
||||
value, err := loader()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loader failed: %w", err)
|
||||
}
|
||||
|
||||
// Store in cache
|
||||
if err := c.Set(ctx, key, value, ttl); err != nil {
|
||||
return nil, fmt.Errorf("failed to cache value: %w", err)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
69
pkg/cache/cache_test.go
vendored
Normal file
69
pkg/cache/cache_test.go
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSetDefaultCache(t *testing.T) {
|
||||
// Create a custom cache instance
|
||||
provider := NewMemoryProvider(&Options{
|
||||
DefaultTTL: 1 * time.Minute,
|
||||
MaxSize: 50,
|
||||
})
|
||||
customCache := NewCache(provider)
|
||||
|
||||
// Set it as the default
|
||||
SetDefaultCache(customCache)
|
||||
|
||||
// Verify it's now the default
|
||||
retrievedCache := GetDefaultCache()
|
||||
if retrievedCache != customCache {
|
||||
t.Error("SetDefaultCache did not set the cache correctly")
|
||||
}
|
||||
|
||||
// Test that we can use it
|
||||
ctx := context.Background()
|
||||
testKey := "test_key"
|
||||
testValue := "test_value"
|
||||
|
||||
err := retrievedCache.Set(ctx, testKey, testValue, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set value: %v", err)
|
||||
}
|
||||
|
||||
var result string
|
||||
err = retrievedCache.Get(ctx, testKey, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get value: %v", err)
|
||||
}
|
||||
|
||||
if result != testValue {
|
||||
t.Errorf("Expected %s, got %s", testValue, result)
|
||||
}
|
||||
|
||||
// Clean up - reset to default
|
||||
SetDefaultCache(nil)
|
||||
}
|
||||
|
||||
func TestGetDefaultCacheInitialization(t *testing.T) {
|
||||
// Reset to nil first
|
||||
SetDefaultCache(nil)
|
||||
|
||||
// GetDefaultCache should auto-initialize
|
||||
cache := GetDefaultCache()
|
||||
if cache == nil {
|
||||
t.Error("GetDefaultCache should auto-initialize, got nil")
|
||||
}
|
||||
|
||||
// Should be usable
|
||||
ctx := context.Background()
|
||||
err := cache.Set(ctx, "test", "value", time.Minute)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to use auto-initialized cache: %v", err)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
SetDefaultCache(nil)
|
||||
}
|
||||
266
pkg/cache/example_usage.go
vendored
Normal file
266
pkg/cache/example_usage.go
vendored
Normal file
@@ -0,0 +1,266 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ExampleInMemoryCache demonstrates using the in-memory cache provider.
|
||||
func ExampleInMemoryCache() {
|
||||
// Initialize with in-memory provider
|
||||
err := UseMemory(&Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
MaxSize: 1000,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get the cache instance
|
||||
cache := GetDefaultCache()
|
||||
|
||||
// Store a value
|
||||
type User struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
|
||||
user := User{ID: 1, Name: "John Doe"}
|
||||
err = cache.Set(ctx, "user:1", user, 10*time.Minute)
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Retrieve a value
|
||||
var retrieved User
|
||||
err = cache.Get(ctx, "user:1", &retrieved)
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Retrieved user: %+v\n", retrieved)
|
||||
|
||||
// Check if key exists
|
||||
exists := cache.Exists(ctx, "user:1")
|
||||
fmt.Printf("Key exists: %v\n", exists)
|
||||
|
||||
// Delete a key
|
||||
err = cache.Delete(ctx, "user:1")
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Get statistics
|
||||
stats, err := cache.Stats(ctx)
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Cache stats: %+v\n", stats)
|
||||
_ = Close()
|
||||
}
|
||||
|
||||
// ExampleRedisCache demonstrates using the Redis cache provider.
|
||||
func ExampleRedisCache() {
|
||||
// Initialize with Redis provider
|
||||
err := UseRedis(&RedisConfig{
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
Password: "", // Set if Redis requires authentication
|
||||
DB: 0,
|
||||
Options: &Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get the cache instance
|
||||
cache := GetDefaultCache()
|
||||
|
||||
// Store raw bytes
|
||||
data := []byte("Hello, Redis!")
|
||||
err = cache.SetBytes(ctx, "greeting", data, 1*time.Hour)
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Retrieve raw bytes
|
||||
retrieved, err := cache.GetBytes(ctx, "greeting")
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Retrieved data: %s\n", string(retrieved))
|
||||
|
||||
// Clear all cache
|
||||
err = cache.Clear(ctx)
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
_ = Close()
|
||||
}
|
||||
|
||||
// ExampleMemcacheCache demonstrates using the Memcache cache provider.
|
||||
func ExampleMemcacheCache() {
|
||||
// Initialize with Memcache provider
|
||||
err := UseMemcache(&MemcacheConfig{
|
||||
Servers: []string{"localhost:11211"},
|
||||
Options: &Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Get the cache instance
|
||||
cache := GetDefaultCache()
|
||||
|
||||
// Store a value
|
||||
type Product struct {
|
||||
ID int
|
||||
Name string
|
||||
Price float64
|
||||
}
|
||||
|
||||
product := Product{ID: 100, Name: "Widget", Price: 29.99}
|
||||
err = cache.Set(ctx, "product:100", product, 30*time.Minute)
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Retrieve a value
|
||||
var retrieved Product
|
||||
err = cache.Get(ctx, "product:100", &retrieved)
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Retrieved product: %+v\n", retrieved)
|
||||
_ = Close()
|
||||
}
|
||||
|
||||
// ExampleGetOrSet demonstrates the GetOrSet pattern for lazy loading.
|
||||
func ExampleGetOrSet() {
|
||||
err := UseMemory(&Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
MaxSize: 1000,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
cache := GetDefaultCache()
|
||||
|
||||
type ExpensiveData struct {
|
||||
Result string
|
||||
}
|
||||
|
||||
var data ExpensiveData
|
||||
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
|
||||
// This expensive operation only runs if the key is not in cache
|
||||
fmt.Println("Computing expensive result...")
|
||||
time.Sleep(1 * time.Second)
|
||||
return ExpensiveData{Result: "computed value"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Data: %+v\n", data)
|
||||
|
||||
// Second call will use cached value
|
||||
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
|
||||
fmt.Println("This won't be called!")
|
||||
return ExpensiveData{Result: "new value"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Cached data: %+v\n", data)
|
||||
_ = Close()
|
||||
}
|
||||
|
||||
// ExampleCustomProvider demonstrates using a custom provider.
|
||||
func ExampleCustomProvider() {
|
||||
// Create a custom provider
|
||||
memProvider := NewMemoryProvider(&Options{
|
||||
DefaultTTL: 10 * time.Minute,
|
||||
MaxSize: 500,
|
||||
})
|
||||
|
||||
// Initialize with custom provider
|
||||
Initialize(memProvider)
|
||||
|
||||
ctx := context.Background()
|
||||
cache := GetDefaultCache()
|
||||
|
||||
// Use the cache
|
||||
err := cache.SetBytes(ctx, "key", []byte("value"), 5*time.Minute)
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Clean expired items (memory provider specific)
|
||||
if mp, ok := cache.provider.(*MemoryProvider); ok {
|
||||
count := mp.CleanExpired(ctx)
|
||||
fmt.Printf("Cleaned %d expired items\n", count)
|
||||
}
|
||||
_ = Close()
|
||||
}
|
||||
|
||||
// ExampleDeleteByPattern demonstrates pattern-based deletion (Redis only).
|
||||
func ExampleDeleteByPattern() {
|
||||
err := UseRedis(&RedisConfig{
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
Options: &Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
cache := GetDefaultCache()
|
||||
|
||||
// Store multiple keys with a pattern
|
||||
_ = cache.SetBytes(ctx, "user:1:profile", []byte("profile1"), 10*time.Minute)
|
||||
_ = cache.SetBytes(ctx, "user:2:profile", []byte("profile2"), 10*time.Minute)
|
||||
_ = cache.SetBytes(ctx, "user:1:settings", []byte("settings1"), 10*time.Minute)
|
||||
|
||||
// Delete all keys matching pattern (Redis glob pattern)
|
||||
err = cache.DeleteByPattern(ctx, "user:*:profile")
|
||||
if err != nil {
|
||||
_ = Close()
|
||||
log.Print(err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Deleted all user profile keys")
|
||||
_ = Close()
|
||||
}
|
||||
57
pkg/cache/provider.go
vendored
Normal file
57
pkg/cache/provider.go
vendored
Normal file
@@ -0,0 +1,57 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Provider defines the interface that all cache providers must implement.
|
||||
type Provider interface {
|
||||
// Get retrieves a value from the cache by key.
|
||||
// Returns nil, false if key doesn't exist or is expired.
|
||||
Get(ctx context.Context, key string) ([]byte, bool)
|
||||
|
||||
// Set stores a value in the cache with the specified TTL.
|
||||
// If ttl is 0, the item never expires.
|
||||
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
Delete(ctx context.Context, key string) error
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
// Pattern syntax depends on the provider implementation.
|
||||
DeleteByPattern(ctx context.Context, pattern string) error
|
||||
|
||||
// Clear removes all items from the cache.
|
||||
Clear(ctx context.Context) error
|
||||
|
||||
// Exists checks if a key exists in the cache.
|
||||
Exists(ctx context.Context, key string) bool
|
||||
|
||||
// Close closes the provider and releases any resources.
|
||||
Close() error
|
||||
|
||||
// Stats returns statistics about the cache provider.
|
||||
Stats(ctx context.Context) (*CacheStats, error)
|
||||
}
|
||||
|
||||
// CacheStats contains cache statistics.
|
||||
type CacheStats struct {
|
||||
Hits int64 `json:"hits"`
|
||||
Misses int64 `json:"misses"`
|
||||
Keys int64 `json:"keys"`
|
||||
ProviderType string `json:"provider_type"`
|
||||
ProviderStats map[string]any `json:"provider_stats,omitempty"`
|
||||
}
|
||||
|
||||
// Options contains configuration options for cache providers.
|
||||
type Options struct {
|
||||
// DefaultTTL is the default time-to-live for cache items.
|
||||
DefaultTTL time.Duration
|
||||
|
||||
// MaxSize is the maximum number of items (for in-memory provider).
|
||||
MaxSize int
|
||||
|
||||
// EvictionPolicy determines how items are evicted (LRU, LFU, etc).
|
||||
EvictionPolicy string
|
||||
}
|
||||
144
pkg/cache/provider_memcache.go
vendored
Normal file
144
pkg/cache/provider_memcache.go
vendored
Normal file
@@ -0,0 +1,144 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/bradfitz/gomemcache/memcache"
|
||||
)
|
||||
|
||||
// MemcacheProvider is a Memcache implementation of the Provider interface.
|
||||
type MemcacheProvider struct {
|
||||
client *memcache.Client
|
||||
options *Options
|
||||
}
|
||||
|
||||
// MemcacheConfig contains Memcache-specific configuration.
|
||||
type MemcacheConfig struct {
|
||||
// Servers is a list of memcache server addresses (e.g., "localhost:11211")
|
||||
Servers []string
|
||||
|
||||
// MaxIdleConns is the maximum number of idle connections (default: 2)
|
||||
MaxIdleConns int
|
||||
|
||||
// Timeout for connection operations (default: 1 second)
|
||||
Timeout time.Duration
|
||||
|
||||
// Options contains general cache options
|
||||
Options *Options
|
||||
}
|
||||
|
||||
// NewMemcacheProvider creates a new Memcache cache provider.
|
||||
func NewMemcacheProvider(config *MemcacheConfig) (*MemcacheProvider, error) {
|
||||
if config == nil {
|
||||
config = &MemcacheConfig{
|
||||
Servers: []string{"localhost:11211"},
|
||||
}
|
||||
}
|
||||
|
||||
if len(config.Servers) == 0 {
|
||||
config.Servers = []string{"localhost:11211"}
|
||||
}
|
||||
|
||||
if config.MaxIdleConns == 0 {
|
||||
config.MaxIdleConns = 2
|
||||
}
|
||||
|
||||
if config.Timeout == 0 {
|
||||
config.Timeout = 1 * time.Second
|
||||
}
|
||||
|
||||
if config.Options == nil {
|
||||
config.Options = &Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
client := memcache.New(config.Servers...)
|
||||
client.MaxIdleConns = config.MaxIdleConns
|
||||
client.Timeout = config.Timeout
|
||||
|
||||
// Test connection
|
||||
if err := client.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Memcache: %w", err)
|
||||
}
|
||||
|
||||
return &MemcacheProvider{
|
||||
client: client,
|
||||
options: config.Options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache by key.
|
||||
func (m *MemcacheProvider) Get(ctx context.Context, key string) ([]byte, bool) {
|
||||
item, err := m.client.Get(key)
|
||||
if err == memcache.ErrCacheMiss {
|
||||
return nil, false
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return item.Value, true
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with the specified TTL.
|
||||
func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
if ttl == 0 {
|
||||
ttl = m.options.DefaultTTL
|
||||
}
|
||||
|
||||
item := &memcache.Item{
|
||||
Key: key,
|
||||
Value: value,
|
||||
Expiration: int32(ttl.Seconds()),
|
||||
}
|
||||
|
||||
return m.client.Set(item)
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
|
||||
err := m.client.Delete(key)
|
||||
if err == memcache.ErrCacheMiss {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
// Note: Memcache does not support pattern-based deletion natively.
|
||||
// This is a no-op for memcache and returns an error.
|
||||
func (m *MemcacheProvider) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||
return fmt.Errorf("pattern-based deletion is not supported by Memcache")
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache.
|
||||
func (m *MemcacheProvider) Clear(ctx context.Context) error {
|
||||
return m.client.FlushAll()
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache.
|
||||
func (m *MemcacheProvider) Exists(ctx context.Context, key string) bool {
|
||||
_, err := m.client.Get(key)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Close closes the provider and releases any resources.
|
||||
func (m *MemcacheProvider) Close() error {
|
||||
// Memcache client doesn't have a close method
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns statistics about the cache provider.
|
||||
// Note: Memcache provider returns limited statistics.
|
||||
func (m *MemcacheProvider) Stats(ctx context.Context) (*CacheStats, error) {
|
||||
stats := &CacheStats{
|
||||
ProviderType: "memcache",
|
||||
ProviderStats: map[string]any{
|
||||
"note": "Memcache does not provide detailed statistics through the standard client",
|
||||
},
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
238
pkg/cache/provider_memory.go
vendored
Normal file
238
pkg/cache/provider_memory.go
vendored
Normal file
@@ -0,0 +1,238 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// memoryItem represents a cached item in memory.
|
||||
type memoryItem struct {
|
||||
Value []byte
|
||||
Expiration time.Time
|
||||
LastAccess time.Time
|
||||
HitCount int64
|
||||
}
|
||||
|
||||
// isExpired checks if the item has expired.
|
||||
func (m *memoryItem) isExpired() bool {
|
||||
if m.Expiration.IsZero() {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(m.Expiration)
|
||||
}
|
||||
|
||||
// MemoryProvider is an in-memory implementation of the Provider interface.
|
||||
type MemoryProvider struct {
|
||||
mu sync.RWMutex
|
||||
items map[string]*memoryItem
|
||||
options *Options
|
||||
hits atomic.Int64
|
||||
misses atomic.Int64
|
||||
}
|
||||
|
||||
// NewMemoryProvider creates a new in-memory cache provider.
|
||||
func NewMemoryProvider(opts *Options) *MemoryProvider {
|
||||
if opts == nil {
|
||||
opts = &Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
MaxSize: 10000,
|
||||
}
|
||||
}
|
||||
|
||||
return &MemoryProvider{
|
||||
items: make(map[string]*memoryItem),
|
||||
options: opts,
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache by key.
|
||||
func (m *MemoryProvider) Get(ctx context.Context, key string) ([]byte, bool) {
|
||||
// First try with read lock for fast path
|
||||
m.mu.RLock()
|
||||
item, exists := m.items[key]
|
||||
if !exists {
|
||||
m.mu.RUnlock()
|
||||
m.misses.Add(1)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if item.isExpired() {
|
||||
m.mu.RUnlock()
|
||||
// Upgrade to write lock to delete expired item
|
||||
m.mu.Lock()
|
||||
delete(m.items, key)
|
||||
m.mu.Unlock()
|
||||
m.misses.Add(1)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Update stats and access time with write lock
|
||||
value := item.Value
|
||||
m.mu.RUnlock()
|
||||
|
||||
// Update access tracking with write lock
|
||||
m.mu.Lock()
|
||||
item.LastAccess = time.Now()
|
||||
item.HitCount++
|
||||
m.mu.Unlock()
|
||||
|
||||
m.hits.Add(1)
|
||||
return value, true
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with the specified TTL.
|
||||
func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if ttl == 0 {
|
||||
ttl = m.options.DefaultTTL
|
||||
}
|
||||
|
||||
var expiration time.Time
|
||||
if ttl > 0 {
|
||||
expiration = time.Now().Add(ttl)
|
||||
}
|
||||
|
||||
// Check max size and evict if necessary
|
||||
if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize {
|
||||
if _, exists := m.items[key]; !exists {
|
||||
m.evictOne()
|
||||
}
|
||||
}
|
||||
|
||||
m.items[key] = &memoryItem{
|
||||
Value: value,
|
||||
Expiration: expiration,
|
||||
LastAccess: time.Now(),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
delete(m.items, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid pattern: %w", err)
|
||||
}
|
||||
|
||||
for key := range m.items {
|
||||
if re.MatchString(key) {
|
||||
delete(m.items, key)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache.
|
||||
func (m *MemoryProvider) Clear(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.items = make(map[string]*memoryItem)
|
||||
m.hits.Store(0)
|
||||
m.misses.Store(0)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache.
|
||||
func (m *MemoryProvider) Exists(ctx context.Context, key string) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
item, exists := m.items[key]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
return !item.isExpired()
|
||||
}
|
||||
|
||||
// Close closes the provider and releases any resources.
|
||||
func (m *MemoryProvider) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.items = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats returns statistics about the cache provider.
|
||||
func (m *MemoryProvider) Stats(ctx context.Context) (*CacheStats, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// Clean expired items first
|
||||
validKeys := 0
|
||||
for _, item := range m.items {
|
||||
if !item.isExpired() {
|
||||
validKeys++
|
||||
}
|
||||
}
|
||||
|
||||
return &CacheStats{
|
||||
Hits: m.hits.Load(),
|
||||
Misses: m.misses.Load(),
|
||||
Keys: int64(validKeys),
|
||||
ProviderType: "memory",
|
||||
ProviderStats: map[string]any{
|
||||
"capacity": m.options.MaxSize,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// evictOne removes one item from the cache using LRU strategy.
|
||||
func (m *MemoryProvider) evictOne() {
|
||||
var oldestKey string
|
||||
var oldestTime time.Time
|
||||
|
||||
for key, item := range m.items {
|
||||
if item.isExpired() {
|
||||
delete(m.items, key)
|
||||
return
|
||||
}
|
||||
|
||||
if oldestKey == "" || item.LastAccess.Before(oldestTime) {
|
||||
oldestKey = key
|
||||
oldestTime = item.LastAccess
|
||||
}
|
||||
}
|
||||
|
||||
if oldestKey != "" {
|
||||
delete(m.items, oldestKey)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanExpired removes all expired items from the cache.
|
||||
func (m *MemoryProvider) CleanExpired(ctx context.Context) int {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
count := 0
|
||||
for key, item := range m.items {
|
||||
if item.isExpired() {
|
||||
delete(m.items, key)
|
||||
count++
|
||||
}
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
185
pkg/cache/provider_redis.go
vendored
Normal file
185
pkg/cache/provider_redis.go
vendored
Normal file
@@ -0,0 +1,185 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RedisProvider is a Redis implementation of the Provider interface.
|
||||
type RedisProvider struct {
|
||||
client *redis.Client
|
||||
options *Options
|
||||
}
|
||||
|
||||
// RedisConfig contains Redis-specific configuration.
|
||||
type RedisConfig struct {
|
||||
// Host is the Redis server host (default: localhost)
|
||||
Host string
|
||||
|
||||
// Port is the Redis server port (default: 6379)
|
||||
Port int
|
||||
|
||||
// Password for Redis authentication (optional)
|
||||
Password string
|
||||
|
||||
// DB is the Redis database number (default: 0)
|
||||
DB int
|
||||
|
||||
// PoolSize is the maximum number of connections (default: 10)
|
||||
PoolSize int
|
||||
|
||||
// Options contains general cache options
|
||||
Options *Options
|
||||
}
|
||||
|
||||
// NewRedisProvider creates a new Redis cache provider.
|
||||
func NewRedisProvider(config *RedisConfig) (*RedisProvider, error) {
|
||||
if config == nil {
|
||||
config = &RedisConfig{
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
DB: 0,
|
||||
}
|
||||
}
|
||||
|
||||
if config.Host == "" {
|
||||
config.Host = "localhost"
|
||||
}
|
||||
if config.Port == 0 {
|
||||
config.Port = 6379
|
||||
}
|
||||
if config.PoolSize == 0 {
|
||||
config.PoolSize = 10
|
||||
}
|
||||
|
||||
if config.Options == nil {
|
||||
config.Options = &Options{
|
||||
DefaultTTL: 5 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", config.Host, config.Port),
|
||||
Password: config.Password,
|
||||
DB: config.DB,
|
||||
PoolSize: config.PoolSize,
|
||||
})
|
||||
|
||||
// Test connection
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
|
||||
}
|
||||
|
||||
return &RedisProvider{
|
||||
client: client,
|
||||
options: config.Options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Get retrieves a value from the cache by key.
|
||||
func (r *RedisProvider) Get(ctx context.Context, key string) ([]byte, bool) {
|
||||
val, err := r.client.Get(ctx, key).Bytes()
|
||||
if err == redis.Nil {
|
||||
return nil, false
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return val, true
|
||||
}
|
||||
|
||||
// Set stores a value in the cache with the specified TTL.
|
||||
func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
|
||||
if ttl == 0 {
|
||||
ttl = r.options.DefaultTTL
|
||||
}
|
||||
|
||||
return r.client.Set(ctx, key, value, ttl).Err()
|
||||
}
|
||||
|
||||
// Delete removes a key from the cache.
|
||||
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
|
||||
return r.client.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// DeleteByPattern removes all keys matching the pattern.
|
||||
func (r *RedisProvider) DeleteByPattern(ctx context.Context, pattern string) error {
|
||||
iter := r.client.Scan(ctx, 0, pattern, 0).Iterator()
|
||||
pipe := r.client.Pipeline()
|
||||
|
||||
count := 0
|
||||
for iter.Next(ctx) {
|
||||
pipe.Del(ctx, iter.Val())
|
||||
count++
|
||||
|
||||
// Execute pipeline in batches of 100
|
||||
if count%100 == 0 {
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
pipe = r.client.Pipeline()
|
||||
}
|
||||
}
|
||||
|
||||
if err := iter.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute remaining commands
|
||||
if count%100 != 0 {
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all items from the cache.
|
||||
func (r *RedisProvider) Clear(ctx context.Context) error {
|
||||
return r.client.FlushDB(ctx).Err()
|
||||
}
|
||||
|
||||
// Exists checks if a key exists in the cache.
|
||||
func (r *RedisProvider) Exists(ctx context.Context, key string) bool {
|
||||
result, err := r.client.Exists(ctx, key).Result()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return result > 0
|
||||
}
|
||||
|
||||
// Close closes the provider and releases any resources.
|
||||
func (r *RedisProvider) Close() error {
|
||||
return r.client.Close()
|
||||
}
|
||||
|
||||
// Stats returns statistics about the cache provider.
|
||||
func (r *RedisProvider) Stats(ctx context.Context) (*CacheStats, error) {
|
||||
info, err := r.client.Info(ctx, "stats", "keyspace").Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get Redis stats: %w", err)
|
||||
}
|
||||
|
||||
dbSize, err := r.client.DBSize(ctx).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get DB size: %w", err)
|
||||
}
|
||||
|
||||
// Parse stats from INFO command
|
||||
// This is a simplified version - you may want to parse more detailed stats
|
||||
stats := &CacheStats{
|
||||
Keys: dbSize,
|
||||
ProviderType: "redis",
|
||||
ProviderStats: map[string]any{
|
||||
"info": info,
|
||||
},
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
127
pkg/cache/query_cache.go
vendored
Normal file
127
pkg/cache/query_cache.go
vendored
Normal file
@@ -0,0 +1,127 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// QueryCacheKey represents the components used to build a cache key for query total count
|
||||
type QueryCacheKey struct {
|
||||
TableName string `json:"table_name"`
|
||||
Filters []common.FilterOption `json:"filters"`
|
||||
Sort []common.SortOption `json:"sort"`
|
||||
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
||||
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
||||
Expand []ExpandOptionKey `json:"expand,omitempty"`
|
||||
Distinct bool `json:"distinct,omitempty"`
|
||||
CursorForward string `json:"cursor_forward,omitempty"`
|
||||
CursorBackward string `json:"cursor_backward,omitempty"`
|
||||
}
|
||||
|
||||
// ExpandOptionKey represents expand options for cache key
|
||||
type ExpandOptionKey struct {
|
||||
Relation string `json:"relation"`
|
||||
Where string `json:"where,omitempty"`
|
||||
}
|
||||
|
||||
// BuildQueryCacheKey builds a cache key from query parameters for total count caching
|
||||
// This is used to cache the total count of records matching a query
|
||||
func BuildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
|
||||
key := QueryCacheKey{
|
||||
TableName: tableName,
|
||||
Filters: filters,
|
||||
Sort: sort,
|
||||
CustomSQLWhere: customWhere,
|
||||
CustomSQLOr: customOr,
|
||||
}
|
||||
|
||||
// Serialize to JSON for consistent hashing
|
||||
jsonData, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
// Fallback to simple string concatenation if JSON fails
|
||||
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
|
||||
}
|
||||
|
||||
return hashString(string(jsonData))
|
||||
}
|
||||
|
||||
// BuildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
|
||||
// Includes expand, distinct, and cursor pagination options
|
||||
func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
||||
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
||||
|
||||
key := QueryCacheKey{
|
||||
TableName: tableName,
|
||||
Filters: filters,
|
||||
Sort: sort,
|
||||
CustomSQLWhere: customWhere,
|
||||
CustomSQLOr: customOr,
|
||||
Distinct: distinct,
|
||||
CursorForward: cursorFwd,
|
||||
CursorBackward: cursorBwd,
|
||||
}
|
||||
|
||||
// Convert expand options to cache key format
|
||||
if len(expandOpts) > 0 {
|
||||
key.Expand = make([]ExpandOptionKey, 0, len(expandOpts))
|
||||
for _, exp := range expandOpts {
|
||||
// Type assert to get the expand option fields we care about for caching
|
||||
if expMap, ok := exp.(map[string]interface{}); ok {
|
||||
expKey := ExpandOptionKey{}
|
||||
if rel, ok := expMap["relation"].(string); ok {
|
||||
expKey.Relation = rel
|
||||
}
|
||||
if where, ok := expMap["where"].(string); ok {
|
||||
expKey.Where = where
|
||||
}
|
||||
key.Expand = append(key.Expand, expKey)
|
||||
}
|
||||
}
|
||||
// Sort expand options for consistent hashing (already sorted by relation name above)
|
||||
}
|
||||
|
||||
// Serialize to JSON for consistent hashing
|
||||
jsonData, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
// Fallback to simple string concatenation if JSON fails
|
||||
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s",
|
||||
tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd))
|
||||
}
|
||||
|
||||
return hashString(string(jsonData))
|
||||
}
|
||||
|
||||
// hashString computes SHA256 hash of a string
|
||||
func hashString(s string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(s))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// GetQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
|
||||
func GetQueryTotalCacheKey(hash string) string {
|
||||
return fmt.Sprintf("query_total:%s", hash)
|
||||
}
|
||||
|
||||
// CachedTotal represents a cached total count
|
||||
type CachedTotal struct {
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
// InvalidateCacheForTable removes all cached totals for a specific table
|
||||
// This should be called when data in the table changes (insert/update/delete)
|
||||
func InvalidateCacheForTable(ctx context.Context, tableName string) error {
|
||||
cache := GetDefaultCache()
|
||||
|
||||
// Build a pattern to match all query totals for this table
|
||||
// Note: This requires pattern matching support in the provider
|
||||
pattern := fmt.Sprintf("query_total:*%s*", strings.ToLower(tableName))
|
||||
|
||||
return cache.DeleteByPattern(ctx, pattern)
|
||||
}
|
||||
151
pkg/cache/query_cache_test.go
vendored
Normal file
151
pkg/cache/query_cache_test.go
vendored
Normal file
@@ -0,0 +1,151 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
func TestBuildQueryCacheKey(t *testing.T) {
|
||||
filters := []common.FilterOption{
|
||||
{Column: "name", Operator: "eq", Value: "test"},
|
||||
{Column: "age", Operator: "gt", Value: 25},
|
||||
}
|
||||
sorts := []common.SortOption{
|
||||
{Column: "name", Direction: "asc"},
|
||||
}
|
||||
|
||||
// Generate cache key
|
||||
key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
|
||||
|
||||
// Same parameters should generate same key
|
||||
key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
|
||||
|
||||
if key1 != key2 {
|
||||
t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2)
|
||||
}
|
||||
|
||||
// Different parameters should generate different key
|
||||
key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "")
|
||||
|
||||
if key1 == key3 {
|
||||
t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildExtendedQueryCacheKey(t *testing.T) {
|
||||
filters := []common.FilterOption{
|
||||
{Column: "name", Operator: "eq", Value: "test"},
|
||||
}
|
||||
sorts := []common.SortOption{
|
||||
{Column: "name", Direction: "asc"},
|
||||
}
|
||||
expandOpts := []interface{}{
|
||||
map[string]interface{}{
|
||||
"relation": "posts",
|
||||
"where": "status = 'published'",
|
||||
},
|
||||
}
|
||||
|
||||
// Generate cache key
|
||||
key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
|
||||
|
||||
// Same parameters should generate same key
|
||||
key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
|
||||
|
||||
if key1 != key2 {
|
||||
t.Errorf("Expected same cache keys for identical parameters")
|
||||
}
|
||||
|
||||
// Different distinct value should generate different key
|
||||
key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "")
|
||||
|
||||
if key1 == key3 {
|
||||
t.Errorf("Expected different cache keys for different distinct values")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetQueryTotalCacheKey(t *testing.T) {
|
||||
hash := "abc123"
|
||||
key := GetQueryTotalCacheKey(hash)
|
||||
|
||||
expected := "query_total:abc123"
|
||||
if key != expected {
|
||||
t.Errorf("Expected %s, got %s", expected, key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachedTotalIntegration(t *testing.T) {
|
||||
// Initialize cache with memory provider for testing
|
||||
UseMemory(&Options{
|
||||
DefaultTTL: 1 * time.Minute,
|
||||
MaxSize: 100,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create test data
|
||||
filters := []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
}
|
||||
sorts := []common.SortOption{
|
||||
{Column: "created_at", Direction: "desc"},
|
||||
}
|
||||
|
||||
// Build cache key
|
||||
cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "")
|
||||
cacheKey := GetQueryTotalCacheKey(cacheKeyHash)
|
||||
|
||||
// Store a total count in cache
|
||||
totalToCache := CachedTotal{Total: 42}
|
||||
err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set cache: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve from cache
|
||||
var cachedTotal CachedTotal
|
||||
err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get from cache: %v", err)
|
||||
}
|
||||
|
||||
if cachedTotal.Total != 42 {
|
||||
t.Errorf("Expected total 42, got %d", cachedTotal.Total)
|
||||
}
|
||||
|
||||
// Test cache miss
|
||||
nonExistentKey := GetQueryTotalCacheKey("nonexistent")
|
||||
var missedTotal CachedTotal
|
||||
err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal)
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for cache miss, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashString(t *testing.T) {
|
||||
input1 := "test string"
|
||||
input2 := "test string"
|
||||
input3 := "different string"
|
||||
|
||||
hash1 := hashString(input1)
|
||||
hash2 := hashString(input2)
|
||||
hash3 := hashString(input3)
|
||||
|
||||
// Same input should produce same hash
|
||||
if hash1 != hash2 {
|
||||
t.Errorf("Expected same hash for identical inputs")
|
||||
}
|
||||
|
||||
// Different input should produce different hash
|
||||
if hash1 == hash3 {
|
||||
t.Errorf("Expected different hash for different inputs")
|
||||
}
|
||||
|
||||
// Hash should be hex encoded SHA256 (64 characters)
|
||||
if len(hash1) != 64 {
|
||||
t.Errorf("Expected hash length of 64, got %d", len(hash1))
|
||||
}
|
||||
}
|
||||
@@ -4,10 +4,15 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||
"github.com/uptrace/bun"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// BunAdapter adapts Bun to work with our Database interface
|
||||
@@ -22,7 +27,10 @@ func NewBunAdapter(db *bun.DB) *BunAdapter {
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
||||
return &BunSelectQuery{query: b.db.NewSelect()}
|
||||
return &BunSelectQuery{
|
||||
query: b.db.NewSelect(),
|
||||
db: b.db,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) NewInsert() common.InsertQuery {
|
||||
@@ -37,12 +45,22 @@ func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
||||
return &BunDeleteQuery{query: b.db.NewDelete()}
|
||||
}
|
||||
|
||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result, err := b.db.ExecContext(ctx, query, args...)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
|
||||
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
||||
}
|
||||
|
||||
@@ -67,7 +85,12 @@ func (b *BunAdapter) RollbackTx(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
||||
func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||
// Create adapter with transaction
|
||||
adapter := &BunTxAdapter{tx: tx}
|
||||
@@ -77,17 +100,35 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
||||
|
||||
// BunSelectQuery implements SelectQuery for Bun
|
||||
type BunSelectQuery struct {
|
||||
query *bun.SelectQuery
|
||||
tableName string
|
||||
tableAlias string
|
||||
query *bun.SelectQuery
|
||||
db bun.IDB // Store DB connection for count queries
|
||||
hasModel bool // Track if Model() was called
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
deferredPreloads []deferredPreload // Preloads to execute as separate queries
|
||||
}
|
||||
|
||||
// deferredPreload represents a preload that will be executed as a separate query
|
||||
// to avoid PostgreSQL identifier length limits
|
||||
type deferredPreload struct {
|
||||
relation string
|
||||
apply []func(common.SelectQuery) common.SelectQuery
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.hasModel = true // Mark that we have a model
|
||||
|
||||
// Try to get table name from model if it implements TableNameProvider
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
b.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
b.schema, b.tableName = parseTableName(fullTableName)
|
||||
}
|
||||
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
b.tableAlias = provider.TableAlias()
|
||||
}
|
||||
|
||||
return b
|
||||
@@ -95,7 +136,8 @@ func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
|
||||
func (b *BunSelectQuery) Table(table string) common.SelectQuery {
|
||||
b.query = b.query.Table(table)
|
||||
b.tableName = table
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
b.schema, b.tableName = parseTableName(table)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -104,6 +146,15 @@ func (b *BunSelectQuery) Column(columns ...string) common.SelectQuery {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
||||
if len(args) > 0 {
|
||||
b.query = b.query.ColumnExpr(query, args)
|
||||
} else {
|
||||
b.query = b.query.ColumnExpr(query)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||
b.query = b.query.Where(query, args...)
|
||||
return b
|
||||
@@ -128,13 +179,9 @@ func (b *BunSelectQuery) Join(query string, args ...interface{}) common.SelectQu
|
||||
}
|
||||
}
|
||||
|
||||
// If no prefix provided, use the table name as prefix
|
||||
// If no prefix provided, use the table name as prefix (already separated from schema)
|
||||
if prefix == "" && b.tableName != "" {
|
||||
prefix = b.tableName
|
||||
// Extract just the table name if it has schema
|
||||
if idx := strings.LastIndex(prefix, "."); idx != -1 {
|
||||
prefix = prefix[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// If prefix is provided, add it as an alias in the join
|
||||
@@ -169,12 +216,9 @@ func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.Sele
|
||||
}
|
||||
}
|
||||
|
||||
// If no prefix provided, use the table name as prefix
|
||||
// If no prefix provided, use the table name as prefix (already separated from schema)
|
||||
if prefix == "" && b.tableName != "" {
|
||||
prefix = b.tableName
|
||||
if idx := strings.LastIndex(prefix, "."); idx != -1 {
|
||||
prefix = prefix[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// Construct LEFT JOIN with prefix
|
||||
@@ -189,7 +233,7 @@ func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.Sele
|
||||
}
|
||||
}
|
||||
|
||||
b.query = b.query.Join("LEFT JOIN " + joinClause, sqlArgs...)
|
||||
b.query = b.query.Join("LEFT JOIN "+joinClause, sqlArgs...)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -201,6 +245,133 @@ func (b *BunSelectQuery) Preload(relation string, conditions ...interface{}) com
|
||||
return b
|
||||
}
|
||||
|
||||
// // shortenAliasForPostgres shortens a table/relation alias if it would exceed PostgreSQL's 63-char limit
|
||||
// // when combined with typical column names
|
||||
// func shortenAliasForPostgres(relationPath string) (string, bool) {
|
||||
// // Convert relation path to the alias format Bun uses: dots become double underscores
|
||||
// // Also convert to lowercase and use snake_case as Bun does
|
||||
// parts := strings.Split(relationPath, ".")
|
||||
// alias := strings.ToLower(strings.Join(parts, "__"))
|
||||
|
||||
// // PostgreSQL truncates identifiers to 63 chars
|
||||
// // If the alias + typical column name would exceed this, we need to shorten
|
||||
// // Reserve at least 30 chars for column names (e.g., "__rid_mastertype_hubtype")
|
||||
// const maxAliasLength = 30
|
||||
|
||||
// if len(alias) > maxAliasLength {
|
||||
// // Create a shortened alias using a hash of the original
|
||||
// hash := md5.Sum([]byte(alias))
|
||||
// hashStr := hex.EncodeToString(hash[:])[:8]
|
||||
|
||||
// // Keep first few chars of original for readability + hash
|
||||
// prefixLen := maxAliasLength - 9 // 9 = 1 underscore + 8 hash chars
|
||||
// if prefixLen > len(alias) {
|
||||
// prefixLen = len(alias)
|
||||
// }
|
||||
|
||||
// shortened := alias[:prefixLen] + "_" + hashStr
|
||||
// logger.Debug("Shortened alias '%s' (%d chars) to '%s' (%d chars) to avoid PostgreSQL 63-char limit",
|
||||
// alias, len(alias), shortened, len(shortened))
|
||||
// return shortened, true
|
||||
// }
|
||||
|
||||
// return alias, false
|
||||
// }
|
||||
|
||||
// // estimateColumnAliasLength estimates the length of a column alias in a nested preload
|
||||
// // Bun creates aliases like: relationChain__columnName
|
||||
// func estimateColumnAliasLength(relationPath string, columnName string) int {
|
||||
// relationParts := strings.Split(relationPath, ".")
|
||||
// aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
|
||||
// // Bun adds "__" between alias and column name
|
||||
// return len(aliasChain) + 2 + len(columnName)
|
||||
// }
|
||||
|
||||
func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
// Check if this relation chain would create problematic long aliases
|
||||
relationParts := strings.Split(relation, ".")
|
||||
aliasChain := strings.ToLower(strings.Join(relationParts, "__"))
|
||||
|
||||
// PostgreSQL's identifier limit is 63 characters
|
||||
const postgresIdentifierLimit = 63
|
||||
const safeAliasLimit = 35 // Leave room for column names
|
||||
|
||||
// If the alias chain is too long, defer this preload to be executed as a separate query
|
||||
if len(aliasChain) > safeAliasLimit {
|
||||
logger.Info("Preload relation '%s' creates long alias chain '%s' (%d chars). "+
|
||||
"Using separate query to avoid PostgreSQL %d-char identifier limit.",
|
||||
relation, aliasChain, len(aliasChain), postgresIdentifierLimit)
|
||||
|
||||
// For nested preloads (e.g., "Parent.Child"), split into separate preloads
|
||||
// This avoids the long concatenated alias
|
||||
if len(relationParts) > 1 {
|
||||
// Load first level normally: "Parent"
|
||||
firstLevel := relationParts[0]
|
||||
remainingPath := strings.Join(relationParts[1:], ".")
|
||||
|
||||
logger.Info("Splitting nested preload: loading '%s' first, then '%s' separately",
|
||||
firstLevel, remainingPath)
|
||||
|
||||
// Apply the first level preload normally
|
||||
b.query = b.query.Relation(firstLevel)
|
||||
|
||||
// Store the remaining nested preload to be executed after the main query
|
||||
b.deferredPreloads = append(b.deferredPreloads, deferredPreload{
|
||||
relation: relation,
|
||||
apply: apply,
|
||||
})
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Single level but still too long - just warn and continue
|
||||
logger.Warn("Single-level preload '%s' has a very long name (%d chars). "+
|
||||
"Consider renaming the field to avoid potential issues.",
|
||||
relation, len(aliasChain))
|
||||
}
|
||||
|
||||
// Normal preload handling
|
||||
b.query = b.query.Relation(relation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err := logger.HandlePanic("BunSelectQuery.PreloadRelation", r)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
if len(apply) == 0 {
|
||||
return sq
|
||||
}
|
||||
|
||||
// Wrap the incoming *bun.SelectQuery in our adapter
|
||||
wrapper := &BunSelectQuery{
|
||||
query: sq,
|
||||
db: b.db,
|
||||
}
|
||||
|
||||
// Start with the interface value (not pointer)
|
||||
current := common.SelectQuery(wrapper)
|
||||
|
||||
// Apply each function in sequence
|
||||
for _, fn := range apply {
|
||||
if fn != nil {
|
||||
// Pass ¤t (pointer to interface variable), fn modifies and returns new interface value
|
||||
modified := fn(current)
|
||||
current = modified
|
||||
}
|
||||
}
|
||||
|
||||
// Extract the final *bun.SelectQuery
|
||||
if finalBun, ok := current.(*BunSelectQuery); ok {
|
||||
return finalBun.query
|
||||
}
|
||||
|
||||
return sq // fallback
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Order(order string) common.SelectQuery {
|
||||
b.query = b.query.Order(order)
|
||||
return b
|
||||
@@ -226,31 +397,221 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||
return b.query.Scan(ctx, dest)
|
||||
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
if dest == nil {
|
||||
return fmt.Errorf("destination cannot be nil")
|
||||
}
|
||||
|
||||
// Execute the main query first
|
||||
err = b.query.Scan(ctx, dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute any deferred preloads
|
||||
if len(b.deferredPreloads) > 0 {
|
||||
err = b.executeDeferredPreloads(ctx, dest)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to execute deferred preloads: %v", err)
|
||||
// Don't fail the whole query, just log the warning
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Count(ctx context.Context) (int, error) {
|
||||
count, err := b.query.Count(ctx)
|
||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
|
||||
}
|
||||
}()
|
||||
if b.query.GetModel() == nil {
|
||||
return fmt.Errorf("model is nil")
|
||||
}
|
||||
|
||||
// Execute the main query first
|
||||
err = b.query.Scan(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute any deferred preloads
|
||||
if len(b.deferredPreloads) > 0 {
|
||||
model := b.query.GetModel()
|
||||
err = b.executeDeferredPreloads(ctx, model.Value())
|
||||
if err != nil {
|
||||
logger.Warn("Failed to execute deferred preloads: %v", err)
|
||||
// Don't fail the whole query, just log the warning
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeDeferredPreloads executes preloads that were deferred to avoid PostgreSQL identifier length limits
|
||||
func (b *BunSelectQuery) executeDeferredPreloads(ctx context.Context, dest interface{}) error {
|
||||
if len(b.deferredPreloads) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, dp := range b.deferredPreloads {
|
||||
err := b.executeSingleDeferredPreload(ctx, dest, dp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute deferred preload '%s': %w", dp.relation, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeSingleDeferredPreload executes a single deferred preload
|
||||
// For a relation like "Parent.Child", it:
|
||||
// 1. Finds all loaded Parent records in dest
|
||||
// 2. Loads Child records for those Parents using a separate query (loading only "Child", not "Parent.Child")
|
||||
// 3. Bun automatically assigns the Child records to the appropriate Parent.Child field
|
||||
func (b *BunSelectQuery) executeSingleDeferredPreload(ctx context.Context, dest interface{}, dp deferredPreload) error {
|
||||
relationParts := strings.Split(dp.relation, ".")
|
||||
if len(relationParts) < 2 {
|
||||
return fmt.Errorf("deferred preload must be nested (e.g., 'Parent.Child'), got: %s", dp.relation)
|
||||
}
|
||||
|
||||
// The parent relation that was already loaded
|
||||
parentRelation := relationParts[0]
|
||||
// The child relation we need to load
|
||||
childRelation := strings.Join(relationParts[1:], ".")
|
||||
|
||||
logger.Debug("Executing deferred preload: loading '%s' on already-loaded '%s'", childRelation, parentRelation)
|
||||
|
||||
// Use reflection to access the parent relation field(s) in the loaded records
|
||||
// Then load the child relation for those parent records
|
||||
destValue := reflect.ValueOf(dest)
|
||||
if destValue.Kind() == reflect.Ptr {
|
||||
destValue = destValue.Elem()
|
||||
}
|
||||
|
||||
// Handle both slice and single record
|
||||
if destValue.Kind() == reflect.Slice {
|
||||
// Iterate through each record in the slice
|
||||
for i := 0; i < destValue.Len(); i++ {
|
||||
record := destValue.Index(i)
|
||||
if err := b.loadChildRelationForRecord(ctx, record, parentRelation, childRelation, dp.apply); err != nil {
|
||||
logger.Warn("Failed to load child relation '%s' for record %d: %v", childRelation, i, err)
|
||||
// Continue with other records
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Single record
|
||||
if err := b.loadChildRelationForRecord(ctx, destValue, parentRelation, childRelation, dp.apply); err != nil {
|
||||
return fmt.Errorf("failed to load child relation '%s': %w", childRelation, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadChildRelationForRecord loads a child relation for a single parent record
|
||||
func (b *BunSelectQuery) loadChildRelationForRecord(ctx context.Context, record reflect.Value, parentRelation, childRelation string, apply []func(common.SelectQuery) common.SelectQuery) error {
|
||||
// Ensure we're working with the actual struct value, not a pointer
|
||||
if record.Kind() == reflect.Ptr {
|
||||
record = record.Elem()
|
||||
}
|
||||
|
||||
// Get the parent relation field
|
||||
parentField := record.FieldByName(parentRelation)
|
||||
if !parentField.IsValid() {
|
||||
// Parent relation field doesn't exist
|
||||
logger.Debug("Parent relation field '%s' not found in record", parentRelation)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if the parent field is nil (for pointer fields)
|
||||
if parentField.Kind() == reflect.Ptr && parentField.IsNil() {
|
||||
// Parent relation not loaded or nil, skip
|
||||
logger.Debug("Parent relation field '%s' is nil, skipping child preload", parentRelation)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the interface value to pass to Bun
|
||||
parentValue := parentField.Interface()
|
||||
|
||||
// Load the child relation on the parent record
|
||||
// This uses a shorter alias since we're only loading "Child", not "Parent.Child"
|
||||
return b.db.NewSelect().
|
||||
Model(parentValue).
|
||||
Relation(childRelation, func(sq *bun.SelectQuery) *bun.SelectQuery {
|
||||
// Apply any custom query modifications
|
||||
if len(apply) > 0 {
|
||||
wrapper := &BunSelectQuery{query: sq, db: b.db}
|
||||
current := common.SelectQuery(wrapper)
|
||||
for _, fn := range apply {
|
||||
if fn != nil {
|
||||
current = fn(current)
|
||||
}
|
||||
}
|
||||
if finalBun, ok := current.(*BunSelectQuery); ok {
|
||||
return finalBun.query
|
||||
}
|
||||
}
|
||||
return sq
|
||||
}).
|
||||
Scan(ctx)
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Count", r)
|
||||
count = 0
|
||||
}
|
||||
}()
|
||||
// If Model() was set, use bun's native Count() which works properly
|
||||
if b.hasModel {
|
||||
count, err := b.query.Count(ctx)
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
||||
// This is needed when only Table() is set without a model
|
||||
err = b.db.NewSelect().
|
||||
TableExpr("(?) AS subquery", b.query).
|
||||
ColumnExpr("COUNT(*)").
|
||||
Scan(ctx, &count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (b *BunSelectQuery) Exists(ctx context.Context) (bool, error) {
|
||||
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunSelectQuery.Exists", r)
|
||||
exists = false
|
||||
}
|
||||
}()
|
||||
return b.query.Exists(ctx)
|
||||
}
|
||||
|
||||
// BunInsertQuery implements InsertQuery for Bun
|
||||
type BunInsertQuery struct {
|
||||
query *bun.InsertQuery
|
||||
values map[string]interface{}
|
||||
query *bun.InsertQuery
|
||||
values map[string]interface{}
|
||||
hasModel bool
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.hasModel = true
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Table(table string) common.InsertQuery {
|
||||
if b.hasModel {
|
||||
return b
|
||||
}
|
||||
b.query = b.query.Table(table)
|
||||
return b
|
||||
}
|
||||
@@ -275,11 +636,22 @@ func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
if b.values != nil {
|
||||
// For Bun, we need to handle this differently
|
||||
for k, v := range b.values {
|
||||
b.query = b.query.Set("? = ?", bun.Ident(k), v)
|
||||
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunInsertQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
if len(b.values) > 0 {
|
||||
if !b.hasModel {
|
||||
// If no model was set, use the values map as the model
|
||||
// Bun can insert map[string]interface{} directly
|
||||
b.query = b.query.Model(&b.values)
|
||||
} else {
|
||||
// If model was set, use Value() to add individual values
|
||||
for k, v := range b.values {
|
||||
b.query = b.query.Value(k, "?", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
result, err := b.query.Exec(ctx)
|
||||
@@ -289,25 +661,50 @@ func (b *BunInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
// BunUpdateQuery implements UpdateQuery for Bun
|
||||
type BunUpdateQuery struct {
|
||||
query *bun.UpdateQuery
|
||||
model interface{}
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
b.query = b.query.Model(model)
|
||||
b.model = model
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
b.query = b.query.Table(table)
|
||||
if b.model == nil {
|
||||
// Try to get table name from table string if model is not set
|
||||
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
b.model = model
|
||||
}
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
|
||||
// Validate column is writable if model is set
|
||||
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
|
||||
// Skip scan-only columns
|
||||
return b
|
||||
}
|
||||
b.query = b.query.Set(column+" = ?", value)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
||||
pkName := reflection.GetPrimaryKeyName(b.model)
|
||||
for column, value := range values {
|
||||
// Validate column is writable if model is set
|
||||
if b.model != nil && !reflection.IsColumnWritable(b.model, column) {
|
||||
// Skip scan-only columns
|
||||
continue
|
||||
}
|
||||
if pkName != "" && column == pkName {
|
||||
// Skip primary key updates
|
||||
continue
|
||||
}
|
||||
b.query = b.query.Set(column+" = ?", value)
|
||||
}
|
||||
return b
|
||||
@@ -325,7 +722,12 @@ func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result, err := b.query.Exec(ctx)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
@@ -350,7 +752,12 @@ func (b *BunDeleteQuery) Where(query string, args ...interface{}) common.DeleteQ
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *BunDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result, err := b.query.Exec(ctx)
|
||||
return &BunResult{result: result}, err
|
||||
}
|
||||
@@ -381,7 +788,10 @@ type BunTxAdapter struct {
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) NewSelect() common.SelectQuery {
|
||||
return &BunSelectQuery{query: b.tx.NewSelect()}
|
||||
return &BunSelectQuery{
|
||||
query: b.tx.NewSelect(),
|
||||
db: b.tx,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *BunTxAdapter) NewInsert() common.InsertQuery {
|
||||
|
||||
213
pkg/common/adapters/database/bun_insert_test.go
Normal file
213
pkg/common/adapters/database/bun_insert_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||
"github.com/uptrace/bun/driver/sqliteshim"
|
||||
)
|
||||
|
||||
// TestInsertModel is a test model for insert operations
|
||||
type TestInsertModel struct {
|
||||
bun.BaseModel `bun:"table:test_inserts"`
|
||||
ID int64 `bun:"id,pk,autoincrement"`
|
||||
Name string `bun:"name,notnull"`
|
||||
Email string `bun:"email"`
|
||||
Age int `bun:"age"`
|
||||
}
|
||||
|
||||
func setupBunTestDB(t *testing.T) *bun.DB {
|
||||
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
|
||||
require.NoError(t, err, "Failed to open SQLite database")
|
||||
|
||||
db := bun.NewDB(sqldb, sqlitedialect.New())
|
||||
|
||||
// Create test table
|
||||
_, err = db.NewCreateTable().
|
||||
Model((*TestInsertModel)(nil)).
|
||||
IfNotExists().
|
||||
Exec(context.Background())
|
||||
require.NoError(t, err, "Failed to create test table")
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_Model(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test inserting with Model()
|
||||
model := &TestInsertModel{
|
||||
Name: "John Doe",
|
||||
Email: "john@example.com",
|
||||
Age: 30,
|
||||
}
|
||||
|
||||
result, err := adapter.NewInsert().
|
||||
Model(model).
|
||||
Returning("*").
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Insert should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected(), "Should insert 1 row")
|
||||
|
||||
// Verify the data was inserted
|
||||
var retrieved TestInsertModel
|
||||
err = db.NewSelect().
|
||||
Model(&retrieved).
|
||||
Where("id = ?", model.ID).
|
||||
Scan(ctx)
|
||||
|
||||
require.NoError(t, err, "Should retrieve inserted row")
|
||||
assert.Equal(t, "John Doe", retrieved.Name)
|
||||
assert.Equal(t, "john@example.com", retrieved.Email)
|
||||
assert.Equal(t, 30, retrieved.Age)
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_Value(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test inserting with Value() method - this was the bug
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Jane Smith").
|
||||
Value("email", "jane@example.com").
|
||||
Value("age", 25).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Insert with Value() should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected(), "Should insert 1 row")
|
||||
|
||||
// Verify the data was inserted
|
||||
var retrieved TestInsertModel
|
||||
err = db.NewSelect().
|
||||
Model(&retrieved).
|
||||
Where("name = ?", "Jane Smith").
|
||||
Scan(ctx)
|
||||
|
||||
require.NoError(t, err, "Should retrieve inserted row")
|
||||
assert.Equal(t, "Jane Smith", retrieved.Name)
|
||||
assert.Equal(t, "jane@example.com", retrieved.Email)
|
||||
assert.Equal(t, 25, retrieved.Age)
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_MultipleValues(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test inserting multiple values
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Alice").
|
||||
Value("email", "alice@example.com").
|
||||
Value("age", 28).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "First insert should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
result, err = adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Bob").
|
||||
Value("email", "bob@example.com").
|
||||
Value("age", 35).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Second insert should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
// Verify both rows exist
|
||||
var count int
|
||||
count, err = db.NewSelect().
|
||||
Model((*TestInsertModel)(nil)).
|
||||
Count(ctx)
|
||||
|
||||
require.NoError(t, err, "Count should succeed")
|
||||
assert.Equal(t, 2, count, "Should have 2 rows")
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_ValueWithNil(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test inserting with nil value for nullable field
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Test User").
|
||||
Value("email", nil). // NULL email
|
||||
Value("age", 20).
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Insert with nil value should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
|
||||
// Verify the data was inserted with NULL email
|
||||
var retrieved TestInsertModel
|
||||
err = db.NewSelect().
|
||||
Model(&retrieved).
|
||||
Where("name = ?", "Test User").
|
||||
Scan(ctx)
|
||||
|
||||
require.NoError(t, err, "Should retrieve inserted row")
|
||||
assert.Equal(t, "Test User", retrieved.Name)
|
||||
assert.Equal(t, "", retrieved.Email) // NULL becomes empty string
|
||||
assert.Equal(t, 20, retrieved.Age)
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_Returning(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test insert with RETURNING clause
|
||||
// Note: SQLite has limited RETURNING support, but this tests the API
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Value("name", "Return Test").
|
||||
Value("email", "return@example.com").
|
||||
Value("age", 40).
|
||||
Returning("*").
|
||||
Exec(ctx)
|
||||
|
||||
require.NoError(t, err, "Insert with RETURNING should succeed")
|
||||
assert.Equal(t, int64(1), result.RowsAffected())
|
||||
}
|
||||
|
||||
func TestBunInsertQuery_EmptyValues(t *testing.T) {
|
||||
db := setupBunTestDB(t)
|
||||
defer db.Close()
|
||||
|
||||
adapter := NewBunAdapter(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Test insert without calling Value() - should use Model() or fail gracefully
|
||||
result, err := adapter.NewInsert().
|
||||
Table("test_inserts").
|
||||
Exec(ctx)
|
||||
|
||||
// This should fail because no values are provided
|
||||
assert.Error(t, err, "Insert without values should fail")
|
||||
if result != nil {
|
||||
assert.Equal(t, int64(0), result.RowsAffected())
|
||||
}
|
||||
}
|
||||
@@ -5,8 +5,12 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// GormAdapter adapts GORM to work with our Database interface
|
||||
@@ -35,12 +39,22 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
||||
return &GormDeleteQuery{db: g.db}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormAdapter.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Exec(query, args...)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormAdapter.Query", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||
}
|
||||
|
||||
@@ -60,7 +74,12 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
|
||||
return g.db.WithContext(ctx).Rollback().Error
|
||||
}
|
||||
|
||||
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
||||
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx}
|
||||
return fn(adapter)
|
||||
@@ -70,7 +89,8 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
||||
// GormSelectQuery implements SelectQuery for GORM
|
||||
type GormSelectQuery struct {
|
||||
db *gorm.DB
|
||||
tableName string
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
}
|
||||
|
||||
@@ -79,7 +99,13 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
|
||||
// Try to get table name from model if it implements TableNameProvider
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
g.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
g.schema, g.tableName = parseTableName(fullTableName)
|
||||
}
|
||||
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
g.tableAlias = provider.TableAlias()
|
||||
}
|
||||
|
||||
return g
|
||||
@@ -87,7 +113,9 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
|
||||
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
||||
g.db = g.db.Table(table)
|
||||
g.tableName = table
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
g.schema, g.tableName = parseTableName(table)
|
||||
|
||||
return g
|
||||
}
|
||||
|
||||
@@ -96,6 +124,16 @@ func (g *GormSelectQuery) Column(columns ...string) common.SelectQuery {
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
||||
if len(args) > 0 {
|
||||
g.db = g.db.Select(query, args...)
|
||||
} else {
|
||||
g.db = g.db.Select(query)
|
||||
}
|
||||
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||
g.db = g.db.Where(query, args...)
|
||||
return g
|
||||
@@ -120,13 +158,9 @@ func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQ
|
||||
}
|
||||
}
|
||||
|
||||
// If no prefix provided, use the table name as prefix
|
||||
// If no prefix provided, use the table name as prefix (already separated from schema)
|
||||
if prefix == "" && g.tableName != "" {
|
||||
prefix = g.tableName
|
||||
// Extract just the table name if it has schema
|
||||
if idx := strings.LastIndex(prefix, "."); idx != -1 {
|
||||
prefix = prefix[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// If prefix is provided, add it as an alias in the join
|
||||
@@ -161,12 +195,9 @@ func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.Sel
|
||||
}
|
||||
}
|
||||
|
||||
// If no prefix provided, use the table name as prefix
|
||||
// If no prefix provided, use the table name as prefix (already separated from schema)
|
||||
if prefix == "" && g.tableName != "" {
|
||||
prefix = g.tableName
|
||||
if idx := strings.LastIndex(prefix, "."); idx != -1 {
|
||||
prefix = prefix[idx+1:]
|
||||
}
|
||||
}
|
||||
|
||||
// Construct LEFT JOIN with prefix
|
||||
@@ -190,6 +221,36 @@ func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) co
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB {
|
||||
if len(apply) == 0 {
|
||||
return db
|
||||
}
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
}
|
||||
|
||||
current := common.SelectQuery(wrapper)
|
||||
|
||||
for _, fn := range apply {
|
||||
if fn != nil {
|
||||
|
||||
modified := fn(current)
|
||||
current = modified
|
||||
}
|
||||
}
|
||||
|
||||
if finalBun, ok := current.(*GormSelectQuery); ok {
|
||||
return finalBun.db
|
||||
}
|
||||
|
||||
return db // fallback
|
||||
})
|
||||
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Order(order string) common.SelectQuery {
|
||||
g.db = g.db.Order(order)
|
||||
return g
|
||||
@@ -215,19 +276,48 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Find(dest).Error
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Count(ctx context.Context) (int, error) {
|
||||
var count int64
|
||||
err := g.db.WithContext(ctx).Count(&count).Error
|
||||
return int(count), err
|
||||
func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
|
||||
}
|
||||
}()
|
||||
if g.db.Statement.Model == nil {
|
||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||
}
|
||||
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Exists(ctx context.Context) (bool, error) {
|
||||
func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormSelectQuery.Count", r)
|
||||
count = 0
|
||||
}
|
||||
}()
|
||||
var count64 int64
|
||||
err = g.db.WithContext(ctx).Count(&count64).Error
|
||||
return int(count64), err
|
||||
}
|
||||
|
||||
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormSelectQuery.Exists", r)
|
||||
exists = false
|
||||
}
|
||||
}()
|
||||
var count int64
|
||||
err := g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
@@ -267,13 +357,19 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormInsertQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
var result *gorm.DB
|
||||
if g.model != nil {
|
||||
switch {
|
||||
case g.model != nil:
|
||||
result = g.db.WithContext(ctx).Create(g.model)
|
||||
} else if g.values != nil {
|
||||
case g.values != nil:
|
||||
result = g.db.WithContext(ctx).Create(g.values)
|
||||
} else {
|
||||
default:
|
||||
result = g.db.WithContext(ctx).Create(map[string]interface{}{})
|
||||
}
|
||||
return &GormResult{result: result}, result.Error
|
||||
@@ -294,10 +390,23 @@ func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
|
||||
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
g.db = g.db.Table(table)
|
||||
if g.model == nil {
|
||||
// Try to get table name from table string if model is not set
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
g.model = model
|
||||
}
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
|
||||
// Validate column is writable if model is set
|
||||
if g.model != nil && !reflection.IsColumnWritable(g.model, column) {
|
||||
// Skip read-only columns
|
||||
return g
|
||||
}
|
||||
|
||||
if g.updates == nil {
|
||||
g.updates = make(map[string]interface{})
|
||||
}
|
||||
@@ -308,7 +417,25 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQue
|
||||
}
|
||||
|
||||
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
|
||||
g.updates = values
|
||||
|
||||
// Filter out read-only columns if model is set
|
||||
if g.model != nil {
|
||||
pkName := reflection.GetPrimaryKeyName(g.model)
|
||||
filteredValues := make(map[string]interface{})
|
||||
for column, value := range values {
|
||||
if pkName != "" && column == pkName {
|
||||
// Skip primary key updates
|
||||
continue
|
||||
}
|
||||
if reflection.IsColumnWritable(g.model, column) {
|
||||
filteredValues[column] = value
|
||||
}
|
||||
|
||||
}
|
||||
g.updates = filteredValues
|
||||
} else {
|
||||
g.updates = values
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
@@ -322,7 +449,12 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormUpdateQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
@@ -349,7 +481,12 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GormDeleteQuery) Exec(ctx context.Context) (common.Result, error) {
|
||||
func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
||||
}
|
||||
}()
|
||||
result := g.db.WithContext(ctx).Delete(g.model)
|
||||
return &GormResult{result: result}, result.Error
|
||||
}
|
||||
|
||||
161
pkg/common/adapters/database/update_validation_test.go
Normal file
161
pkg/common/adapters/database/update_validation_test.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// Test models for bun
|
||||
type BunTestModel struct {
|
||||
ID int `bun:"id,pk"`
|
||||
Name string `bun:"name"`
|
||||
Email string `bun:"email"`
|
||||
ComputedCol string `bun:"computed_col,scanonly"`
|
||||
}
|
||||
|
||||
// Test models for gorm
|
||||
type GormTestModel struct {
|
||||
ID int `gorm:"column:id;primaryKey"`
|
||||
Name string `gorm:"column:name"`
|
||||
Email string `gorm:"column:email"`
|
||||
ReadOnlyCol string `gorm:"column:readonly_col;->"`
|
||||
NoWriteCol string `gorm:"column:nowrite_col;<-:false"`
|
||||
}
|
||||
|
||||
func TestIsColumnWritable_Bun(t *testing.T) {
|
||||
model := &BunTestModel{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
columnName string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "writable column - id",
|
||||
columnName: "id",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "writable column - name",
|
||||
columnName: "name",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "writable column - email",
|
||||
columnName: "email",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "scanonly column should not be writable",
|
||||
columnName: "computed_col",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent column should be writable (dynamic)",
|
||||
columnName: "nonexistent",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reflection.IsColumnWritable(model, tt.columnName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsColumnWritable_Gorm(t *testing.T) {
|
||||
model := &GormTestModel{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
columnName string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "writable column - id",
|
||||
columnName: "id",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "writable column - name",
|
||||
columnName: "name",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "writable column - email",
|
||||
columnName: "email",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "read-only column with -> should not be writable",
|
||||
columnName: "readonly_col",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "column with <-:false should not be writable",
|
||||
columnName: "nowrite_col",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent column should be writable (dynamic)",
|
||||
columnName: "nonexistent",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := reflection.IsColumnWritable(model, tt.columnName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBunUpdateQuery_SetMap_FiltersScanOnly(t *testing.T) {
|
||||
// Note: This is a unit test for the validation logic only.
|
||||
// We can't fully test the bun query without a database connection,
|
||||
// but we've verified the validation logic in TestIsColumnWritable_Bun
|
||||
t.Skip("Skipping integration test - validation logic tested in TestIsColumnWritable_Bun")
|
||||
}
|
||||
|
||||
func TestGormUpdateQuery_SetMap_FiltersReadOnly(t *testing.T) {
|
||||
model := &GormTestModel{}
|
||||
query := &GormUpdateQuery{
|
||||
model: model,
|
||||
}
|
||||
|
||||
// SetMap should filter out read-only columns
|
||||
values := map[string]interface{}{
|
||||
"name": "John",
|
||||
"email": "john@example.com",
|
||||
"readonly_col": "should_be_filtered",
|
||||
"nowrite_col": "should_also_be_filtered",
|
||||
}
|
||||
|
||||
query.SetMap(values)
|
||||
|
||||
// Check that the updates map only contains writable columns
|
||||
if updates, ok := query.updates.(map[string]interface{}); ok {
|
||||
if _, exists := updates["readonly_col"]; exists {
|
||||
t.Error("readonly_col should have been filtered out")
|
||||
}
|
||||
if _, exists := updates["nowrite_col"]; exists {
|
||||
t.Error("nowrite_col should have been filtered out")
|
||||
}
|
||||
if _, exists := updates["name"]; !exists {
|
||||
t.Error("name should be in updates")
|
||||
}
|
||||
if _, exists := updates["email"]; !exists {
|
||||
t.Error("email should be in updates")
|
||||
}
|
||||
} else {
|
||||
t.Error("updates should be a map[string]interface{}")
|
||||
}
|
||||
}
|
||||
16
pkg/common/adapters/database/utils.go
Normal file
16
pkg/common/adapters/database/utils.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||
// For example: "public.users" -> ("public", "users")
|
||||
//
|
||||
// "users" -> ("", "users")
|
||||
func parseTableName(fullTableName string) (schema, table string) {
|
||||
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
||||
return fullTableName[:idx], fullTableName[idx+1:]
|
||||
}
|
||||
return "", fullTableName
|
||||
}
|
||||
@@ -3,8 +3,10 @@ package router
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||
"github.com/uptrace/bunrouter"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// BunRouterAdapter adapts uptrace/bunrouter to work with our Router interface
|
||||
@@ -34,7 +36,11 @@ func (b *BunRouterAdapter) HandleFunc(pattern string, handler common.HTTPHandler
|
||||
func (b *BunRouterAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) {
|
||||
// This method would be used when we need to serve through our interface
|
||||
// For now, we'll work directly with the underlying router
|
||||
panic("ServeHTTP not implemented - use GetBunRouter() for direct access")
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
_, err := w.Write([]byte(`{"error":"ServeHTTP not implemented - use GetBunRouter() for direct access"}`))
|
||||
if err != nil {
|
||||
logger.Warn("Failed to write. %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetBunRouter returns the underlying bunrouter for direct access
|
||||
@@ -120,6 +126,16 @@ func (b *BunRouterRequest) QueryParam(key string) string {
|
||||
return b.req.URL.Query().Get(key)
|
||||
}
|
||||
|
||||
func (b *BunRouterRequest) AllQueryParams() map[string]string {
|
||||
params := make(map[string]string)
|
||||
for key, values := range b.req.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
params[key] = values[0]
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func (b *BunRouterRequest) AllHeaders() map[string]string {
|
||||
headers := make(map[string]string)
|
||||
for key, values := range b.req.Header {
|
||||
@@ -130,6 +146,12 @@ func (b *BunRouterRequest) AllHeaders() map[string]string {
|
||||
return headers
|
||||
}
|
||||
|
||||
// UnderlyingRequest returns the underlying *http.Request
|
||||
// This is useful when you need to pass the request to other handlers
|
||||
func (b *BunRouterRequest) UnderlyingRequest() *http.Request {
|
||||
return b.req.Request
|
||||
}
|
||||
|
||||
// StandardBunRouterAdapter creates routes compatible with standard bunrouter handlers
|
||||
type StandardBunRouterAdapter struct {
|
||||
*BunRouterAdapter
|
||||
@@ -190,4 +212,3 @@ func DefaultBunRouterConfig() *BunRouterConfig {
|
||||
HandleOPTIONS: true,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/common"
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// MuxAdapter adapts Gorilla Mux to work with our Router interface
|
||||
@@ -31,7 +33,11 @@ func (m *MuxAdapter) HandleFunc(pattern string, handler common.HTTPHandlerFunc)
|
||||
func (m *MuxAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) {
|
||||
// This method would be used when we need to serve through our interface
|
||||
// For now, we'll work directly with the underlying router
|
||||
panic("ServeHTTP not implemented - use GetMuxRouter() for direct access")
|
||||
w.WriteHeader(http.StatusNotImplemented)
|
||||
_, err := w.Write([]byte(`{"error":"ServeHTTP not implemented - use GetMuxRouter() for direct access"}`))
|
||||
if err != nil {
|
||||
logger.Warn("Failed to write. %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// MuxRouteRegistration implements RouteRegistration for Mux
|
||||
@@ -116,6 +122,16 @@ func (h *HTTPRequest) QueryParam(key string) string {
|
||||
return h.req.URL.Query().Get(key)
|
||||
}
|
||||
|
||||
func (h *HTTPRequest) AllQueryParams() map[string]string {
|
||||
params := make(map[string]string)
|
||||
for key, values := range h.req.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
params[key] = values[0]
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func (h *HTTPRequest) AllHeaders() map[string]string {
|
||||
headers := make(map[string]string)
|
||||
for key, values := range h.req.Header {
|
||||
@@ -126,10 +142,16 @@ func (h *HTTPRequest) AllHeaders() map[string]string {
|
||||
return headers
|
||||
}
|
||||
|
||||
// UnderlyingRequest returns the underlying *http.Request
|
||||
// This is useful when you need to pass the request to other handlers
|
||||
func (h *HTTPRequest) UnderlyingRequest() *http.Request {
|
||||
return h.req
|
||||
}
|
||||
|
||||
// HTTPResponseWriter adapts our ResponseWriter interface to standard http.ResponseWriter
|
||||
type HTTPResponseWriter struct {
|
||||
resp http.ResponseWriter
|
||||
w common.ResponseWriter
|
||||
w common.ResponseWriter //nolint:unused
|
||||
status int
|
||||
}
|
||||
|
||||
@@ -155,6 +177,12 @@ func (h *HTTPResponseWriter) WriteJSON(data interface{}) error {
|
||||
return json.NewEncoder(h.resp).Encode(data)
|
||||
}
|
||||
|
||||
// UnderlyingResponseWriter returns the underlying http.ResponseWriter
|
||||
// This is useful when you need to pass the response writer to other handlers
|
||||
func (h *HTTPResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
||||
return h.resp
|
||||
}
|
||||
|
||||
// StandardMuxAdapter creates routes compatible with standard http.HandlerFunc
|
||||
type StandardMuxAdapter struct {
|
||||
*MuxAdapter
|
||||
|
||||
119
pkg/common/cors.go
Normal file
119
pkg/common/cors.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CORSConfig holds CORS configuration
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string
|
||||
AllowedMethods []string
|
||||
AllowedHeaders []string
|
||||
MaxAge int
|
||||
}
|
||||
|
||||
// DefaultCORSConfig returns a default CORS configuration suitable for HeadSpec
|
||||
func DefaultCORSConfig() CORSConfig {
|
||||
return CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
|
||||
AllowedHeaders: GetHeadSpecHeaders(),
|
||||
MaxAge: 86400, // 24 hours
|
||||
}
|
||||
}
|
||||
|
||||
// GetHeadSpecHeaders returns all headers used by HeadSpec
|
||||
func GetHeadSpecHeaders() []string {
|
||||
return []string{
|
||||
// Standard headers
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"Accept",
|
||||
"Accept-Language",
|
||||
"Content-Language",
|
||||
|
||||
// Field Selection
|
||||
"X-Select-Fields",
|
||||
"X-Not-Select-Fields",
|
||||
"X-Clean-JSON",
|
||||
|
||||
// Filtering & Search
|
||||
"X-FieldFilter-*",
|
||||
"X-SearchFilter-*",
|
||||
"X-SearchOp-*",
|
||||
"X-SearchOr-*",
|
||||
"X-SearchAnd-*",
|
||||
"X-SearchCols",
|
||||
"X-Custom-SQL-W",
|
||||
"X-Custom-SQL-W-*",
|
||||
"X-Custom-SQL-Or",
|
||||
"X-Custom-SQL-Or-*",
|
||||
|
||||
// Joins & Relations
|
||||
"X-Preload",
|
||||
"X-Preload-*",
|
||||
"X-Expand",
|
||||
"X-Expand-*",
|
||||
"X-Custom-SQL-Join",
|
||||
"X-Custom-SQL-Join-*",
|
||||
|
||||
// Sorting & Pagination
|
||||
"X-Sort",
|
||||
"X-Sort-*",
|
||||
"X-Limit",
|
||||
"X-Offset",
|
||||
"X-Cursor-Forward",
|
||||
"X-Cursor-Backward",
|
||||
|
||||
// Advanced Features
|
||||
"X-AdvSQL-*",
|
||||
"X-CQL-Sel-*",
|
||||
"X-Distinct",
|
||||
"X-SkipCount",
|
||||
"X-SkipCache",
|
||||
"X-Fetch-RowNumber",
|
||||
"X-PKRow",
|
||||
|
||||
// Response Format
|
||||
"X-SimpleAPI",
|
||||
"X-DetailAPI",
|
||||
"X-Syncfusion",
|
||||
"X-Single-Record-As-Object",
|
||||
|
||||
// Transaction Control
|
||||
"X-Transaction-Atomic",
|
||||
|
||||
// X-Files - comprehensive JSON configuration
|
||||
"X-Files",
|
||||
}
|
||||
}
|
||||
|
||||
// SetCORSHeaders sets CORS headers on a response writer
|
||||
func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
|
||||
// Set allowed origins
|
||||
if len(config.AllowedOrigins) > 0 {
|
||||
w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
|
||||
}
|
||||
|
||||
// Set allowed methods
|
||||
if len(config.AllowedMethods) > 0 {
|
||||
w.SetHeader("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
|
||||
}
|
||||
|
||||
// Set allowed headers
|
||||
if len(config.AllowedHeaders) > 0 {
|
||||
w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
|
||||
}
|
||||
|
||||
// Set max age
|
||||
if config.MaxAge > 0 {
|
||||
w.SetHeader("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
|
||||
}
|
||||
|
||||
// Allow credentials
|
||||
w.SetHeader("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
// Expose headers that clients can read
|
||||
w.SetHeader("Access-Control-Expose-Headers", "Content-Range, X-Api-Range-Total, X-Api-Range-Size")
|
||||
}
|
||||
97
pkg/common/handler_example.go
Normal file
97
pkg/common/handler_example.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package common
|
||||
|
||||
// Example showing how to use the common handler interfaces
|
||||
// This file demonstrates the handler interface hierarchy and usage patterns
|
||||
|
||||
// ProcessWithAnyHandler demonstrates using the base SpecHandler interface
|
||||
// which works with any handler type (resolvespec, restheadspec, or funcspec)
|
||||
func ProcessWithAnyHandler(handler SpecHandler) Database {
|
||||
// All handlers expose GetDatabase() through the SpecHandler interface
|
||||
return handler.GetDatabase()
|
||||
}
|
||||
|
||||
// ProcessCRUDRequest demonstrates using the CRUDHandler interface
|
||||
// which works with resolvespec.Handler and restheadspec.Handler
|
||||
func ProcessCRUDRequest(handler CRUDHandler, w ResponseWriter, r Request, params map[string]string) {
|
||||
// Both resolvespec and restheadspec handlers implement Handle()
|
||||
handler.Handle(w, r, params)
|
||||
}
|
||||
|
||||
// ProcessMetadataRequest demonstrates getting metadata from CRUD handlers
|
||||
func ProcessMetadataRequest(handler CRUDHandler, w ResponseWriter, r Request, params map[string]string) {
|
||||
// Both resolvespec and restheadspec handlers implement HandleGet()
|
||||
handler.HandleGet(w, r, params)
|
||||
}
|
||||
|
||||
// Example usage patterns (not executable, just for documentation):
|
||||
/*
|
||||
// Example 1: Using with resolvespec.Handler
|
||||
func ExampleResolveSpec() {
|
||||
db := // ... get database
|
||||
registry := // ... get registry
|
||||
|
||||
handler := resolvespec.NewHandler(db, registry)
|
||||
|
||||
// Can be used as SpecHandler
|
||||
var specHandler SpecHandler = handler
|
||||
database := specHandler.GetDatabase()
|
||||
|
||||
// Can be used as CRUDHandler
|
||||
var crudHandler CRUDHandler = handler
|
||||
crudHandler.Handle(w, r, params)
|
||||
crudHandler.HandleGet(w, r, params)
|
||||
}
|
||||
|
||||
// Example 2: Using with restheadspec.Handler
|
||||
func ExampleRestHeadSpec() {
|
||||
db := // ... get database
|
||||
registry := // ... get registry
|
||||
|
||||
handler := restheadspec.NewHandler(db, registry)
|
||||
|
||||
// Can be used as SpecHandler
|
||||
var specHandler SpecHandler = handler
|
||||
database := specHandler.GetDatabase()
|
||||
|
||||
// Can be used as CRUDHandler
|
||||
var crudHandler CRUDHandler = handler
|
||||
crudHandler.Handle(w, r, params)
|
||||
crudHandler.HandleGet(w, r, params)
|
||||
}
|
||||
|
||||
// Example 3: Using with funcspec.Handler
|
||||
func ExampleFuncSpec() {
|
||||
db := // ... get database
|
||||
|
||||
handler := funcspec.NewHandler(db)
|
||||
|
||||
// Can be used as SpecHandler
|
||||
var specHandler SpecHandler = handler
|
||||
database := specHandler.GetDatabase()
|
||||
|
||||
// Can be used as QueryHandler
|
||||
var queryHandler QueryHandler = handler
|
||||
// funcspec has different methods: SqlQueryList() and SqlQuery()
|
||||
// which return HTTP handler functions
|
||||
}
|
||||
|
||||
// Example 4: Polymorphic handler processing
|
||||
func ProcessHandlers(handlers []SpecHandler) {
|
||||
for _, handler := range handlers {
|
||||
// All handlers expose the database
|
||||
db := handler.GetDatabase()
|
||||
|
||||
// Type switch for specific handler types
|
||||
switch h := handler.(type) {
|
||||
case CRUDHandler:
|
||||
// This is resolvespec or restheadspec
|
||||
// Can call Handle() and HandleGet()
|
||||
_ = h
|
||||
case QueryHandler:
|
||||
// This is funcspec
|
||||
// Can call SqlQueryList() and SqlQuery()
|
||||
_ = h
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
47
pkg/common/handler_utils.go
Normal file
47
pkg/common/handler_utils.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// ValidateAndUnwrapModelResult contains the result of model validation
|
||||
type ValidateAndUnwrapModelResult struct {
|
||||
ModelType reflect.Type
|
||||
Model interface{}
|
||||
ModelPtr interface{}
|
||||
OriginalType reflect.Type
|
||||
}
|
||||
|
||||
// ValidateAndUnwrapModel validates that a model is a struct type and unwraps
|
||||
// pointers, slices, and arrays to get to the base struct type.
|
||||
// Returns an error if the model is not a valid struct type.
|
||||
func ValidateAndUnwrapModel(model interface{}) (*ValidateAndUnwrapModelResult, error) {
|
||||
modelType := reflect.TypeOf(model)
|
||||
originalType := modelType
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Validate that we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType)
|
||||
}
|
||||
|
||||
// If the registered model was a pointer or slice, use the unwrapped struct type
|
||||
if originalType != modelType {
|
||||
model = reflect.New(modelType).Elem().Interface()
|
||||
}
|
||||
|
||||
// Create a pointer to the model type for database operations
|
||||
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
|
||||
|
||||
return &ValidateAndUnwrapModelResult{
|
||||
ModelType: modelType,
|
||||
Model: model,
|
||||
ModelPtr: modelPtr,
|
||||
OriginalType: originalType,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,6 +1,11 @@
|
||||
package common
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Database interface designed to work with both GORM and Bun
|
||||
type Database interface {
|
||||
@@ -26,11 +31,13 @@ type SelectQuery interface {
|
||||
Model(model interface{}) SelectQuery
|
||||
Table(table string) SelectQuery
|
||||
Column(columns ...string) SelectQuery
|
||||
ColumnExpr(query string, args ...interface{}) SelectQuery
|
||||
Where(query string, args ...interface{}) SelectQuery
|
||||
WhereOr(query string, args ...interface{}) SelectQuery
|
||||
Join(query string, args ...interface{}) SelectQuery
|
||||
LeftJoin(query string, args ...interface{}) SelectQuery
|
||||
Preload(relation string, conditions ...interface{}) SelectQuery
|
||||
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
|
||||
Order(order string) SelectQuery
|
||||
Limit(n int) SelectQuery
|
||||
Offset(n int) SelectQuery
|
||||
@@ -39,6 +46,7 @@ type SelectQuery interface {
|
||||
|
||||
// Execution methods
|
||||
Scan(ctx context.Context, dest interface{}) error
|
||||
ScanModel(ctx context.Context) error
|
||||
Count(ctx context.Context) (int, error)
|
||||
Exists(ctx context.Context) (bool, error)
|
||||
}
|
||||
@@ -113,6 +121,8 @@ type Request interface {
|
||||
Body() ([]byte, error)
|
||||
PathParam(key string) string
|
||||
QueryParam(key string) string
|
||||
AllQueryParams() map[string]string // Get all query parameters as a map
|
||||
UnderlyingRequest() *http.Request // Get the underlying *http.Request for forwarding to other handlers
|
||||
}
|
||||
|
||||
// ResponseWriter interface abstracts HTTP response
|
||||
@@ -121,17 +131,164 @@ type ResponseWriter interface {
|
||||
WriteHeader(statusCode int)
|
||||
Write(data []byte) (int, error)
|
||||
WriteJSON(data interface{}) error
|
||||
UnderlyingResponseWriter() http.ResponseWriter // Get the underlying http.ResponseWriter for forwarding to other handlers
|
||||
}
|
||||
|
||||
// HTTPHandlerFunc type for HTTP handlers
|
||||
type HTTPHandlerFunc func(ResponseWriter, Request)
|
||||
|
||||
// WrapHTTPRequest wraps standard http.ResponseWriter and *http.Request into common interfaces
|
||||
func WrapHTTPRequest(w http.ResponseWriter, r *http.Request) (ResponseWriter, Request) {
|
||||
return &StandardResponseWriter{w: w}, &StandardRequest{r: r}
|
||||
}
|
||||
|
||||
// StandardResponseWriter adapts http.ResponseWriter to ResponseWriter interface
|
||||
type StandardResponseWriter struct {
|
||||
w http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (s *StandardResponseWriter) SetHeader(key, value string) {
|
||||
s.w.Header().Set(key, value)
|
||||
}
|
||||
|
||||
func (s *StandardResponseWriter) WriteHeader(statusCode int) {
|
||||
s.status = statusCode
|
||||
s.w.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (s *StandardResponseWriter) Write(data []byte) (int, error) {
|
||||
return s.w.Write(data)
|
||||
}
|
||||
|
||||
func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
|
||||
s.SetHeader("Content-Type", "application/json")
|
||||
return json.NewEncoder(s.w).Encode(data)
|
||||
}
|
||||
|
||||
func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
|
||||
return s.w
|
||||
}
|
||||
|
||||
// StandardRequest adapts *http.Request to Request interface
|
||||
type StandardRequest struct {
|
||||
r *http.Request
|
||||
body []byte
|
||||
}
|
||||
|
||||
func (s *StandardRequest) Method() string {
|
||||
return s.r.Method
|
||||
}
|
||||
|
||||
func (s *StandardRequest) URL() string {
|
||||
return s.r.URL.String()
|
||||
}
|
||||
|
||||
func (s *StandardRequest) Header(key string) string {
|
||||
return s.r.Header.Get(key)
|
||||
}
|
||||
|
||||
func (s *StandardRequest) AllHeaders() map[string]string {
|
||||
headers := make(map[string]string)
|
||||
for key, values := range s.r.Header {
|
||||
if len(values) > 0 {
|
||||
headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func (s *StandardRequest) Body() ([]byte, error) {
|
||||
if s.body != nil {
|
||||
return s.body, nil
|
||||
}
|
||||
if s.r.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer s.r.Body.Close()
|
||||
body, err := io.ReadAll(s.r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.body = body
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func (s *StandardRequest) PathParam(key string) string {
|
||||
// Standard http.Request doesn't have path params
|
||||
// This should be set by the router
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *StandardRequest) QueryParam(key string) string {
|
||||
return s.r.URL.Query().Get(key)
|
||||
}
|
||||
|
||||
func (s *StandardRequest) AllQueryParams() map[string]string {
|
||||
params := make(map[string]string)
|
||||
for key, values := range s.r.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
params[key] = values[0]
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
func (s *StandardRequest) UnderlyingRequest() *http.Request {
|
||||
return s.r
|
||||
}
|
||||
|
||||
// TableNameProvider interface for models that provide table names
|
||||
type TableNameProvider interface {
|
||||
TableName() string
|
||||
}
|
||||
|
||||
type TableAliasProvider interface {
|
||||
TableAlias() string
|
||||
}
|
||||
|
||||
// PrimaryKeyNameProvider interface for models that provide primary key column names
|
||||
type PrimaryKeyNameProvider interface {
|
||||
GetIDName() string
|
||||
}
|
||||
|
||||
// SchemaProvider interface for models that provide schema names
|
||||
type SchemaProvider interface {
|
||||
SchemaName() string
|
||||
}
|
||||
|
||||
// SpecHandler interface represents common functionality across all spec handlers
|
||||
// This is the base interface implemented by:
|
||||
// - resolvespec.Handler: Handles CRUD operations via request body with explicit operation field
|
||||
// - restheadspec.Handler: Handles CRUD operations via HTTP methods (GET/POST/PUT/DELETE)
|
||||
// - funcspec.Handler: Handles custom SQL query execution with dynamic parameters
|
||||
//
|
||||
// The interface hierarchy is:
|
||||
//
|
||||
// SpecHandler (base)
|
||||
// ├── CRUDHandler (resolvespec, restheadspec)
|
||||
// └── QueryHandler (funcspec)
|
||||
type SpecHandler interface {
|
||||
// GetDatabase returns the underlying database connection
|
||||
GetDatabase() Database
|
||||
}
|
||||
|
||||
// CRUDHandler interface for handlers that support CRUD operations
|
||||
// This is implemented by resolvespec.Handler and restheadspec.Handler
|
||||
type CRUDHandler interface {
|
||||
SpecHandler
|
||||
|
||||
// Handle processes API requests through router-agnostic interface
|
||||
Handle(w ResponseWriter, r Request, params map[string]string)
|
||||
|
||||
// HandleGet processes GET requests for metadata
|
||||
HandleGet(w ResponseWriter, r Request, params map[string]string)
|
||||
}
|
||||
|
||||
// QueryHandler interface for handlers that execute SQL queries
|
||||
// This is implemented by funcspec.Handler
|
||||
// Note: funcspec uses standard http.ResponseWriter and *http.Request instead of common interfaces
|
||||
type QueryHandler interface {
|
||||
SpecHandler
|
||||
// Methods are defined in funcspec package due to different function signature requirements
|
||||
}
|
||||
|
||||
453
pkg/common/recursive_crud.go
Normal file
453
pkg/common/recursive_crud.go
Normal file
@@ -0,0 +1,453 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// CRUDRequestProvider interface for models that provide CRUD request strings
|
||||
type CRUDRequestProvider interface {
|
||||
GetRequest() string
|
||||
}
|
||||
|
||||
// RelationshipInfoProvider interface for handlers that can provide relationship info
|
||||
type RelationshipInfoProvider interface {
|
||||
GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo
|
||||
}
|
||||
|
||||
// RelationshipInfo contains information about a model relationship
|
||||
type RelationshipInfo struct {
|
||||
FieldName string
|
||||
JSONName string
|
||||
RelationType string // "belongsTo", "hasMany", "hasOne", "many2many"
|
||||
ForeignKey string
|
||||
References string
|
||||
JoinTable string
|
||||
RelatedModel interface{}
|
||||
}
|
||||
|
||||
// NestedCUDProcessor handles recursive processing of nested object graphs
|
||||
type NestedCUDProcessor struct {
|
||||
db Database
|
||||
registry ModelRegistry
|
||||
relationshipHelper RelationshipInfoProvider
|
||||
}
|
||||
|
||||
// NewNestedCUDProcessor creates a new nested CUD processor
|
||||
func NewNestedCUDProcessor(db Database, registry ModelRegistry, relationshipHelper RelationshipInfoProvider) *NestedCUDProcessor {
|
||||
return &NestedCUDProcessor{
|
||||
db: db,
|
||||
registry: registry,
|
||||
relationshipHelper: relationshipHelper,
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessResult contains the result of processing a CUD operation
|
||||
type ProcessResult struct {
|
||||
ID interface{} // The ID of the processed record
|
||||
AffectedRows int64 // Number of rows affected
|
||||
Data map[string]interface{} // The processed data
|
||||
RelationData map[string]interface{} // Data from processed relations
|
||||
}
|
||||
|
||||
// ProcessNestedCUD recursively processes nested object graphs for Create, Update, Delete operations
|
||||
// with automatic foreign key resolution
|
||||
func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
ctx context.Context,
|
||||
operation string, // "insert", "update", or "delete"
|
||||
data map[string]interface{},
|
||||
model interface{},
|
||||
parentIDs map[string]interface{}, // Parent IDs for foreign key resolution
|
||||
tableName string,
|
||||
) (*ProcessResult, error) {
|
||||
logger.Info("Processing nested CUD: operation=%s, table=%s", operation, tableName)
|
||||
|
||||
result := &ProcessResult{
|
||||
Data: make(map[string]interface{}),
|
||||
RelationData: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Check if data has a _request field that overrides the operation
|
||||
if requestOp := p.extractCRUDRequest(data); requestOp != "" {
|
||||
logger.Debug("Found _request override: %s", requestOp)
|
||||
operation = requestOp
|
||||
}
|
||||
|
||||
// Get model type for reflection
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("model must be a struct type, got %v", modelType)
|
||||
}
|
||||
|
||||
// Separate relation fields from regular fields
|
||||
relationFields := make(map[string]*RelationshipInfo)
|
||||
regularData := make(map[string]interface{})
|
||||
|
||||
for key, value := range data {
|
||||
// Skip _request field in actual data processing
|
||||
if key == "_request" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this field is a relation
|
||||
relInfo := p.relationshipHelper.GetRelationshipInfo(modelType, key)
|
||||
if relInfo != nil {
|
||||
relationFields[key] = relInfo
|
||||
result.RelationData[key] = value
|
||||
} else {
|
||||
regularData[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Inject parent IDs for foreign key resolution
|
||||
p.injectForeignKeys(regularData, modelType, parentIDs)
|
||||
|
||||
// Get the primary key name for this model
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// Process based on operation
|
||||
switch strings.ToLower(operation) {
|
||||
case "insert", "create":
|
||||
id, err := p.processInsert(ctx, regularData, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
result.ID = id
|
||||
result.AffectedRows = 1
|
||||
result.Data = regularData
|
||||
|
||||
// Process child relations after parent insert (to get parent ID)
|
||||
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType); err != nil {
|
||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||
}
|
||||
|
||||
case "update":
|
||||
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
result.ID = data[pkName]
|
||||
result.AffectedRows = rows
|
||||
result.Data = regularData
|
||||
|
||||
// Process child relations for update
|
||||
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||
}
|
||||
|
||||
case "delete":
|
||||
// Process child relations first (for referential integrity)
|
||||
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
|
||||
}
|
||||
|
||||
rows, err := p.processDelete(ctx, tableName, data[pkName])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("delete failed: %w", err)
|
||||
}
|
||||
result.ID = data[pkName]
|
||||
result.AffectedRows = rows
|
||||
result.Data = regularData
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported operation: %s", operation)
|
||||
}
|
||||
|
||||
logger.Info("Nested CUD completed: operation=%s, id=%v, rows=%d", operation, result.ID, result.AffectedRows)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// extractCRUDRequest extracts the request field from data if present
|
||||
func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) string {
|
||||
if request, ok := data["_request"]; ok {
|
||||
if requestStr, ok := request.(string); ok {
|
||||
return strings.ToLower(strings.TrimSpace(requestStr))
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// injectForeignKeys injects parent IDs into data for foreign key fields
|
||||
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
|
||||
if len(parentIDs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Iterate through model fields to find foreign key fields
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
jsonName := strings.Split(jsonTag, ",")[0]
|
||||
|
||||
// Check if this field is a foreign key and we have a parent ID for it
|
||||
// Common patterns: DepartmentID, ManagerID, ProjectID, etc.
|
||||
for parentKey, parentID := range parentIDs {
|
||||
// Match field name patterns like "department_id" with parent key "department"
|
||||
if strings.EqualFold(jsonName, parentKey+"_id") ||
|
||||
strings.EqualFold(jsonName, parentKey+"id") ||
|
||||
strings.EqualFold(field.Name, parentKey+"ID") {
|
||||
// Only inject if not already present
|
||||
if _, exists := data[jsonName]; !exists {
|
||||
logger.Debug("Injecting foreign key: %s = %v", jsonName, parentID)
|
||||
data[jsonName] = parentID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processInsert handles insert operation
|
||||
func (p *NestedCUDProcessor) processInsert(
|
||||
ctx context.Context,
|
||||
data map[string]interface{},
|
||||
tableName string,
|
||||
) (interface{}, error) {
|
||||
logger.Debug("Inserting into %s with data: %+v", tableName, data)
|
||||
|
||||
query := p.db.NewInsert().Table(tableName)
|
||||
|
||||
for key, value := range data {
|
||||
query = query.Value(key, value)
|
||||
}
|
||||
|
||||
// Add RETURNING clause to get the inserted ID
|
||||
query = query.Returning("id")
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert exec failed: %w", err)
|
||||
}
|
||||
|
||||
// Try to get the ID
|
||||
var id interface{}
|
||||
if lastID, err := result.LastInsertId(); err == nil && lastID > 0 {
|
||||
id = lastID
|
||||
} else if data["id"] != nil {
|
||||
id = data["id"]
|
||||
}
|
||||
|
||||
logger.Debug("Insert successful, ID: %v, rows affected: %d", id, result.RowsAffected())
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// processUpdate handles update operation
|
||||
func (p *NestedCUDProcessor) processUpdate(
|
||||
ctx context.Context,
|
||||
data map[string]interface{},
|
||||
tableName string,
|
||||
id interface{},
|
||||
) (int64, error) {
|
||||
if id == nil {
|
||||
return 0, fmt.Errorf("update requires an ID")
|
||||
}
|
||||
|
||||
logger.Debug("Updating %s with ID %v, data: %+v", tableName, id, data)
|
||||
|
||||
query := p.db.NewUpdate().Table(tableName).SetMap(data).Where(fmt.Sprintf("%s = ?", QuoteIdent(reflection.GetPrimaryKeyName(tableName))), id)
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("update exec failed: %w", err)
|
||||
}
|
||||
|
||||
rows := result.RowsAffected()
|
||||
logger.Debug("Update successful, rows affected: %d", rows)
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// processDelete handles delete operation
|
||||
func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string, id interface{}) (int64, error) {
|
||||
if id == nil {
|
||||
return 0, fmt.Errorf("delete requires an ID")
|
||||
}
|
||||
|
||||
logger.Debug("Deleting from %s with ID %v", tableName, id)
|
||||
|
||||
query := p.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", QuoteIdent(reflection.GetPrimaryKeyName(tableName))), id)
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("delete exec failed: %w", err)
|
||||
}
|
||||
|
||||
rows := result.RowsAffected()
|
||||
logger.Debug("Delete successful, rows affected: %d", rows)
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// processChildRelations recursively processes child relations
|
||||
func (p *NestedCUDProcessor) processChildRelations(
|
||||
ctx context.Context,
|
||||
operation string,
|
||||
parentID interface{},
|
||||
relationFields map[string]*RelationshipInfo,
|
||||
relationData map[string]interface{},
|
||||
parentModelType reflect.Type,
|
||||
) error {
|
||||
for relationName, relInfo := range relationFields {
|
||||
relationValue, exists := relationData[relationName]
|
||||
if !exists || relationValue == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debug("Processing relation: %s, type: %s", relationName, relInfo.RelationType)
|
||||
|
||||
// Get the related model
|
||||
field, found := parentModelType.FieldByName(relInfo.FieldName)
|
||||
if !found {
|
||||
logger.Warn("Field %s not found in model", relInfo.FieldName)
|
||||
continue
|
||||
}
|
||||
|
||||
// Get the model type for the relation
|
||||
relatedModelType := field.Type
|
||||
if relatedModelType.Kind() == reflect.Slice {
|
||||
relatedModelType = relatedModelType.Elem()
|
||||
}
|
||||
if relatedModelType.Kind() == reflect.Ptr {
|
||||
relatedModelType = relatedModelType.Elem()
|
||||
}
|
||||
|
||||
// Create an instance of the related model
|
||||
relatedModel := reflect.New(relatedModelType).Elem().Interface()
|
||||
|
||||
// Get table name for related model
|
||||
relatedTableName := p.getTableNameForModel(relatedModel, relInfo.JSONName)
|
||||
|
||||
// Prepare parent IDs for foreign key injection
|
||||
parentIDs := make(map[string]interface{})
|
||||
if relInfo.ForeignKey != "" {
|
||||
// Extract the base name from foreign key (e.g., "DepartmentID" -> "Department")
|
||||
baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID")
|
||||
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
|
||||
parentIDs[baseName] = parentID
|
||||
}
|
||||
|
||||
// Process based on relation type and data structure
|
||||
switch v := relationValue.(type) {
|
||||
case map[string]interface{}:
|
||||
// Single related object
|
||||
_, err := p.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process relation %s: %w", relationName, err)
|
||||
}
|
||||
|
||||
case []interface{}:
|
||||
// Multiple related objects
|
||||
for i, item := range v {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case []map[string]interface{}:
|
||||
// Multiple related objects (typed slice)
|
||||
for i, itemMap := range v {
|
||||
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
logger.Warn("Unsupported relation data type for %s: %T", relationName, relationValue)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getTableNameForModel gets the table name for a model
|
||||
func (p *NestedCUDProcessor) getTableNameForModel(model interface{}, defaultName string) string {
|
||||
if provider, ok := model.(TableNameProvider); ok {
|
||||
tableName := provider.TableName()
|
||||
if tableName != "" {
|
||||
return tableName
|
||||
}
|
||||
}
|
||||
return defaultName
|
||||
}
|
||||
|
||||
// ShouldUseNestedProcessor determines if we should use nested CUD processing
|
||||
// It recursively checks if the data contains:
|
||||
// 1. A _request field at any level, OR
|
||||
// 2. Nested relations that themselves contain further nested relations or _request fields
|
||||
// This ensures nested processing is only used when there are deeply nested operations
|
||||
func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider) bool {
|
||||
return shouldUseNestedProcessorDepth(data, model, relationshipHelper, 0)
|
||||
}
|
||||
|
||||
// shouldUseNestedProcessorDepth is the internal recursive implementation with depth tracking
|
||||
func shouldUseNestedProcessorDepth(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider, depth int) bool {
|
||||
// Check for _request field
|
||||
if _, hasCRUDRequest := data["_request"]; hasCRUDRequest {
|
||||
return true
|
||||
}
|
||||
|
||||
// Get model type
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if data contains any fields that are relations (nested objects or arrays)
|
||||
for key, value := range data {
|
||||
// Skip _request and regular scalar fields
|
||||
if key == "_request" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this field is a relation in the model
|
||||
relInfo := relationshipHelper.GetRelationshipInfo(modelType, key)
|
||||
if relInfo != nil {
|
||||
// Check if the value is actually nested data (object or array)
|
||||
switch v := value.(type) {
|
||||
case map[string]interface{}, []interface{}, []map[string]interface{}:
|
||||
// If we're already at a nested level (depth > 0) and found a relation,
|
||||
// that means we have multi-level nesting, so return true
|
||||
if depth > 0 {
|
||||
return true
|
||||
}
|
||||
// At depth 0, recurse to check if the nested data has further nesting
|
||||
switch typedValue := v.(type) {
|
||||
case map[string]interface{}:
|
||||
if shouldUseNestedProcessorDepth(typedValue, relInfo.RelatedModel, relationshipHelper, depth+1) {
|
||||
return true
|
||||
}
|
||||
case []interface{}:
|
||||
for _, item := range typedValue {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
case []map[string]interface{}:
|
||||
for _, itemMap := range typedValue {
|
||||
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
340
pkg/common/sql_helpers.go
Normal file
340
pkg/common/sql_helpers.go
Normal file
@@ -0,0 +1,340 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// ValidateAndFixPreloadWhere validates that the WHERE clause for a preload contains
|
||||
// the relation prefix (alias). If not present, it attempts to add it to column references.
|
||||
// Returns the fixed WHERE clause and an error if it cannot be safely fixed.
|
||||
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
|
||||
if where == "" {
|
||||
return where, nil
|
||||
}
|
||||
|
||||
// Check if the relation name is already present in the WHERE clause
|
||||
lowerWhere := strings.ToLower(where)
|
||||
lowerRelation := strings.ToLower(relationName)
|
||||
|
||||
// Check for patterns like "relation.", "relation ", or just "relation" followed by a dot
|
||||
if strings.Contains(lowerWhere, lowerRelation+".") ||
|
||||
strings.Contains(lowerWhere, "`"+lowerRelation+"`.") ||
|
||||
strings.Contains(lowerWhere, "\""+lowerRelation+"\".") {
|
||||
// Relation prefix is already present
|
||||
return where, nil
|
||||
}
|
||||
|
||||
// If the WHERE clause is complex (contains OR, parentheses, subqueries, etc.),
|
||||
// we can't safely auto-fix it - require explicit prefix
|
||||
if strings.Contains(lowerWhere, " or ") ||
|
||||
strings.Contains(where, "(") ||
|
||||
strings.Contains(where, ")") {
|
||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Complex WHERE clauses with OR/parentheses must explicitly use the relation prefix", relationName, relationName)
|
||||
}
|
||||
|
||||
// Try to add the relation prefix to simple column references
|
||||
// This handles basic cases like "column = value" or "column = value AND other_column = value"
|
||||
// Split by AND to handle multiple conditions (case-insensitive)
|
||||
originalConditions := strings.Split(where, " AND ")
|
||||
|
||||
// If uppercase split didn't work, try lowercase
|
||||
if len(originalConditions) == 1 {
|
||||
originalConditions = strings.Split(where, " and ")
|
||||
}
|
||||
|
||||
fixedConditions := make([]string, 0, len(originalConditions))
|
||||
|
||||
for _, cond := range originalConditions {
|
||||
cond = strings.TrimSpace(cond)
|
||||
if cond == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this condition already has a table prefix (contains a dot)
|
||||
if strings.Contains(cond, ".") {
|
||||
fixedConditions = append(fixedConditions, cond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||
lowerCond := strings.ToLower(strings.TrimSpace(cond))
|
||||
if IsSQLExpression(lowerCond) {
|
||||
// Don't prefix SQL expressions like "true", "false", "1=1", etc.
|
||||
fixedConditions = append(fixedConditions, cond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract the column name (first identifier before operator)
|
||||
columnName := ExtractColumnName(cond)
|
||||
if columnName == "" {
|
||||
// Can't identify column name, require explicit prefix
|
||||
return "", fmt.Errorf("preload WHERE condition must reference the relation '%s' (e.g., '%s.column_name'). Cannot auto-fix condition: %s", relationName, relationName, cond)
|
||||
}
|
||||
|
||||
// Add relation prefix to the column name only
|
||||
fixedCond := strings.Replace(cond, columnName, relationName+"."+columnName, 1)
|
||||
fixedConditions = append(fixedConditions, fixedCond)
|
||||
}
|
||||
|
||||
fixedWhere := strings.Join(fixedConditions, " AND ")
|
||||
logger.Debug("Auto-fixed preload WHERE clause: '%s' -> '%s'", where, fixedWhere)
|
||||
return fixedWhere, nil
|
||||
}
|
||||
|
||||
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
|
||||
func IsSQLExpression(cond string) bool {
|
||||
// Common SQL literals and expressions
|
||||
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
|
||||
for _, literal := range sqlLiterals {
|
||||
if cond == literal {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
|
||||
// These conditions should be removed from WHERE clauses as they have no filtering effect
|
||||
func IsTrivialCondition(cond string) bool {
|
||||
cond = strings.TrimSpace(cond)
|
||||
lowerCond := strings.ToLower(cond)
|
||||
|
||||
// Conditions that always evaluate to true
|
||||
trivialConditions := []string{
|
||||
"1=1", "1 = 1", "1= 1", "1 =1",
|
||||
"true", "true = true", "true=true", "true= true", "true =true",
|
||||
"0=0", "0 = 0", "0= 0", "0 =0",
|
||||
}
|
||||
|
||||
for _, trivial := range trivialConditions {
|
||||
if lowerCond == trivial {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SanitizeWhereClause removes trivial conditions and optionally prefixes table/relation names to columns
|
||||
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
|
||||
//
|
||||
// Parameters:
|
||||
// - where: The WHERE clause string to sanitize
|
||||
// - tableName: Optional table/relation name to prefix to column references (empty string to skip prefixing)
|
||||
//
|
||||
// Returns:
|
||||
// - The sanitized WHERE clause with trivial conditions removed and columns optionally prefixed
|
||||
// - An empty string if all conditions were trivial or the input was empty
|
||||
func SanitizeWhereClause(where string, tableName string) string {
|
||||
if where == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
where = strings.TrimSpace(where)
|
||||
|
||||
// Strip outer parentheses and re-trim
|
||||
where = stripOuterParentheses(where)
|
||||
|
||||
// Get valid columns from the model if tableName is provided
|
||||
var validColumns map[string]bool
|
||||
if tableName != "" {
|
||||
validColumns = getValidColumnsForTable(tableName)
|
||||
}
|
||||
|
||||
// Split by AND to handle multiple conditions
|
||||
conditions := splitByAND(where)
|
||||
|
||||
validConditions := make([]string, 0, len(conditions))
|
||||
|
||||
for _, cond := range conditions {
|
||||
cond = strings.TrimSpace(cond)
|
||||
if cond == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Strip parentheses from the condition before checking
|
||||
condToCheck := stripOuterParentheses(cond)
|
||||
|
||||
// Skip trivial conditions that always evaluate to true
|
||||
if IsTrivialCondition(condToCheck) {
|
||||
logger.Debug("Removing trivial condition: '%s'", cond)
|
||||
continue
|
||||
}
|
||||
|
||||
// If tableName is provided and the condition doesn't already have a table prefix,
|
||||
// attempt to add it
|
||||
if tableName != "" && !hasTablePrefix(condToCheck) {
|
||||
// Check if this is a SQL expression/literal that shouldn't be prefixed
|
||||
if !IsSQLExpression(strings.ToLower(condToCheck)) {
|
||||
// Extract the column name and prefix it
|
||||
columnName := ExtractColumnName(condToCheck)
|
||||
if columnName != "" {
|
||||
// Only prefix if this is a valid column in the model
|
||||
// If we don't have model info (validColumns is nil), prefix anyway for backward compatibility
|
||||
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||
// Replace in the original condition (without stripped parens)
|
||||
cond = strings.Replace(cond, columnName, tableName+"."+columnName, 1)
|
||||
logger.Debug("Prefixed column in condition: '%s'", cond)
|
||||
} else {
|
||||
logger.Debug("Skipping prefix for '%s' - not a valid column in model", columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validConditions = append(validConditions, cond)
|
||||
}
|
||||
|
||||
if len(validConditions) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
result := strings.Join(validConditions, " AND ")
|
||||
|
||||
if result != where {
|
||||
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// stripOuterParentheses removes matching outer parentheses from a string
|
||||
// It handles nested parentheses correctly
|
||||
func stripOuterParentheses(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
for {
|
||||
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
|
||||
return s
|
||||
}
|
||||
|
||||
// Check if these parentheses match (i.e., they're the outermost pair)
|
||||
depth := 0
|
||||
matched := false
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
if depth == 0 && i == len(s)-1 {
|
||||
matched = true
|
||||
} else if depth == 0 {
|
||||
// Found a closing paren before the end, so outer parens don't match
|
||||
return s
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !matched {
|
||||
return s
|
||||
}
|
||||
|
||||
// Strip the outer parentheses and continue
|
||||
s = strings.TrimSpace(s[1 : len(s)-1])
|
||||
}
|
||||
}
|
||||
|
||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||
// This is a simple split that doesn't handle nested parentheses or complex expressions
|
||||
func splitByAND(where string) []string {
|
||||
// First try uppercase AND
|
||||
conditions := strings.Split(where, " AND ")
|
||||
|
||||
// If we didn't split on uppercase, try lowercase
|
||||
if len(conditions) == 1 {
|
||||
conditions = strings.Split(where, " and ")
|
||||
}
|
||||
|
||||
// If we still didn't split, try mixed case
|
||||
if len(conditions) == 1 {
|
||||
conditions = strings.Split(where, " And ")
|
||||
}
|
||||
|
||||
return conditions
|
||||
}
|
||||
|
||||
// hasTablePrefix checks if a condition already has a table/relation prefix (contains a dot)
|
||||
func hasTablePrefix(cond string) bool {
|
||||
// Look for patterns like "table.column" or "`table`.`column`" or "\"table\".\"column\""
|
||||
return strings.Contains(cond, ".")
|
||||
}
|
||||
|
||||
// ExtractColumnName extracts the column name from a WHERE condition
|
||||
// For example: "status = 'active'" returns "status"
|
||||
func ExtractColumnName(cond string) string {
|
||||
// Common SQL operators
|
||||
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
|
||||
|
||||
for _, op := range operators {
|
||||
if idx := strings.Index(cond, op); idx > 0 {
|
||||
columnName := strings.TrimSpace(cond[:idx])
|
||||
// Remove quotes if present
|
||||
columnName = strings.Trim(columnName, "`\"'")
|
||||
return columnName
|
||||
}
|
||||
}
|
||||
|
||||
// If no operator found, check if it's a simple identifier (for boolean columns)
|
||||
parts := strings.Fields(cond)
|
||||
if len(parts) > 0 {
|
||||
columnName := strings.Trim(parts[0], "`\"'")
|
||||
// Check if it's a valid identifier (not a SQL keyword)
|
||||
if !IsSQLKeyword(strings.ToLower(columnName)) {
|
||||
return columnName
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
|
||||
func IsSQLKeyword(word string) bool {
|
||||
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
|
||||
for _, kw := range keywords {
|
||||
if word == kw {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry
|
||||
// Returns a map of column names for fast lookup, or nil if the model is not found
|
||||
func getValidColumnsForTable(tableName string) map[string]bool {
|
||||
// Try to get the model from the registry
|
||||
model, err := modelregistry.GetModelByName(tableName)
|
||||
if err != nil {
|
||||
// Model not found, return nil to indicate we should use fallback behavior
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get SQL columns from the model
|
||||
columns := reflection.GetSQLModelColumns(model)
|
||||
if len(columns) == 0 {
|
||||
// No columns found, return nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build a map for fast lookup
|
||||
columnMap := make(map[string]bool, len(columns))
|
||||
for _, col := range columns {
|
||||
columnMap[strings.ToLower(col)] = true
|
||||
}
|
||||
|
||||
return columnMap
|
||||
}
|
||||
|
||||
// isValidColumn checks if a column name exists in the valid columns map
|
||||
// Handles case-insensitive comparison
|
||||
func isValidColumn(columnName string, validColumns map[string]bool) bool {
|
||||
if validColumns == nil {
|
||||
return true // No model info, assume valid
|
||||
}
|
||||
return validColumns[strings.ToLower(columnName)]
|
||||
}
|
||||
224
pkg/common/sql_helpers_test.go
Normal file
224
pkg/common/sql_helpers_test.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
func TestSanitizeWhereClause(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "trivial conditions in parentheses",
|
||||
where: "(true AND true AND true)",
|
||||
tableName: "mastertask",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "trivial conditions without parentheses",
|
||||
where: "true AND true AND true",
|
||||
tableName: "mastertask",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single trivial condition",
|
||||
where: "true",
|
||||
tableName: "mastertask",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "valid condition with parentheses",
|
||||
where: "(status = 'active')",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "mixed trivial and valid conditions",
|
||||
where: "true AND status = 'active' AND 1=1",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "condition already with table prefix",
|
||||
where: "users.status = 'active'",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "multiple valid conditions",
|
||||
where: "status = 'active' AND age > 18",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' AND users.age > 18",
|
||||
},
|
||||
{
|
||||
name: "no table name provided",
|
||||
where: "status = 'active'",
|
||||
tableName: "",
|
||||
expected: "status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "empty where clause",
|
||||
where: "",
|
||||
tableName: "users",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeWhereClause(tt.where, tt.tableName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripOuterParentheses(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single level parentheses",
|
||||
input: "(true)",
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "multiple levels",
|
||||
input: "((true))",
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "no parentheses",
|
||||
input: "true",
|
||||
expected: "true",
|
||||
},
|
||||
{
|
||||
name: "mismatched parentheses",
|
||||
input: "(true",
|
||||
expected: "(true",
|
||||
},
|
||||
{
|
||||
name: "complex expression",
|
||||
input: "(a AND b)",
|
||||
expected: "a AND b",
|
||||
},
|
||||
{
|
||||
name: "nested but not outer",
|
||||
input: "(a AND (b OR c)) AND d",
|
||||
expected: "(a AND (b OR c)) AND d",
|
||||
},
|
||||
{
|
||||
name: "with spaces",
|
||||
input: " ( true ) ",
|
||||
expected: "true",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := stripOuterParentheses(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("stripOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTrivialCondition(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"true", "true", true},
|
||||
{"true with spaces", " true ", true},
|
||||
{"TRUE uppercase", "TRUE", true},
|
||||
{"1=1", "1=1", true},
|
||||
{"1 = 1", "1 = 1", true},
|
||||
{"true = true", "true = true", true},
|
||||
{"valid condition", "status = 'active'", false},
|
||||
{"false", "false", false},
|
||||
{"column name", "is_active", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsTrivialCondition(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsTrivialCondition(%q) = %v; want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test model for model-aware sanitization tests
|
||||
type MasterTask struct {
|
||||
ID int `bun:"id,pk"`
|
||||
Name string `bun:"name"`
|
||||
Status string `bun:"status"`
|
||||
UserID int `bun:"user_id"`
|
||||
}
|
||||
|
||||
func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||
// Register the test model
|
||||
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
|
||||
if err != nil {
|
||||
// Model might already be registered, ignore error
|
||||
t.Logf("Model registration returned: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "valid column gets prefixed",
|
||||
where: "status = 'active'",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "multiple valid columns get prefixed",
|
||||
where: "status = 'active' AND user_id = 123",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
|
||||
},
|
||||
{
|
||||
name: "invalid column does not get prefixed",
|
||||
where: "invalid_column = 'value'",
|
||||
tableName: "mastertask",
|
||||
expected: "invalid_column = 'value'",
|
||||
},
|
||||
{
|
||||
name: "mix of valid and trivial conditions",
|
||||
where: "true AND status = 'active' AND 1=1",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
},
|
||||
{
|
||||
name: "parentheses with valid column",
|
||||
where: "(status = 'active')",
|
||||
tableName: "mastertask",
|
||||
expected: "mastertask.status = 'active'",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeWhereClause(tt.where, tt.tableName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
771
pkg/common/sql_types.go
Normal file
771
pkg/common/sql_types.go
Normal file
@@ -0,0 +1,771 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func tryParseDT(str string) (time.Time, error) {
|
||||
var lasterror error
|
||||
tryFormats := []string{time.RFC3339,
|
||||
"2006-01-02T15:04:05.000-0700",
|
||||
"2006-01-02T15:04:05.000",
|
||||
"06-01-02T15:04:05.000",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02 15:04:05",
|
||||
"02/01/2006",
|
||||
"02-01-2006",
|
||||
"2006-01-02",
|
||||
"15:04:05.000",
|
||||
"15:04:05",
|
||||
"15:04"}
|
||||
|
||||
for _, f := range tryFormats {
|
||||
tx, err := time.Parse(f, str)
|
||||
if err == nil {
|
||||
return tx, nil
|
||||
} else {
|
||||
lasterror = err
|
||||
}
|
||||
}
|
||||
|
||||
return time.Now(), lasterror
|
||||
}
|
||||
|
||||
func ToJSONDT(dt time.Time) string {
|
||||
return dt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// SqlInt16 - A Int16 that supports SQL string
|
||||
type SqlInt16 int16
|
||||
|
||||
// Scan -
|
||||
func (n *SqlInt16) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*n = 0
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
*n = SqlInt16(v)
|
||||
case int32:
|
||||
*n = SqlInt16(v)
|
||||
case int64:
|
||||
*n = SqlInt16(v)
|
||||
default:
|
||||
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||
*n = SqlInt16(i)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value -
|
||||
func (n SqlInt16) Value() (driver.Value, error) {
|
||||
if n == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return int64(n), nil
|
||||
}
|
||||
|
||||
// String - Override String format of ZNullInt32
|
||||
func (n SqlInt16) String() string {
|
||||
tmstr := fmt.Sprintf("%d", n)
|
||||
return tmstr
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Overre JidSON format of ZNullInt32
|
||||
func (n *SqlInt16) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
|
||||
n64, err := strconv.ParseInt(s, 10, 64)
|
||||
if err == nil {
|
||||
*n = SqlInt16(n64)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlInt16) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf("%d", n)), nil
|
||||
}
|
||||
|
||||
// SqlInt32 - A int32 that supports SQL string
|
||||
type SqlInt32 int32
|
||||
|
||||
// Scan -
|
||||
func (n *SqlInt32) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*n = 0
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
*n = SqlInt32(v)
|
||||
case int32:
|
||||
*n = SqlInt32(v)
|
||||
case int64:
|
||||
*n = SqlInt32(v)
|
||||
default:
|
||||
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||
*n = SqlInt32(i)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value -
|
||||
func (n SqlInt32) Value() (driver.Value, error) {
|
||||
if n == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return int64(n), nil
|
||||
}
|
||||
|
||||
// String - Override String format of ZNullInt32
|
||||
func (n SqlInt32) String() string {
|
||||
tmstr := fmt.Sprintf("%d", n)
|
||||
return tmstr
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Overre JidSON format of ZNullInt32
|
||||
func (n *SqlInt32) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
|
||||
n64, err := strconv.ParseInt(s, 10, 64)
|
||||
if err == nil {
|
||||
*n = SqlInt32(n64)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlInt32) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf("%d", n)), nil
|
||||
}
|
||||
|
||||
// SqlInt64 - A int64 that supports SQL string
|
||||
type SqlInt64 int64
|
||||
|
||||
// Scan -
|
||||
func (n *SqlInt64) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*n = 0
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
*n = SqlInt64(v)
|
||||
case int32:
|
||||
*n = SqlInt64(v)
|
||||
case uint32:
|
||||
*n = SqlInt64(v)
|
||||
case int64:
|
||||
*n = SqlInt64(v)
|
||||
case uint64:
|
||||
*n = SqlInt64(v)
|
||||
default:
|
||||
i, _ := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||
*n = SqlInt64(i)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value -
|
||||
func (n SqlInt64) Value() (driver.Value, error) {
|
||||
if n == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return int64(n), nil
|
||||
}
|
||||
|
||||
// String - Override String format of ZNullInt32
|
||||
func (n SqlInt64) String() string {
|
||||
tmstr := fmt.Sprintf("%d", n)
|
||||
return tmstr
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Overre JidSON format of ZNullInt32
|
||||
func (n *SqlInt64) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
|
||||
n64, err := strconv.ParseInt(s, 10, 64)
|
||||
if err == nil {
|
||||
*n = SqlInt64(n64)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlInt64) MarshalJSON() ([]byte, error) {
|
||||
return []byte(fmt.Sprintf("%d", n)), nil
|
||||
}
|
||||
|
||||
// SqlTimeStamp - Implementation of SqlTimeStamp with some interfaces.
|
||||
type SqlTimeStamp time.Time
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (t SqlTimeStamp) MarshalJSON() ([]byte, error) {
|
||||
if time.Time(t).IsZero() {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
if time.Time(t).Before(time.Date(0001, 1, 1, 0, 0, 0, 0, time.UTC)) {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
tmstr := time.Time(t).Format("2006-01-02T15:04:05")
|
||||
if tmstr == "0001-01-01T00:00:00" {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Override JSON format of time
|
||||
func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
|
||||
var err error
|
||||
|
||||
if b == nil {
|
||||
|
||||
return nil
|
||||
}
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
if s == "null" || s == "" || s == "0" ||
|
||||
s == "0001-01-01T00:00:00" || s == "0001-01-01" {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := tryParseDT(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*t = SqlTimeStamp(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Value - SQL Value of custom date
|
||||
func (t SqlTimeStamp) Value() (driver.Value, error) {
|
||||
if t.GetTime().IsZero() || t.GetTime().Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) {
|
||||
return nil, nil
|
||||
}
|
||||
tmstr := time.Time(t).Format("2006-01-02T15:04:05")
|
||||
if tmstr <= "0001-01-01" || tmstr == "" {
|
||||
empty := time.Time{}
|
||||
return empty, nil
|
||||
}
|
||||
|
||||
return tmstr, nil
|
||||
}
|
||||
|
||||
// Scan - Scan custom date from sql
|
||||
func (t *SqlTimeStamp) Scan(value interface{}) error {
|
||||
tm, ok := value.(time.Time)
|
||||
if ok {
|
||||
*t = SqlTimeStamp(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
str, ok := value.(string)
|
||||
if ok {
|
||||
tx, err := tryParseDT(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*t = SqlTimeStamp(tx)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// String - Override String format of time
|
||||
func (t SqlTimeStamp) String() string {
|
||||
return time.Time(t).Format("2006-01-02T15:04:05")
|
||||
}
|
||||
|
||||
// GetTime - Returns Time
|
||||
func (t SqlTimeStamp) GetTime() time.Time {
|
||||
return time.Time(t)
|
||||
}
|
||||
|
||||
// SetTime - Returns Time
|
||||
func (t *SqlTimeStamp) SetTime(pTime time.Time) {
|
||||
*t = SqlTimeStamp(pTime)
|
||||
}
|
||||
|
||||
// Format - Formats the time
|
||||
func (t SqlTimeStamp) Format(layout string) string {
|
||||
return time.Time(t).Format(layout)
|
||||
}
|
||||
|
||||
func SqlTimeStampNow() SqlTimeStamp {
|
||||
tx := time.Now()
|
||||
|
||||
return SqlTimeStamp(tx)
|
||||
}
|
||||
|
||||
// SqlFloat64 - SQL Int
|
||||
type SqlFloat64 sql.NullFloat64
|
||||
|
||||
// Scan -
|
||||
func (n *SqlFloat64) Scan(value interface{}) error {
|
||||
newval := sql.NullFloat64{Float64: 0, Valid: false}
|
||||
if value == nil {
|
||||
newval.Valid = false
|
||||
*n = SqlFloat64(newval)
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case float64:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case float32:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case int64:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case int32:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case uint16:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case uint64:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
case uint32:
|
||||
newval.Float64 = float64(v)
|
||||
newval.Valid = true
|
||||
default:
|
||||
i, err := strconv.ParseInt(fmt.Sprintf("%v", v), 10, 64)
|
||||
newval.Float64 = float64(i)
|
||||
if err == nil {
|
||||
newval.Valid = false
|
||||
}
|
||||
}
|
||||
|
||||
*n = SqlFloat64(newval)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value -
|
||||
func (n SqlFloat64) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return float64(n.Float64), nil
|
||||
}
|
||||
|
||||
// String -
|
||||
func (n SqlFloat64) String() string {
|
||||
if !n.Valid {
|
||||
return ""
|
||||
}
|
||||
tmstr := fmt.Sprintf("%f", n.Float64)
|
||||
return tmstr
|
||||
}
|
||||
|
||||
// UnmarshalJSON -
|
||||
func (n *SqlFloat64) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
invalid := (s == "null" || s == "" || len(s) < 2) || (strings.Contains(s, "{") || strings.Contains(s, "["))
|
||||
if invalid {
|
||||
return nil
|
||||
}
|
||||
|
||||
nval, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*n = SqlFloat64(sql.NullFloat64{Valid: true, Float64: float64(nval)})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlFloat64) MarshalJSON() ([]byte, error) {
|
||||
if !n.Valid {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf("%f", n.Float64)), nil
|
||||
}
|
||||
|
||||
// SqlDate - Implementation of SqlTime with some interfaces.
|
||||
type SqlDate time.Time
|
||||
|
||||
// UnmarshalJSON - Override JSON format of time
|
||||
func (t *SqlDate) UnmarshalJSON(b []byte) error {
|
||||
var err error
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
if s == "null" || s == "" || s == "0" ||
|
||||
strings.HasPrefix(s, "0001-01-01T00:00:00") ||
|
||||
s == "0001-01-01" {
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := tryParseDT(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*t = SqlDate(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (t SqlDate) MarshalJSON() ([]byte, error) {
|
||||
tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339
|
||||
if strings.HasPrefix(tmstr, "0001-01-01") {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
|
||||
}
|
||||
|
||||
// Value - SQL Value of custom date
|
||||
func (t SqlDate) Value() (driver.Value, error) {
|
||||
var s time.Time
|
||||
tmstr := time.Time(t).Format("2006-01-02")
|
||||
if strings.HasPrefix(tmstr, "0001-01-01") || tmstr <= "0001-01-01" {
|
||||
return nil, nil
|
||||
}
|
||||
s = time.Time(t)
|
||||
|
||||
return s.Format("2006-01-02"), nil
|
||||
}
|
||||
|
||||
// Scan - Scan custom date from sql
|
||||
func (t *SqlDate) Scan(value interface{}) error {
|
||||
tm, ok := value.(time.Time)
|
||||
if ok {
|
||||
*t = SqlDate(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
str, ok := value.(string)
|
||||
if ok {
|
||||
tx, err := tryParseDT(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*t = SqlDate(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Int64 - Override date format in unix epoch
|
||||
func (t SqlDate) Int64() int64 {
|
||||
return time.Time(t).Unix()
|
||||
}
|
||||
|
||||
// String - Override String format of time
|
||||
func (t SqlDate) String() string {
|
||||
tmstr := time.Time(t).Format("2006-01-02") // time.RFC3339
|
||||
if strings.HasPrefix(tmstr, "0001-01-01") || strings.HasPrefix(tmstr, "1800-12-31") {
|
||||
return "0"
|
||||
}
|
||||
return tmstr
|
||||
}
|
||||
|
||||
func SqlDateNow() SqlDate {
|
||||
tx := time.Now()
|
||||
return SqlDate(tx)
|
||||
}
|
||||
|
||||
// ////////////////////// SqlTime /////////////////////////
|
||||
// SqlTime - Implementation of SqlTime with some interfaces.
|
||||
type SqlTime time.Time
|
||||
|
||||
// Int64 - Override Time format in unix epoch
|
||||
func (t SqlTime) Int64() int64 {
|
||||
return time.Time(t).Unix()
|
||||
}
|
||||
|
||||
// String - Override String format of time
|
||||
func (t SqlTime) String() string {
|
||||
return time.Time(t).Format("15:04:05")
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Override JSON format of time
|
||||
func (t *SqlTime) UnmarshalJSON(b []byte) error {
|
||||
var err error
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
if s == "null" || s == "" || s == "0" ||
|
||||
s == "0001-01-01T00:00:00" || s == "00:00:00" {
|
||||
*t = SqlTime{}
|
||||
return nil
|
||||
}
|
||||
|
||||
tx, err := tryParseDT(s)
|
||||
*t = SqlTime(tx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Format - Format Function
|
||||
func (t SqlTime) Format(form string) string {
|
||||
tmstr := time.Time(t).Format(form)
|
||||
return tmstr
|
||||
}
|
||||
|
||||
// Scan - Scan custom date from sql
|
||||
func (t *SqlTime) Scan(value interface{}) error {
|
||||
tm, ok := value.(time.Time)
|
||||
if ok {
|
||||
*t = SqlTime(tm)
|
||||
return nil
|
||||
}
|
||||
|
||||
str, ok := value.(string)
|
||||
if ok {
|
||||
tx, err := tryParseDT(str)
|
||||
*t = SqlTime(tx)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value - SQL Value of custom date
|
||||
func (t SqlTime) Value() (driver.Value, error) {
|
||||
|
||||
s := time.Time(t)
|
||||
st := s.Format("15:04:05")
|
||||
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (t SqlTime) MarshalJSON() ([]byte, error) {
|
||||
tmstr := time.Time(t).Format("15:04:05")
|
||||
if tmstr == "0001-01-01T00:00:00" || tmstr == "00:00:00" {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf("\"%s\"", tmstr)), nil
|
||||
}
|
||||
|
||||
func SqlTimeNow() SqlTime {
|
||||
tx := time.Now()
|
||||
return SqlTime(tx)
|
||||
}
|
||||
|
||||
// SqlJSONB - Nullable JSONB String
|
||||
type SqlJSONB []byte
|
||||
|
||||
// Scan - Implements sql.Scanner for reading JSONB from database
|
||||
func (n *SqlJSONB) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*n = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
*n = SqlJSONB([]byte(v))
|
||||
case []byte:
|
||||
*n = SqlJSONB(v)
|
||||
default:
|
||||
// For other types, marshal to JSON
|
||||
dat, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal value to JSON: %v", err)
|
||||
}
|
||||
*n = SqlJSONB(dat)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value - Implements driver.Valuer for writing JSONB to database
|
||||
func (n SqlJSONB) Value() (driver.Value, error) {
|
||||
if len(n) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Validate that it's valid JSON before returning
|
||||
var js interface{}
|
||||
if err := json.Unmarshal(n, &js); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||
}
|
||||
|
||||
// Return as string for PostgreSQL JSONB/JSON columns
|
||||
return string(n), nil
|
||||
}
|
||||
|
||||
func (n SqlJSONB) AsMap() (map[string]any, error) {
|
||||
if len(n) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// Validate that it's valid JSON before returning
|
||||
js := make(map[string]any)
|
||||
if err := json.Unmarshal(n, &js); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||
}
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func (n SqlJSONB) AsSlice() ([]any, error) {
|
||||
if len(n) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// Validate that it's valid JSON before returning
|
||||
js := make([]any, 0)
|
||||
if err := json.Unmarshal(n, &js); err != nil {
|
||||
return nil, fmt.Errorf("invalid JSON: %v", err)
|
||||
}
|
||||
return js, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Override JSON
|
||||
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
invalid := (s == "null" || s == "" || len(s) < 2) || (!strings.Contains(s, "{") && !strings.Contains(s, "["))
|
||||
if invalid {
|
||||
return nil
|
||||
}
|
||||
|
||||
*n = []byte(s)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlJSONB) MarshalJSON() ([]byte, error) {
|
||||
if n == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
var obj interface{}
|
||||
err := json.Unmarshal(n, &obj)
|
||||
if err != nil {
|
||||
// fmt.Printf("Invalid JSON %v", err)
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
// dat, err := json.MarshalIndent(obj, " ", " ")
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("failed to convert to JSON: %v", err)
|
||||
// }
|
||||
dat := n
|
||||
|
||||
return dat, nil
|
||||
}
|
||||
|
||||
// SqlUUID - Nullable UUID String
|
||||
type SqlUUID sql.NullString
|
||||
|
||||
// Scan -
|
||||
func (n *SqlUUID) Scan(value interface{}) error {
|
||||
str := sql.NullString{String: "", Valid: false}
|
||||
if value == nil {
|
||||
*n = SqlUUID(str)
|
||||
return nil
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
uuid, err := uuid.Parse(v)
|
||||
if err == nil {
|
||||
str.String = uuid.String()
|
||||
str.Valid = true
|
||||
*n = SqlUUID(str)
|
||||
}
|
||||
case []uint8:
|
||||
uuid, err := uuid.ParseBytes(v)
|
||||
if err == nil {
|
||||
str.String = uuid.String()
|
||||
str.Valid = true
|
||||
*n = SqlUUID(str)
|
||||
}
|
||||
default:
|
||||
uuid, err := uuid.Parse(fmt.Sprintf("%v", v))
|
||||
if err == nil {
|
||||
str.String = uuid.String()
|
||||
str.Valid = true
|
||||
*n = SqlUUID(str)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value -
|
||||
func (n SqlUUID) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.String, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON - Override JSON
|
||||
func (n *SqlUUID) UnmarshalJSON(b []byte) error {
|
||||
|
||||
s := strings.Trim(strings.Trim(string(b), " "), "\"")
|
||||
invalid := (s == "null" || s == "" || len(s) < 30)
|
||||
if invalid {
|
||||
return nil
|
||||
}
|
||||
*n = SqlUUID(sql.NullString{String: s, Valid: !invalid})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON - Override JSON format of time
|
||||
func (n SqlUUID) MarshalJSON() ([]byte, error) {
|
||||
if !n.Valid {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return []byte(fmt.Sprintf("\"%s\"", n.String)), nil
|
||||
}
|
||||
|
||||
// TryIfInt64 - Wrapper function to quickly try and cast text to int
|
||||
func TryIfInt64(v any, def int64) int64 {
|
||||
str := ""
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
str = val
|
||||
case int:
|
||||
return int64(val)
|
||||
case int32:
|
||||
return int64(val)
|
||||
case int64:
|
||||
return val
|
||||
case uint32:
|
||||
return int64(val)
|
||||
case uint64:
|
||||
return int64(val)
|
||||
case float32:
|
||||
return int64(val)
|
||||
case float64:
|
||||
return int64(val)
|
||||
case []byte:
|
||||
str = string(val)
|
||||
default:
|
||||
str = fmt.Sprintf("%d", def)
|
||||
}
|
||||
val, err := strconv.ParseInt(str, 10, 64)
|
||||
if err != nil {
|
||||
return def
|
||||
}
|
||||
return val
|
||||
}
|
||||
566
pkg/common/sql_types_test.go
Normal file
566
pkg/common/sql_types_test.go
Normal file
@@ -0,0 +1,566 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestSqlInt16 tests SqlInt16 type
|
||||
func TestSqlInt16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected SqlInt16
|
||||
}{
|
||||
{"int", 42, SqlInt16(42)},
|
||||
{"int32", int32(100), SqlInt16(100)},
|
||||
{"int64", int64(200), SqlInt16(200)},
|
||||
{"string", "123", SqlInt16(123)},
|
||||
{"nil", nil, SqlInt16(0)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n SqlInt16
|
||||
if err := n.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if n != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlInt16_Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlInt16
|
||||
expected driver.Value
|
||||
}{
|
||||
{"zero", SqlInt16(0), nil},
|
||||
{"positive", SqlInt16(42), int64(42)},
|
||||
{"negative", SqlInt16(-10), int64(-10)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := tt.input.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlInt16_JSON(t *testing.T) {
|
||||
n := SqlInt16(42)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(n)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := "42"
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var n2 SqlInt16
|
||||
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if n2 != 123 {
|
||||
t.Errorf("expected 123, got %d", n2)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlInt64 tests SqlInt64 type
|
||||
func TestSqlInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected SqlInt64
|
||||
}{
|
||||
{"int", 42, SqlInt64(42)},
|
||||
{"int32", int32(100), SqlInt64(100)},
|
||||
{"int64", int64(9223372036854775807), SqlInt64(9223372036854775807)},
|
||||
{"uint32", uint32(100), SqlInt64(100)},
|
||||
{"uint64", uint64(200), SqlInt64(200)},
|
||||
{"nil", nil, SqlInt64(0)},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n SqlInt64
|
||||
if err := n.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if n != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlFloat64 tests SqlFloat64 type
|
||||
func TestSqlFloat64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected float64
|
||||
valid bool
|
||||
}{
|
||||
{"float64", float64(3.14), 3.14, true},
|
||||
{"float32", float32(2.5), 2.5, true},
|
||||
{"int", 42, 42.0, true},
|
||||
{"int64", int64(100), 100.0, true},
|
||||
{"nil", nil, 0, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var n SqlFloat64
|
||||
if err := n.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if n.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
|
||||
}
|
||||
if tt.valid && n.Float64 != tt.expected {
|
||||
t.Errorf("expected %v, got %v", tt.expected, n.Float64)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlTimeStamp tests SqlTimeStamp type
|
||||
func TestSqlTimeStamp(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
}{
|
||||
{"time.Time", now},
|
||||
{"string RFC3339", now.Format(time.RFC3339)},
|
||||
{"string date", "2024-01-15"},
|
||||
{"string datetime", "2024-01-15T10:30:00"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var ts SqlTimeStamp
|
||||
if err := ts.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if ts.GetTime().IsZero() {
|
||||
t.Error("expected non-zero time")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlTimeStamp_JSON(t *testing.T) {
|
||||
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
|
||||
ts := SqlTimeStamp(now)
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(ts)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := `"2024-01-15T10:30:45"`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var ts2 SqlTimeStamp
|
||||
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if ts2.GetTime().Year() != 2024 {
|
||||
t.Errorf("expected year 2024, got %d", ts2.GetTime().Year())
|
||||
}
|
||||
|
||||
// Test null
|
||||
var ts3 SqlTimeStamp
|
||||
if err := json.Unmarshal([]byte("null"), &ts3); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlDate tests SqlDate type
|
||||
func TestSqlDate(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
}{
|
||||
{"time.Time", now},
|
||||
{"string date", "2024-01-15"},
|
||||
{"string UK format", "15/01/2024"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var d SqlDate
|
||||
if err := d.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if d.String() == "0" {
|
||||
t.Error("expected non-zero date")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlDate_JSON(t *testing.T) {
|
||||
date := SqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(date)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := `"2024-01-15"`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var d2 SqlDate
|
||||
if err := json.Unmarshal([]byte(`"2024-01-15"`), &d2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlTime tests SqlTime type
|
||||
func TestSqlTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{"time.Time", now, now.Format("15:04:05")},
|
||||
{"string time", "10:30:45", "10:30:45"},
|
||||
{"string short time", "10:30", "10:30:00"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var tm SqlTime
|
||||
if err := tm.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if tm.String() != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, tm.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlJSONB tests SqlJSONB type
|
||||
func TestSqlJSONB_Scan(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
}{
|
||||
{"string JSON object", `{"key":"value"}`, `{"key":"value"}`},
|
||||
{"string JSON array", `[1,2,3]`, `[1,2,3]`},
|
||||
{"bytes", []byte(`{"test":true}`), `{"test":true}`},
|
||||
{"nil", nil, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var j SqlJSONB
|
||||
if err := j.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if tt.expected == "" && j == nil {
|
||||
return // nil case
|
||||
}
|
||||
if string(j) != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, string(j))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlJSONB
|
||||
expected string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid object", SqlJSONB(`{"key":"value"}`), `{"key":"value"}`, false},
|
||||
{"valid array", SqlJSONB(`[1,2,3]`), `[1,2,3]`, false},
|
||||
{"empty", SqlJSONB{}, "", false},
|
||||
{"nil", nil, "", false},
|
||||
{"invalid JSON", SqlJSONB(`{invalid`), "", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
val, err := tt.input.Value()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if tt.expected == "" && val == nil {
|
||||
return // nil case
|
||||
}
|
||||
if val.(string) != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, val)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_JSON(t *testing.T) {
|
||||
// Marshal
|
||||
j := SqlJSONB(`{"name":"test","count":42}`)
|
||||
data, err := json.Marshal(j)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
t.Fatalf("Unmarshal result failed: %v", err)
|
||||
}
|
||||
if result["name"] != "test" {
|
||||
t.Errorf("expected name=test, got %v", result["name"])
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var j2 SqlJSONB
|
||||
if err := json.Unmarshal([]byte(`{"key":"value"}`), &j2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if string(j2) != `{"key":"value"}` {
|
||||
t.Errorf("expected {\"key\":\"value\"}, got %s", string(j2))
|
||||
}
|
||||
|
||||
// Test null
|
||||
var j3 SqlJSONB
|
||||
if err := json.Unmarshal([]byte("null"), &j3); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_AsMap(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlJSONB
|
||||
wantErr bool
|
||||
wantNil bool
|
||||
}{
|
||||
{"valid object", SqlJSONB(`{"name":"test","age":30}`), false, false},
|
||||
{"empty", SqlJSONB{}, false, true},
|
||||
{"nil", nil, false, true},
|
||||
{"invalid JSON", SqlJSONB(`{invalid`), true, false},
|
||||
{"array not object", SqlJSONB(`[1,2,3]`), true, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
m, err := tt.input.AsMap()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("AsMap failed: %v", err)
|
||||
}
|
||||
if tt.wantNil {
|
||||
if m != nil {
|
||||
t.Errorf("expected nil, got %v", m)
|
||||
}
|
||||
return
|
||||
}
|
||||
if m == nil {
|
||||
t.Error("expected non-nil map")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlJSONB_AsSlice(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input SqlJSONB
|
||||
wantErr bool
|
||||
wantNil bool
|
||||
}{
|
||||
{"valid array", SqlJSONB(`[1,2,3]`), false, false},
|
||||
{"empty", SqlJSONB{}, false, true},
|
||||
{"nil", nil, false, true},
|
||||
{"invalid JSON", SqlJSONB(`[invalid`), true, false},
|
||||
{"object not array", SqlJSONB(`{"key":"value"}`), true, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s, err := tt.input.AsSlice()
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("AsSlice failed: %v", err)
|
||||
}
|
||||
if tt.wantNil {
|
||||
if s != nil {
|
||||
t.Errorf("expected nil, got %v", s)
|
||||
}
|
||||
return
|
||||
}
|
||||
if s == nil {
|
||||
t.Error("expected non-nil slice")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlUUID tests SqlUUID type
|
||||
func TestSqlUUID_Scan(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
testUUIDStr := testUUID.String()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected string
|
||||
valid bool
|
||||
}{
|
||||
{"string UUID", testUUIDStr, testUUIDStr, true},
|
||||
{"bytes UUID", []byte(testUUIDStr), testUUIDStr, true},
|
||||
{"nil", nil, "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var u SqlUUID
|
||||
if err := u.Scan(tt.input); err != nil {
|
||||
t.Fatalf("Scan failed: %v", err)
|
||||
}
|
||||
if u.Valid != tt.valid {
|
||||
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
|
||||
}
|
||||
if tt.valid && u.String != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, u.String)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlUUID_Value(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||
|
||||
val, err := u.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val != testUUID.String() {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), val)
|
||||
}
|
||||
|
||||
// Test invalid UUID
|
||||
u2 := SqlUUID{Valid: false}
|
||||
val2, err := u2.Value()
|
||||
if err != nil {
|
||||
t.Fatalf("Value failed: %v", err)
|
||||
}
|
||||
if val2 != nil {
|
||||
t.Errorf("expected nil, got %v", val2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlUUID_JSON(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
u := SqlUUID{String: testUUID.String(), Valid: true}
|
||||
|
||||
// Marshal
|
||||
data, err := json.Marshal(u)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
expected := `"` + testUUID.String() + `"`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
|
||||
// Unmarshal
|
||||
var u2 SqlUUID
|
||||
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
if u2.String != testUUID.String() {
|
||||
t.Errorf("expected %s, got %s", testUUID.String(), u2.String)
|
||||
}
|
||||
|
||||
// Test null
|
||||
var u3 SqlUUID
|
||||
if err := json.Unmarshal([]byte("null"), &u3); err != nil {
|
||||
t.Fatalf("Unmarshal null failed: %v", err)
|
||||
}
|
||||
if u3.Valid {
|
||||
t.Error("expected invalid UUID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestTryIfInt64 tests the TryIfInt64 helper function
|
||||
func TestTryIfInt64(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
def int64
|
||||
expected int64
|
||||
}{
|
||||
{"string valid", "123", 0, 123},
|
||||
{"string invalid", "abc", 99, 99},
|
||||
{"int", 42, 0, 42},
|
||||
{"int32", int32(100), 0, 100},
|
||||
{"int64", int64(200), 0, 200},
|
||||
{"uint32", uint32(50), 0, 50},
|
||||
{"uint64", uint64(75), 0, 75},
|
||||
{"float32", float32(3.14), 0, 3},
|
||||
{"float64", float64(2.71), 0, 2},
|
||||
{"bytes", []byte("456"), 0, 456},
|
||||
{"unknown type", struct{}{}, 999, 999},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := TryIfInt64(tt.input, tt.def)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,11 @@ type RequestOptions struct {
|
||||
CustomOperators []CustomOperator `json:"customOperators"`
|
||||
ComputedColumns []ComputedColumn `json:"computedColumns"`
|
||||
Parameters []Parameter `json:"parameters"`
|
||||
|
||||
// Cursor pagination
|
||||
CursorForward string `json:"cursor_forward"`
|
||||
CursorBackward string `json:"cursor_backward"`
|
||||
FetchRowNumber *string `json:"fetch_row_number"`
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
@@ -27,19 +32,29 @@ type Parameter struct {
|
||||
}
|
||||
|
||||
type PreloadOption struct {
|
||||
Relation string `json:"relation"`
|
||||
Columns []string `json:"columns"`
|
||||
OmitColumns []string `json:"omit_columns"`
|
||||
Filters []FilterOption `json:"filters"`
|
||||
Limit *int `json:"limit"`
|
||||
Offset *int `json:"offset"`
|
||||
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
||||
Relation string `json:"relation"`
|
||||
Columns []string `json:"columns"`
|
||||
OmitColumns []string `json:"omit_columns"`
|
||||
Sort []SortOption `json:"sort"`
|
||||
Filters []FilterOption `json:"filters"`
|
||||
Where string `json:"where"`
|
||||
Limit *int `json:"limit"`
|
||||
Offset *int `json:"offset"`
|
||||
Updatable *bool `json:"updateable"` // if true, the relation can be updated
|
||||
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
|
||||
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
|
||||
|
||||
// Relationship keys from XFiles - used to build proper foreign key filters
|
||||
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
||||
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
||||
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
||||
}
|
||||
|
||||
type FilterOption struct {
|
||||
Column string `json:"column"`
|
||||
Operator string `json:"operator"`
|
||||
Value interface{} `json:"value"`
|
||||
Column string `json:"column"`
|
||||
Operator string `json:"operator"`
|
||||
Value interface{} `json:"value"`
|
||||
LogicOperator string `json:"logic_operator"` // "AND" or "OR" - how this filter combines with previous filters
|
||||
}
|
||||
|
||||
type SortOption struct {
|
||||
@@ -66,10 +81,12 @@ type Response struct {
|
||||
}
|
||||
|
||||
type Metadata struct {
|
||||
Total int64 `json:"total"`
|
||||
Filtered int64 `json:"filtered"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
Total int64 `json:"total"`
|
||||
Count int64 `json:"count"`
|
||||
Filtered int64 `json:"filtered"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
RowNumber *int64 `json:"row_number,omitempty"`
|
||||
}
|
||||
|
||||
type APIError struct {
|
||||
|
||||
287
pkg/common/validation.go
Normal file
287
pkg/common/validation.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// ColumnValidator validates column names against a model's fields
|
||||
type ColumnValidator struct {
|
||||
validColumns map[string]bool
|
||||
model interface{}
|
||||
}
|
||||
|
||||
// NewColumnValidator creates a new column validator for a given model
|
||||
func NewColumnValidator(model interface{}) *ColumnValidator {
|
||||
validator := &ColumnValidator{
|
||||
validColumns: make(map[string]bool),
|
||||
model: model,
|
||||
}
|
||||
validator.buildValidColumns()
|
||||
return validator
|
||||
}
|
||||
|
||||
// buildValidColumns extracts all valid column names from the model using reflection
|
||||
func (v *ColumnValidator) buildValidColumns() {
|
||||
modelType := reflect.TypeOf(v.model)
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Validate that we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return
|
||||
}
|
||||
|
||||
// Extract column names from struct fields
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get column name from bun, gorm, or json tag
|
||||
columnName := v.getColumnName(field)
|
||||
if columnName != "" && columnName != "-" {
|
||||
v.validColumns[strings.ToLower(columnName)] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getColumnName extracts the column name from a struct field's tags
|
||||
// Supports both Bun and GORM tags
|
||||
func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
|
||||
// First check Bun tag for column name
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && bunTag != "-" {
|
||||
parts := strings.Split(bunTag, ",")
|
||||
// The first part is usually the column name
|
||||
columnName := strings.TrimSpace(parts[0])
|
||||
if columnName != "" && columnName != "-" {
|
||||
return columnName
|
||||
}
|
||||
}
|
||||
|
||||
// Check GORM tag for column name
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if strings.Contains(gormTag, "column:") {
|
||||
parts := strings.Split(gormTag, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "column:") {
|
||||
return strings.TrimPrefix(part, "column:")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to JSON tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
// Extract just the name part (before any comma)
|
||||
jsonName := strings.Split(jsonTag, ",")[0]
|
||||
return jsonName
|
||||
}
|
||||
|
||||
// Fall back to field name in lowercase (snake_case conversion would be better)
|
||||
return strings.ToLower(field.Name)
|
||||
}
|
||||
|
||||
// ValidateColumn validates a single column name
|
||||
// Returns nil if valid, error if invalid
|
||||
// Columns prefixed with "cql" (case insensitive) are always valid
|
||||
// Handles PostgreSQL JSON operators (-> and ->>)
|
||||
func (v *ColumnValidator) ValidateColumn(column string) error {
|
||||
// Allow empty columns
|
||||
if column == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Allow columns prefixed with "cql" (case insensitive) for computed columns
|
||||
if strings.HasPrefix(strings.ToLower(column), "cql") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract source column name (remove JSON operators like ->> or ->)
|
||||
sourceColumn := reflection.ExtractSourceColumn(column)
|
||||
|
||||
// Check if column exists in model
|
||||
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {
|
||||
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValidColumn checks if a column is valid
|
||||
// Returns true if valid, false if invalid
|
||||
func (v *ColumnValidator) IsValidColumn(column string) bool {
|
||||
return v.ValidateColumn(column) == nil
|
||||
}
|
||||
|
||||
// FilterValidColumns filters a list of columns, returning only valid ones
|
||||
// Logs warnings for any invalid columns
|
||||
func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
|
||||
if len(columns) == 0 {
|
||||
return columns
|
||||
}
|
||||
|
||||
validColumns := make([]string, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
if v.IsValidColumn(col) {
|
||||
validColumns = append(validColumns, col)
|
||||
} else {
|
||||
logger.Warn("Invalid column '%s' filtered out: column does not exist in model", col)
|
||||
}
|
||||
}
|
||||
return validColumns
|
||||
}
|
||||
|
||||
// ValidateColumns validates multiple column names
|
||||
// Returns error with details about all invalid columns
|
||||
func (v *ColumnValidator) ValidateColumns(columns []string) error {
|
||||
var invalidColumns []string
|
||||
|
||||
for _, column := range columns {
|
||||
if err := v.ValidateColumn(column); err != nil {
|
||||
invalidColumns = append(invalidColumns, column)
|
||||
}
|
||||
}
|
||||
|
||||
if len(invalidColumns) > 0 {
|
||||
return fmt.Errorf("invalid columns: %s", strings.Join(invalidColumns, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateRequestOptions validates all column references in RequestOptions
|
||||
func (v *ColumnValidator) ValidateRequestOptions(options RequestOptions) error {
|
||||
// Validate Columns
|
||||
if err := v.ValidateColumns(options.Columns); err != nil {
|
||||
return fmt.Errorf("in select columns: %w", err)
|
||||
}
|
||||
|
||||
// Validate OmitColumns
|
||||
if err := v.ValidateColumns(options.OmitColumns); err != nil {
|
||||
return fmt.Errorf("in omit columns: %w", err)
|
||||
}
|
||||
|
||||
// Validate Filter columns
|
||||
for _, filter := range options.Filters {
|
||||
if err := v.ValidateColumn(filter.Column); err != nil {
|
||||
return fmt.Errorf("in filter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate Sort columns
|
||||
for _, sort := range options.Sort {
|
||||
if err := v.ValidateColumn(sort.Column); err != nil {
|
||||
return fmt.Errorf("in sort: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Validate Preload columns (if specified)
|
||||
for idx := range options.Preload {
|
||||
preload := options.Preload[idx]
|
||||
// Note: We don't validate the relation name itself, as it's a relationship
|
||||
// Only validate columns if specified for the preload
|
||||
if err := v.ValidateColumns(preload.Columns); err != nil {
|
||||
return fmt.Errorf("in preload '%s' columns: %w", preload.Relation, err)
|
||||
}
|
||||
if err := v.ValidateColumns(preload.OmitColumns); err != nil {
|
||||
return fmt.Errorf("in preload '%s' omit columns: %w", preload.Relation, err)
|
||||
}
|
||||
|
||||
// Validate filter columns in preload
|
||||
for _, filter := range preload.Filters {
|
||||
if err := v.ValidateColumn(filter.Column); err != nil {
|
||||
return fmt.Errorf("in preload '%s' filter: %w", preload.Relation, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FilterRequestOptions filters all column references in RequestOptions
|
||||
// Returns a new RequestOptions with only valid columns, logging warnings for invalid ones
|
||||
func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOptions {
|
||||
filtered := options
|
||||
|
||||
// Filter Columns
|
||||
filtered.Columns = v.FilterValidColumns(options.Columns)
|
||||
|
||||
// Filter OmitColumns
|
||||
filtered.OmitColumns = v.FilterValidColumns(options.OmitColumns)
|
||||
|
||||
// Filter Filter columns
|
||||
validFilters := make([]FilterOption, 0, len(options.Filters))
|
||||
for _, filter := range options.Filters {
|
||||
if v.IsValidColumn(filter.Column) {
|
||||
validFilters = append(validFilters, filter)
|
||||
} else {
|
||||
logger.Warn("Invalid column in filter '%s' removed", filter.Column)
|
||||
}
|
||||
}
|
||||
filtered.Filters = validFilters
|
||||
|
||||
// Filter Sort columns
|
||||
validSorts := make([]SortOption, 0, len(options.Sort))
|
||||
for _, sort := range options.Sort {
|
||||
if v.IsValidColumn(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else {
|
||||
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
||||
}
|
||||
}
|
||||
filtered.Sort = validSorts
|
||||
|
||||
// Filter Preload columns
|
||||
validPreloads := make([]PreloadOption, 0, len(options.Preload))
|
||||
for idx := range options.Preload {
|
||||
preload := options.Preload[idx]
|
||||
filteredPreload := preload
|
||||
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
||||
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
||||
|
||||
// Filter preload filters
|
||||
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
||||
for _, filter := range preload.Filters {
|
||||
if v.IsValidColumn(filter.Column) {
|
||||
validPreloadFilters = append(validPreloadFilters, filter)
|
||||
} else {
|
||||
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
||||
}
|
||||
}
|
||||
filteredPreload.Filters = validPreloadFilters
|
||||
|
||||
validPreloads = append(validPreloads, filteredPreload)
|
||||
}
|
||||
filtered.Preload = validPreloads
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
// GetValidColumns returns a list of all valid column names for debugging purposes
|
||||
func (v *ColumnValidator) GetValidColumns() []string {
|
||||
columns := make([]string, 0, len(v.validColumns))
|
||||
for col := range v.validColumns {
|
||||
columns = append(columns, col)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func QuoteIdent(qualifier string) string {
|
||||
return `"` + strings.ReplaceAll(qualifier, `"`, `""`) + `"`
|
||||
}
|
||||
|
||||
func QuoteLiteral(value string) string {
|
||||
return `'` + strings.ReplaceAll(value, `'`, `''`) + `'`
|
||||
}
|
||||
126
pkg/common/validation_json_test.go
Normal file
126
pkg/common/validation_json_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
func TestExtractSourceColumn(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple column name",
|
||||
input: "columna",
|
||||
expected: "columna",
|
||||
},
|
||||
{
|
||||
name: "column with ->> operator",
|
||||
input: "columna->>'val'",
|
||||
expected: "columna",
|
||||
},
|
||||
{
|
||||
name: "column with -> operator",
|
||||
input: "columna->'key'",
|
||||
expected: "columna",
|
||||
},
|
||||
{
|
||||
name: "column with table prefix and ->> operator",
|
||||
input: "table.columna->>'val'",
|
||||
expected: "table.columna",
|
||||
},
|
||||
{
|
||||
name: "column with table prefix and -> operator",
|
||||
input: "table.columna->'key'",
|
||||
expected: "table.columna",
|
||||
},
|
||||
{
|
||||
name: "complex JSON path with ->>",
|
||||
input: "data->>'nested'->>'value'",
|
||||
expected: "data",
|
||||
},
|
||||
{
|
||||
name: "column with spaces before operator",
|
||||
input: "columna ->>'val'",
|
||||
expected: "columna",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := reflection.ExtractSourceColumn(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("reflection.ExtractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumnWithJSONOperators(t *testing.T) {
|
||||
// Create a test model
|
||||
type TestModel struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Data string `json:"data"` // JSON column
|
||||
Metadata string `json:"metadata"`
|
||||
}
|
||||
|
||||
validator := NewColumnValidator(TestModel{})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
column string
|
||||
shouldErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple valid column",
|
||||
column: "name",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid column with ->> operator",
|
||||
column: "data->>'field'",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid column with -> operator",
|
||||
column: "metadata->'key'",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid column",
|
||||
column: "invalid_column",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid column with ->> operator",
|
||||
column: "invalid_column->>'field'",
|
||||
shouldErr: true,
|
||||
},
|
||||
{
|
||||
name: "cql prefixed column (always valid)",
|
||||
column: "cql_computed",
|
||||
shouldErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty column",
|
||||
column: "",
|
||||
shouldErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := validator.ValidateColumn(tc.column)
|
||||
if tc.shouldErr && err == nil {
|
||||
t.Errorf("ValidateColumn(%q) expected error, got nil", tc.column)
|
||||
}
|
||||
if !tc.shouldErr && err != nil {
|
||||
t.Errorf("ValidateColumn(%q) expected no error, got %v", tc.column, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
363
pkg/common/validation_test.go
Normal file
363
pkg/common/validation_test.go
Normal file
@@ -0,0 +1,363 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestModel represents a sample model for testing
|
||||
type TestModel struct {
|
||||
ID int64 `json:"id" gorm:"primaryKey"`
|
||||
Name string `json:"name" gorm:"column:name"`
|
||||
Email string `json:"email" bun:"email"`
|
||||
Age int `json:"age"`
|
||||
IsActive bool `json:"is_active"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
func TestNewColumnValidator(t *testing.T) {
|
||||
model := TestModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
if validator == nil {
|
||||
t.Fatal("Expected validator to be created")
|
||||
}
|
||||
|
||||
if len(validator.validColumns) == 0 {
|
||||
t.Fatal("Expected validator to have valid columns")
|
||||
}
|
||||
|
||||
// Check that expected columns are present
|
||||
expectedColumns := []string{"id", "name", "email", "age", "is_active", "created_at"}
|
||||
for _, col := range expectedColumns {
|
||||
if !validator.validColumns[col] {
|
||||
t.Errorf("Expected column '%s' to be valid", col)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumn(t *testing.T) {
|
||||
model := TestModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
column string
|
||||
shouldError bool
|
||||
}{
|
||||
{"Valid column - id", "id", false},
|
||||
{"Valid column - name", "name", false},
|
||||
{"Valid column - email", "email", false},
|
||||
{"Valid column - uppercase", "ID", false}, // Case insensitive
|
||||
{"Invalid column", "invalid_column", true},
|
||||
{"CQL prefixed - should be valid", "cqlComputedField", false},
|
||||
{"CQL prefixed uppercase - should be valid", "CQLComputedField", false},
|
||||
{"Empty column", "", false}, // Empty columns are allowed
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateColumn(tt.column)
|
||||
if tt.shouldError && err == nil {
|
||||
t.Errorf("Expected error for column '%s', got nil", tt.column)
|
||||
}
|
||||
if !tt.shouldError && err != nil {
|
||||
t.Errorf("Expected no error for column '%s', got: %v", tt.column, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateColumns(t *testing.T) {
|
||||
model := TestModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
columns []string
|
||||
shouldError bool
|
||||
}{
|
||||
{"All valid columns", []string{"id", "name", "email"}, false},
|
||||
{"One invalid column", []string{"id", "invalid_col", "name"}, true},
|
||||
{"All invalid columns", []string{"bad1", "bad2"}, true},
|
||||
{"With CQL prefix", []string{"id", "cqlComputed", "name"}, false},
|
||||
{"Empty list", []string{}, false},
|
||||
{"Nil list", nil, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateColumns(tt.columns)
|
||||
if tt.shouldError && err == nil {
|
||||
t.Errorf("Expected error for columns %v, got nil", tt.columns)
|
||||
}
|
||||
if !tt.shouldError && err != nil {
|
||||
t.Errorf("Expected no error for columns %v, got: %v", tt.columns, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRequestOptions(t *testing.T) {
|
||||
model := TestModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
options RequestOptions
|
||||
shouldError bool
|
||||
errorMsg string
|
||||
}{
|
||||
{
|
||||
name: "Valid options with columns",
|
||||
options: RequestOptions{
|
||||
Columns: []string{"id", "name"},
|
||||
Filters: []FilterOption{
|
||||
{Column: "name", Operator: "eq", Value: "test"},
|
||||
},
|
||||
Sort: []SortOption{
|
||||
{Column: "id", Direction: "ASC"},
|
||||
},
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid column in Columns",
|
||||
options: RequestOptions{
|
||||
Columns: []string{"id", "invalid_column"},
|
||||
},
|
||||
shouldError: true,
|
||||
errorMsg: "select columns",
|
||||
},
|
||||
{
|
||||
name: "Invalid column in Filters",
|
||||
options: RequestOptions{
|
||||
Filters: []FilterOption{
|
||||
{Column: "invalid_col", Operator: "eq", Value: "test"},
|
||||
},
|
||||
},
|
||||
shouldError: true,
|
||||
errorMsg: "filter",
|
||||
},
|
||||
{
|
||||
name: "Invalid column in Sort",
|
||||
options: RequestOptions{
|
||||
Sort: []SortOption{
|
||||
{Column: "invalid_col", Direction: "ASC"},
|
||||
},
|
||||
},
|
||||
shouldError: true,
|
||||
errorMsg: "sort",
|
||||
},
|
||||
{
|
||||
name: "Valid CQL prefixed columns",
|
||||
options: RequestOptions{
|
||||
Columns: []string{"id", "cqlComputedField"},
|
||||
Filters: []FilterOption{
|
||||
{Column: "cqlCustomFilter", Operator: "eq", Value: "test"},
|
||||
},
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid column in Preload",
|
||||
options: RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{
|
||||
Relation: "SomeRelation",
|
||||
Columns: []string{"id", "invalid_col"},
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldError: true,
|
||||
errorMsg: "preload",
|
||||
},
|
||||
{
|
||||
name: "Valid preload with valid columns",
|
||||
options: RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{
|
||||
Relation: "SomeRelation",
|
||||
Columns: []string{"id", "name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validator.ValidateRequestOptions(tt.options)
|
||||
if tt.shouldError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error, got nil")
|
||||
} else if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
|
||||
t.Errorf("Expected error to contain '%s', got: %v", tt.errorMsg, err)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidColumns(t *testing.T) {
|
||||
model := TestModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
columns := validator.GetValidColumns()
|
||||
if len(columns) == 0 {
|
||||
t.Error("Expected to get valid columns, got empty list")
|
||||
}
|
||||
|
||||
// Should have at least the columns from TestModel
|
||||
if len(columns) < 6 {
|
||||
t.Errorf("Expected at least 6 columns, got %d", len(columns))
|
||||
}
|
||||
}
|
||||
|
||||
// Test with Bun tags specifically
|
||||
type BunModel struct {
|
||||
ID int64 `bun:"id,pk"`
|
||||
Name string `bun:"name"`
|
||||
Email string `bun:"user_email"`
|
||||
}
|
||||
|
||||
func TestBunTagSupport(t *testing.T) {
|
||||
model := BunModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
// Test that bun tags are properly recognized
|
||||
tests := []struct {
|
||||
column string
|
||||
shouldError bool
|
||||
}{
|
||||
{"id", false},
|
||||
{"name", false},
|
||||
{"user_email", false}, // Bun tag specifies this name
|
||||
{"email", true}, // JSON tag would be "email", but bun tag says "user_email"
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.column, func(t *testing.T) {
|
||||
err := validator.ValidateColumn(tt.column)
|
||||
if tt.shouldError && err == nil {
|
||||
t.Errorf("Expected error for column '%s'", tt.column)
|
||||
}
|
||||
if !tt.shouldError && err != nil {
|
||||
t.Errorf("Expected no error for column '%s', got: %v", tt.column, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterValidColumns(t *testing.T) {
|
||||
model := TestModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input []string
|
||||
expectedOutput []string
|
||||
}{
|
||||
{
|
||||
name: "All valid columns",
|
||||
input: []string{"id", "name", "email"},
|
||||
expectedOutput: []string{"id", "name", "email"},
|
||||
},
|
||||
{
|
||||
name: "Mix of valid and invalid",
|
||||
input: []string{"id", "invalid_col", "name", "bad_col", "email"},
|
||||
expectedOutput: []string{"id", "name", "email"},
|
||||
},
|
||||
{
|
||||
name: "All invalid columns",
|
||||
input: []string{"bad1", "bad2"},
|
||||
expectedOutput: []string{},
|
||||
},
|
||||
{
|
||||
name: "With CQL prefix (should pass)",
|
||||
input: []string{"id", "cqlComputed", "name"},
|
||||
expectedOutput: []string{"id", "cqlComputed", "name"},
|
||||
},
|
||||
{
|
||||
name: "Empty input",
|
||||
input: []string{},
|
||||
expectedOutput: []string{},
|
||||
},
|
||||
{
|
||||
name: "Nil input",
|
||||
input: nil,
|
||||
expectedOutput: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := validator.FilterValidColumns(tt.input)
|
||||
if len(result) != len(tt.expectedOutput) {
|
||||
t.Errorf("Expected %d columns, got %d", len(tt.expectedOutput), len(result))
|
||||
}
|
||||
for i, col := range result {
|
||||
if col != tt.expectedOutput[i] {
|
||||
t.Errorf("At index %d: expected %s, got %s", i, tt.expectedOutput[i], col)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterRequestOptions(t *testing.T) {
|
||||
model := TestModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
options := RequestOptions{
|
||||
Columns: []string{"id", "name", "invalid_col"},
|
||||
OmitColumns: []string{"email", "bad_col"},
|
||||
Filters: []FilterOption{
|
||||
{Column: "name", Operator: "eq", Value: "test"},
|
||||
{Column: "invalid_col", Operator: "eq", Value: "test"},
|
||||
},
|
||||
Sort: []SortOption{
|
||||
{Column: "id", Direction: "ASC"},
|
||||
{Column: "bad_col", Direction: "DESC"},
|
||||
},
|
||||
}
|
||||
|
||||
filtered := validator.FilterRequestOptions(options)
|
||||
|
||||
// Check Columns
|
||||
if len(filtered.Columns) != 2 {
|
||||
t.Errorf("Expected 2 columns, got %d", len(filtered.Columns))
|
||||
}
|
||||
if filtered.Columns[0] != "id" || filtered.Columns[1] != "name" {
|
||||
t.Errorf("Expected columns [id, name], got %v", filtered.Columns)
|
||||
}
|
||||
|
||||
// Check OmitColumns
|
||||
if len(filtered.OmitColumns) != 1 {
|
||||
t.Errorf("Expected 1 omit column, got %d", len(filtered.OmitColumns))
|
||||
}
|
||||
if filtered.OmitColumns[0] != "email" {
|
||||
t.Errorf("Expected omit column [email], got %v", filtered.OmitColumns)
|
||||
}
|
||||
|
||||
// Check Filters
|
||||
if len(filtered.Filters) != 1 {
|
||||
t.Errorf("Expected 1 filter, got %d", len(filtered.Filters))
|
||||
}
|
||||
if filtered.Filters[0].Column != "name" {
|
||||
t.Errorf("Expected filter column 'name', got %s", filtered.Filters[0].Column)
|
||||
}
|
||||
|
||||
// Check Sort
|
||||
if len(filtered.Sort) != 1 {
|
||||
t.Errorf("Expected 1 sort, got %d", len(filtered.Sort))
|
||||
}
|
||||
if filtered.Sort[0].Column != "id" {
|
||||
t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column)
|
||||
}
|
||||
}
|
||||
291
pkg/config/README.md
Normal file
291
pkg/config/README.md
Normal file
@@ -0,0 +1,291 @@
|
||||
# ResolveSpec Configuration System
|
||||
|
||||
A centralized configuration system with support for multiple configuration sources: config files (YAML, TOML, JSON), environment variables, and programmatic configuration.
|
||||
|
||||
## Features
|
||||
|
||||
- **Multiple Config Sources**: Config files, environment variables, and code
|
||||
- **Priority Order**: Environment variables > Config file > Defaults
|
||||
- **Multiple Formats**: YAML, TOML, JSON supported
|
||||
- **Type Safety**: Strongly-typed configuration structs
|
||||
- **Sensible Defaults**: Works out of the box with reasonable defaults
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```go
|
||||
import "github.com/heinhel/ResolveSpec/pkg/config"
|
||||
|
||||
// Create a new config manager
|
||||
mgr := config.NewManager()
|
||||
|
||||
// Load configuration from file and environment
|
||||
if err := mgr.Load(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Get the complete configuration
|
||||
cfg, err := mgr.GetConfig()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Use the configuration
|
||||
fmt.Println("Server address:", cfg.Server.Addr)
|
||||
```
|
||||
|
||||
### Custom Configuration Paths
|
||||
|
||||
```go
|
||||
mgr := config.NewManagerWithOptions(
|
||||
config.WithConfigFile("/path/to/config.yaml"),
|
||||
config.WithEnvPrefix("MYAPP"),
|
||||
)
|
||||
```
|
||||
|
||||
## Configuration Sources
|
||||
|
||||
### 1. Config Files
|
||||
|
||||
Place a `config.yaml` file in one of these locations:
|
||||
- Current directory (`.`)
|
||||
- `./config/`
|
||||
- `/etc/resolvespec/`
|
||||
- `$HOME/.resolvespec/`
|
||||
|
||||
Example `config.yaml`:
|
||||
|
||||
```yaml
|
||||
server:
|
||||
addr: ":8080"
|
||||
shutdown_timeout: 30s
|
||||
|
||||
tracing:
|
||||
enabled: true
|
||||
service_name: "my-service"
|
||||
|
||||
cache:
|
||||
provider: "redis"
|
||||
redis:
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
```
|
||||
|
||||
### 2. Environment Variables
|
||||
|
||||
All configuration can be set via environment variables with the `RESOLVESPEC_` prefix:
|
||||
|
||||
```bash
|
||||
export RESOLVESPEC_SERVER_ADDR=":9090"
|
||||
export RESOLVESPEC_TRACING_ENABLED=true
|
||||
export RESOLVESPEC_CACHE_PROVIDER=redis
|
||||
export RESOLVESPEC_CACHE_REDIS_HOST=localhost
|
||||
```
|
||||
|
||||
Nested configuration uses underscores:
|
||||
- `server.addr` → `RESOLVESPEC_SERVER_ADDR`
|
||||
- `cache.redis.host` → `RESOLVESPEC_CACHE_REDIS_HOST`
|
||||
|
||||
### 3. Programmatic Configuration
|
||||
|
||||
```go
|
||||
mgr := config.NewManager()
|
||||
mgr.Set("server.addr", ":9090")
|
||||
mgr.Set("tracing.enabled", true)
|
||||
|
||||
cfg, _ := mgr.GetConfig()
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Server Configuration
|
||||
|
||||
```yaml
|
||||
server:
|
||||
addr: ":8080" # Server address
|
||||
shutdown_timeout: 30s # Graceful shutdown timeout
|
||||
drain_timeout: 25s # Connection drain timeout
|
||||
read_timeout: 10s # HTTP read timeout
|
||||
write_timeout: 10s # HTTP write timeout
|
||||
idle_timeout: 120s # HTTP idle timeout
|
||||
```
|
||||
|
||||
### Tracing Configuration
|
||||
|
||||
```yaml
|
||||
tracing:
|
||||
enabled: false # Enable/disable tracing
|
||||
service_name: "resolvespec" # Service name
|
||||
service_version: "1.0.0" # Service version
|
||||
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
|
||||
```
|
||||
|
||||
### Cache Configuration
|
||||
|
||||
```yaml
|
||||
cache:
|
||||
provider: "memory" # Options: memory, redis, memcache
|
||||
|
||||
redis:
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
|
||||
memcache:
|
||||
servers:
|
||||
- "localhost:11211"
|
||||
max_idle_conns: 10
|
||||
timeout: 100ms
|
||||
```
|
||||
|
||||
### Logger Configuration
|
||||
|
||||
```yaml
|
||||
logger:
|
||||
dev: false # Development mode (human-readable output)
|
||||
path: "" # Log file path (empty = stdout)
|
||||
```
|
||||
|
||||
### Middleware Configuration
|
||||
|
||||
```yaml
|
||||
middleware:
|
||||
rate_limit_rps: 100.0 # Requests per second
|
||||
rate_limit_burst: 200 # Burst size
|
||||
max_request_size: 10485760 # Max request size in bytes (10MB)
|
||||
```
|
||||
|
||||
### CORS Configuration
|
||||
|
||||
```yaml
|
||||
cors:
|
||||
allowed_origins:
|
||||
- "*"
|
||||
allowed_methods:
|
||||
- "GET"
|
||||
- "POST"
|
||||
- "PUT"
|
||||
- "DELETE"
|
||||
- "OPTIONS"
|
||||
allowed_headers:
|
||||
- "*"
|
||||
max_age: 3600
|
||||
```
|
||||
|
||||
### Database Configuration
|
||||
|
||||
```yaml
|
||||
database:
|
||||
url: "host=localhost user=postgres password=postgres dbname=mydb port=5432 sslmode=disable"
|
||||
```
|
||||
|
||||
## Priority and Overrides
|
||||
|
||||
Configuration sources are applied in this order (highest priority first):
|
||||
|
||||
1. **Environment Variables** (highest priority)
|
||||
2. **Config File**
|
||||
3. **Defaults** (lowest priority)
|
||||
|
||||
This allows you to:
|
||||
- Set defaults in code
|
||||
- Override with a config file
|
||||
- Override specific values with environment variables
|
||||
|
||||
## Examples
|
||||
|
||||
### Production Setup
|
||||
|
||||
```yaml
|
||||
# config.yaml
|
||||
server:
|
||||
addr: ":8080"
|
||||
|
||||
tracing:
|
||||
enabled: true
|
||||
service_name: "myapi"
|
||||
endpoint: "http://jaeger:4318/v1/traces"
|
||||
|
||||
cache:
|
||||
provider: "redis"
|
||||
redis:
|
||||
host: "redis"
|
||||
port: 6379
|
||||
password: "${REDIS_PASSWORD}"
|
||||
|
||||
logger:
|
||||
dev: false
|
||||
path: "/var/log/myapi/app.log"
|
||||
```
|
||||
|
||||
### Development Setup
|
||||
|
||||
```bash
|
||||
# Use environment variables for development
|
||||
export RESOLVESPEC_LOGGER_DEV=true
|
||||
export RESOLVESPEC_TRACING_ENABLED=false
|
||||
export RESOLVESPEC_CACHE_PROVIDER=memory
|
||||
```
|
||||
|
||||
### Testing Setup
|
||||
|
||||
```go
|
||||
// Override config for tests
|
||||
mgr := config.NewManager()
|
||||
mgr.Set("cache.provider", "memory")
|
||||
mgr.Set("database.url", testDBURL)
|
||||
|
||||
cfg, _ := mgr.GetConfig()
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use config files for base configuration** - Define your standard settings
|
||||
2. **Use environment variables for secrets** - Never commit passwords/tokens
|
||||
3. **Use environment variables for deployment-specific values** - Different per environment
|
||||
4. **Keep defaults sensible** - Application should work with minimal configuration
|
||||
5. **Document your configuration** - Comment your config.yaml files
|
||||
|
||||
## Integration with ResolveSpec Components
|
||||
|
||||
The configuration system integrates seamlessly with ResolveSpec components:
|
||||
|
||||
```go
|
||||
cfg, _ := config.NewManager().Load().GetConfig()
|
||||
|
||||
// Server
|
||||
srv := server.NewGracefulServer(server.Config{
|
||||
Addr: cfg.Server.Addr,
|
||||
ShutdownTimeout: cfg.Server.ShutdownTimeout,
|
||||
// ... other fields
|
||||
})
|
||||
|
||||
// Tracing
|
||||
if cfg.Tracing.Enabled {
|
||||
tracer := tracing.Init(tracing.Config{
|
||||
ServiceName: cfg.Tracing.ServiceName,
|
||||
ServiceVersion: cfg.Tracing.ServiceVersion,
|
||||
Endpoint: cfg.Tracing.Endpoint,
|
||||
})
|
||||
defer tracer.Shutdown(context.Background())
|
||||
}
|
||||
|
||||
// Cache
|
||||
var cacheProvider cache.Provider
|
||||
switch cfg.Cache.Provider {
|
||||
case "redis":
|
||||
cacheProvider = cache.NewRedisProvider(cfg.Cache.Redis.Host, cfg.Cache.Redis.Port, ...)
|
||||
case "memcache":
|
||||
cacheProvider = cache.NewMemcacheProvider(cfg.Cache.Memcache.Servers, ...)
|
||||
default:
|
||||
cacheProvider = cache.NewMemoryProvider()
|
||||
}
|
||||
|
||||
// Logger
|
||||
logger.Init(cfg.Logger.Dev)
|
||||
if cfg.Logger.Path != "" {
|
||||
logger.UpdateLoggerPath(cfg.Logger.Path, cfg.Logger.Dev)
|
||||
}
|
||||
```
|
||||
80
pkg/config/config.go
Normal file
80
pkg/config/config.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package config
|
||||
|
||||
import "time"
|
||||
|
||||
// Config represents the complete application configuration
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Tracing TracingConfig `mapstructure:"tracing"`
|
||||
Cache CacheConfig `mapstructure:"cache"`
|
||||
Logger LoggerConfig `mapstructure:"logger"`
|
||||
Middleware MiddlewareConfig `mapstructure:"middleware"`
|
||||
CORS CORSConfig `mapstructure:"cors"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
}
|
||||
|
||||
// ServerConfig holds server-related configuration
|
||||
type ServerConfig struct {
|
||||
Addr string `mapstructure:"addr"`
|
||||
ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
|
||||
DrainTimeout time.Duration `mapstructure:"drain_timeout"`
|
||||
ReadTimeout time.Duration `mapstructure:"read_timeout"`
|
||||
WriteTimeout time.Duration `mapstructure:"write_timeout"`
|
||||
IdleTimeout time.Duration `mapstructure:"idle_timeout"`
|
||||
}
|
||||
|
||||
// TracingConfig holds OpenTelemetry tracing configuration
|
||||
type TracingConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
ServiceName string `mapstructure:"service_name"`
|
||||
ServiceVersion string `mapstructure:"service_version"`
|
||||
Endpoint string `mapstructure:"endpoint"`
|
||||
}
|
||||
|
||||
// CacheConfig holds cache provider configuration
|
||||
type CacheConfig struct {
|
||||
Provider string `mapstructure:"provider"` // memory, redis, memcache
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Memcache MemcacheConfig `mapstructure:"memcache"`
|
||||
}
|
||||
|
||||
// RedisConfig holds Redis-specific configuration
|
||||
type RedisConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
Password string `mapstructure:"password"`
|
||||
DB int `mapstructure:"db"`
|
||||
}
|
||||
|
||||
// MemcacheConfig holds Memcache-specific configuration
|
||||
type MemcacheConfig struct {
|
||||
Servers []string `mapstructure:"servers"`
|
||||
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
}
|
||||
|
||||
// LoggerConfig holds logger configuration
|
||||
type LoggerConfig struct {
|
||||
Dev bool `mapstructure:"dev"`
|
||||
Path string `mapstructure:"path"`
|
||||
}
|
||||
|
||||
// MiddlewareConfig holds middleware configuration
|
||||
type MiddlewareConfig struct {
|
||||
RateLimitRPS float64 `mapstructure:"rate_limit_rps"`
|
||||
RateLimitBurst int `mapstructure:"rate_limit_burst"`
|
||||
MaxRequestSize int64 `mapstructure:"max_request_size"`
|
||||
}
|
||||
|
||||
// CORSConfig holds CORS configuration
|
||||
type CORSConfig struct {
|
||||
AllowedOrigins []string `mapstructure:"allowed_origins"`
|
||||
AllowedMethods []string `mapstructure:"allowed_methods"`
|
||||
AllowedHeaders []string `mapstructure:"allowed_headers"`
|
||||
MaxAge int `mapstructure:"max_age"`
|
||||
}
|
||||
|
||||
// DatabaseConfig holds database configuration (primarily for testing)
|
||||
type DatabaseConfig struct {
|
||||
URL string `mapstructure:"url"`
|
||||
}
|
||||
168
pkg/config/manager.go
Normal file
168
pkg/config/manager.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// Manager handles configuration loading from multiple sources
|
||||
type Manager struct {
|
||||
v *viper.Viper
|
||||
}
|
||||
|
||||
// NewManager creates a new configuration manager with defaults
|
||||
func NewManager() *Manager {
|
||||
v := viper.New()
|
||||
|
||||
// Set configuration file settings
|
||||
v.SetConfigName("config")
|
||||
v.SetConfigType("yaml")
|
||||
v.AddConfigPath(".")
|
||||
v.AddConfigPath("./config")
|
||||
v.AddConfigPath("/etc/resolvespec")
|
||||
v.AddConfigPath("$HOME/.resolvespec")
|
||||
|
||||
// Enable environment variable support
|
||||
v.SetEnvPrefix("RESOLVESPEC")
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
v.AutomaticEnv()
|
||||
|
||||
// Set default values
|
||||
setDefaults(v)
|
||||
|
||||
return &Manager{v: v}
|
||||
}
|
||||
|
||||
// NewManagerWithOptions creates a new configuration manager with custom options
|
||||
func NewManagerWithOptions(opts ...Option) *Manager {
|
||||
m := NewManager()
|
||||
for _, opt := range opts {
|
||||
opt(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Option is a functional option for configuring the Manager
|
||||
type Option func(*Manager)
|
||||
|
||||
// WithConfigFile sets a specific config file path
|
||||
func WithConfigFile(path string) Option {
|
||||
return func(m *Manager) {
|
||||
m.v.SetConfigFile(path)
|
||||
}
|
||||
}
|
||||
|
||||
// WithConfigName sets the config file name (without extension)
|
||||
func WithConfigName(name string) Option {
|
||||
return func(m *Manager) {
|
||||
m.v.SetConfigName(name)
|
||||
}
|
||||
}
|
||||
|
||||
// WithConfigPath adds a path to search for config files
|
||||
func WithConfigPath(path string) Option {
|
||||
return func(m *Manager) {
|
||||
m.v.AddConfigPath(path)
|
||||
}
|
||||
}
|
||||
|
||||
// WithEnvPrefix sets the environment variable prefix
|
||||
func WithEnvPrefix(prefix string) Option {
|
||||
return func(m *Manager) {
|
||||
m.v.SetEnvPrefix(prefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Load attempts to load configuration from file and environment
|
||||
func (m *Manager) Load() error {
|
||||
// Try to read config file (not an error if it doesn't exist)
|
||||
if err := m.v.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return fmt.Errorf("error reading config file: %w", err)
|
||||
}
|
||||
// Config file not found; will rely on defaults and env vars
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConfig returns the complete configuration
|
||||
func (m *Manager) GetConfig() (*Config, error) {
|
||||
var cfg Config
|
||||
if err := m.v.Unmarshal(&cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// Get returns a configuration value by key
|
||||
func (m *Manager) Get(key string) interface{} {
|
||||
return m.v.Get(key)
|
||||
}
|
||||
|
||||
// GetString returns a string configuration value
|
||||
func (m *Manager) GetString(key string) string {
|
||||
return m.v.GetString(key)
|
||||
}
|
||||
|
||||
// GetInt returns an int configuration value
|
||||
func (m *Manager) GetInt(key string) int {
|
||||
return m.v.GetInt(key)
|
||||
}
|
||||
|
||||
// GetBool returns a bool configuration value
|
||||
func (m *Manager) GetBool(key string) bool {
|
||||
return m.v.GetBool(key)
|
||||
}
|
||||
|
||||
// Set sets a configuration value
|
||||
func (m *Manager) Set(key string, value interface{}) {
|
||||
m.v.Set(key, value)
|
||||
}
|
||||
|
||||
// setDefaults sets default configuration values
|
||||
func setDefaults(v *viper.Viper) {
|
||||
// Server defaults
|
||||
v.SetDefault("server.addr", ":8080")
|
||||
v.SetDefault("server.shutdown_timeout", "30s")
|
||||
v.SetDefault("server.drain_timeout", "25s")
|
||||
v.SetDefault("server.read_timeout", "10s")
|
||||
v.SetDefault("server.write_timeout", "10s")
|
||||
v.SetDefault("server.idle_timeout", "120s")
|
||||
|
||||
// Tracing defaults
|
||||
v.SetDefault("tracing.enabled", false)
|
||||
v.SetDefault("tracing.service_name", "resolvespec")
|
||||
v.SetDefault("tracing.service_version", "1.0.0")
|
||||
v.SetDefault("tracing.endpoint", "")
|
||||
|
||||
// Cache defaults
|
||||
v.SetDefault("cache.provider", "memory")
|
||||
v.SetDefault("cache.redis.host", "localhost")
|
||||
v.SetDefault("cache.redis.port", 6379)
|
||||
v.SetDefault("cache.redis.password", "")
|
||||
v.SetDefault("cache.redis.db", 0)
|
||||
v.SetDefault("cache.memcache.servers", []string{"localhost:11211"})
|
||||
v.SetDefault("cache.memcache.max_idle_conns", 10)
|
||||
v.SetDefault("cache.memcache.timeout", "100ms")
|
||||
|
||||
// Logger defaults
|
||||
v.SetDefault("logger.dev", false)
|
||||
v.SetDefault("logger.path", "")
|
||||
|
||||
// Middleware defaults
|
||||
v.SetDefault("middleware.rate_limit_rps", 100.0)
|
||||
v.SetDefault("middleware.rate_limit_burst", 200)
|
||||
v.SetDefault("middleware.max_request_size", 10485760) // 10MB
|
||||
|
||||
// CORS defaults
|
||||
v.SetDefault("cors.allowed_origins", []string{"*"})
|
||||
v.SetDefault("cors.allowed_methods", []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"})
|
||||
v.SetDefault("cors.allowed_headers", []string{"*"})
|
||||
v.SetDefault("cors.max_age", 3600)
|
||||
|
||||
// Database defaults
|
||||
v.SetDefault("database.url", "")
|
||||
}
|
||||
166
pkg/config/manager_test.go
Normal file
166
pkg/config/manager_test.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
if mgr == nil {
|
||||
t.Fatal("Expected manager to be non-nil")
|
||||
}
|
||||
|
||||
if mgr.v == nil {
|
||||
t.Fatal("Expected viper instance to be non-nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultValues(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
if err := mgr.Load(); err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := mgr.GetConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
// Test default values
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{"server.addr", cfg.Server.Addr, ":8080"},
|
||||
{"server.shutdown_timeout", cfg.Server.ShutdownTimeout, 30 * time.Second},
|
||||
{"tracing.enabled", cfg.Tracing.Enabled, false},
|
||||
{"tracing.service_name", cfg.Tracing.ServiceName, "resolvespec"},
|
||||
{"cache.provider", cfg.Cache.Provider, "memory"},
|
||||
{"cache.redis.host", cfg.Cache.Redis.Host, "localhost"},
|
||||
{"cache.redis.port", cfg.Cache.Redis.Port, 6379},
|
||||
{"logger.dev", cfg.Logger.Dev, false},
|
||||
{"middleware.rate_limit_rps", cfg.Middleware.RateLimitRPS, 100.0},
|
||||
{"middleware.rate_limit_burst", cfg.Middleware.RateLimitBurst, 200},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.expected {
|
||||
t.Errorf("%s: got %v, want %v", tt.name, tt.got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvironmentVariableOverrides(t *testing.T) {
|
||||
// Set environment variables
|
||||
os.Setenv("RESOLVESPEC_SERVER_ADDR", ":9090")
|
||||
os.Setenv("RESOLVESPEC_TRACING_ENABLED", "true")
|
||||
os.Setenv("RESOLVESPEC_CACHE_PROVIDER", "redis")
|
||||
os.Setenv("RESOLVESPEC_LOGGER_DEV", "true")
|
||||
defer func() {
|
||||
os.Unsetenv("RESOLVESPEC_SERVER_ADDR")
|
||||
os.Unsetenv("RESOLVESPEC_TRACING_ENABLED")
|
||||
os.Unsetenv("RESOLVESPEC_CACHE_PROVIDER")
|
||||
os.Unsetenv("RESOLVESPEC_LOGGER_DEV")
|
||||
}()
|
||||
|
||||
mgr := NewManager()
|
||||
if err := mgr.Load(); err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := mgr.GetConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
// Test environment variable overrides
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{"server.addr", cfg.Server.Addr, ":9090"},
|
||||
{"tracing.enabled", cfg.Tracing.Enabled, true},
|
||||
{"cache.provider", cfg.Cache.Provider, "redis"},
|
||||
{"logger.dev", cfg.Logger.Dev, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.got != tt.expected {
|
||||
t.Errorf("%s: got %v, want %v", tt.name, tt.got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProgrammaticConfiguration(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
mgr.Set("server.addr", ":7070")
|
||||
mgr.Set("tracing.service_name", "test-service")
|
||||
|
||||
cfg, err := mgr.GetConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Addr != ":7070" {
|
||||
t.Errorf("server.addr: got %s, want :7070", cfg.Server.Addr)
|
||||
}
|
||||
|
||||
if cfg.Tracing.ServiceName != "test-service" {
|
||||
t.Errorf("tracing.service_name: got %s, want test-service", cfg.Tracing.ServiceName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetterMethods(t *testing.T) {
|
||||
mgr := NewManager()
|
||||
mgr.Set("test.string", "value")
|
||||
mgr.Set("test.int", 42)
|
||||
mgr.Set("test.bool", true)
|
||||
|
||||
if got := mgr.GetString("test.string"); got != "value" {
|
||||
t.Errorf("GetString: got %s, want value", got)
|
||||
}
|
||||
|
||||
if got := mgr.GetInt("test.int"); got != 42 {
|
||||
t.Errorf("GetInt: got %d, want 42", got)
|
||||
}
|
||||
|
||||
if got := mgr.GetBool("test.bool"); !got {
|
||||
t.Errorf("GetBool: got %v, want true", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithOptions(t *testing.T) {
|
||||
mgr := NewManagerWithOptions(
|
||||
WithEnvPrefix("MYAPP"),
|
||||
WithConfigName("myconfig"),
|
||||
)
|
||||
|
||||
if mgr == nil {
|
||||
t.Fatal("Expected manager to be non-nil")
|
||||
}
|
||||
|
||||
// Set environment variable with custom prefix
|
||||
os.Setenv("MYAPP_SERVER_ADDR", ":5000")
|
||||
defer os.Unsetenv("MYAPP_SERVER_ADDR")
|
||||
|
||||
if err := mgr.Load(); err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := mgr.GetConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Server.Addr != ":5000" {
|
||||
t.Errorf("server.addr: got %s, want :5000", cfg.Server.Addr)
|
||||
}
|
||||
}
|
||||
1015
pkg/funcspec/function_api.go
Normal file
1015
pkg/funcspec/function_api.go
Normal file
File diff suppressed because it is too large
Load Diff
899
pkg/funcspec/function_api_test.go
Normal file
899
pkg/funcspec/function_api_test.go
Normal file
@@ -0,0 +1,899 @@
|
||||
package funcspec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// MockDatabase implements common.Database interface for testing
|
||||
type MockDatabase struct {
|
||||
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
|
||||
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
|
||||
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
|
||||
}
|
||||
|
||||
func (m *MockDatabase) NewSelect() common.SelectQuery {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockDatabase) NewInsert() common.InsertQuery {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockDatabase) NewUpdate() common.UpdateQuery {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockDatabase) NewDelete() common.DeleteQuery {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
||||
if m.ExecFunc != nil {
|
||||
return m.ExecFunc(ctx, query, args...)
|
||||
}
|
||||
return &MockResult{rows: 0}, nil
|
||||
}
|
||||
|
||||
func (m *MockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
if m.QueryFunc != nil {
|
||||
return m.QueryFunc(ctx, dest, query, args...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockDatabase) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *MockDatabase) CommitTx(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockDatabase) RollbackTx(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
|
||||
if m.RunInTransactionFunc != nil {
|
||||
return m.RunInTransactionFunc(ctx, fn)
|
||||
}
|
||||
return fn(m)
|
||||
}
|
||||
|
||||
// MockResult implements common.Result interface for testing
|
||||
type MockResult struct {
|
||||
rows int64
|
||||
id int64
|
||||
}
|
||||
|
||||
func (m *MockResult) RowsAffected() int64 {
|
||||
return m.rows
|
||||
}
|
||||
|
||||
func (m *MockResult) LastInsertId() (int64, error) {
|
||||
return m.id, nil
|
||||
}
|
||||
|
||||
// Helper function to create a test request with user context
|
||||
func createTestRequest(method, path string, queryParams map[string]string, headers map[string]string, body []byte) *http.Request {
|
||||
u, _ := url.Parse(path)
|
||||
if queryParams != nil {
|
||||
q := u.Query()
|
||||
for k, v := range queryParams {
|
||||
q.Set(k, v)
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
var bodyReader *bytes.Reader
|
||||
if body != nil {
|
||||
bodyReader = bytes.NewReader(body)
|
||||
} else {
|
||||
bodyReader = bytes.NewReader([]byte{})
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(method, u.String(), bodyReader)
|
||||
|
||||
if headers != nil {
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
// Add user context
|
||||
userCtx := &security.UserContext{
|
||||
UserID: 1,
|
||||
UserName: "testuser",
|
||||
SessionID: "test-session-123",
|
||||
}
|
||||
ctx := context.WithValue(req.Context(), security.UserContextKey, userCtx)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
return req
|
||||
}
|
||||
|
||||
// TestNewHandler tests handler creation
|
||||
func TestNewHandler(t *testing.T) {
|
||||
db := &MockDatabase{}
|
||||
handler := NewHandler(db)
|
||||
|
||||
if handler == nil {
|
||||
t.Fatal("Expected handler to be created, got nil")
|
||||
}
|
||||
|
||||
if handler.db != db {
|
||||
t.Error("Expected handler to have the provided database")
|
||||
}
|
||||
|
||||
if handler.hooks == nil {
|
||||
t.Error("Expected handler to have a hook registry")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandlerHooks tests the Hooks method
|
||||
func TestHandlerHooks(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
hooks := handler.Hooks()
|
||||
|
||||
if hooks == nil {
|
||||
t.Fatal("Expected hooks registry to be non-nil")
|
||||
}
|
||||
|
||||
// Should return the same instance
|
||||
hooks2 := handler.Hooks()
|
||||
if hooks != hooks2 {
|
||||
t.Error("Expected Hooks() to return the same registry instance")
|
||||
}
|
||||
}
|
||||
|
||||
// TestExtractInputVariables tests the extractInputVariables function
|
||||
func TestExtractInputVariables(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
expectedVars []string
|
||||
}{
|
||||
{
|
||||
name: "No variables",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
expectedVars: []string{},
|
||||
},
|
||||
{
|
||||
name: "Single variable",
|
||||
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
|
||||
expectedVars: []string{"[user_id]"},
|
||||
},
|
||||
{
|
||||
name: "Multiple variables",
|
||||
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = [user_name]",
|
||||
expectedVars: []string{"[user_id]", "[user_name]"},
|
||||
},
|
||||
{
|
||||
name: "Nested brackets",
|
||||
sqlQuery: "SELECT * FROM users WHERE data::jsonb @> '[field]'::jsonb AND id = [user_id]",
|
||||
expectedVars: []string{"[field]", "[user_id]"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
inputvars := make([]string, 0)
|
||||
result := handler.extractInputVariables(tt.sqlQuery, &inputvars)
|
||||
|
||||
if result != tt.sqlQuery {
|
||||
t.Errorf("Expected SQL query to be unchanged, got %s", result)
|
||||
}
|
||||
|
||||
if len(inputvars) != len(tt.expectedVars) {
|
||||
t.Errorf("Expected %d variables, got %d: %v", len(tt.expectedVars), len(inputvars), inputvars)
|
||||
return
|
||||
}
|
||||
|
||||
for i, expected := range tt.expectedVars {
|
||||
if inputvars[i] != expected {
|
||||
t.Errorf("Expected variable %d to be %s, got %s", i, expected, inputvars[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidSQL tests the SQL sanitization function
|
||||
func TestValidSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
mode string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Column name with valid characters",
|
||||
input: "user_id",
|
||||
mode: "colname",
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
name: "Column name with dots (table.column)",
|
||||
input: "users.user_id",
|
||||
mode: "colname",
|
||||
expected: "users.user_id",
|
||||
},
|
||||
{
|
||||
name: "Column name with SQL injection attempt",
|
||||
input: "id'; DROP TABLE users--",
|
||||
mode: "colname",
|
||||
expected: "idDROPTABLEusers",
|
||||
},
|
||||
{
|
||||
name: "Column value with single quotes",
|
||||
input: "O'Brien",
|
||||
mode: "colvalue",
|
||||
expected: "O''Brien",
|
||||
},
|
||||
{
|
||||
name: "Select with dangerous keywords",
|
||||
input: "name, email; DROP TABLE users",
|
||||
mode: "select",
|
||||
expected: "name, email TABLE users",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ValidSQL(tt.input, tt.mode)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ValidSQL(%q, %q) = %q, expected %q", tt.input, tt.mode, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsNumeric tests the IsNumeric function
|
||||
func TestIsNumeric(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"123", true},
|
||||
{"123.45", true},
|
||||
{"-123", true},
|
||||
{"-123.45", true},
|
||||
{"0", true},
|
||||
{"abc", false},
|
||||
{"12.34.56", false},
|
||||
{"", false},
|
||||
{"123abc", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := IsNumeric(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsNumeric(%q) = %v, expected %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlQryWhere tests the WHERE clause manipulation
|
||||
func TestSqlQryWhere(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
condition string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Add WHERE to query without WHERE",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
condition: "status = 'active'",
|
||||
expected: "SELECT * FROM users WHERE status = 'active' ",
|
||||
},
|
||||
{
|
||||
name: "Add AND to query with existing WHERE",
|
||||
sqlQuery: "SELECT * FROM users WHERE id > 0",
|
||||
condition: "status = 'active'",
|
||||
expected: "SELECT * FROM users WHERE id > 0 AND status = 'active' ",
|
||||
},
|
||||
{
|
||||
name: "Add WHERE before ORDER BY",
|
||||
sqlQuery: "SELECT * FROM users ORDER BY name",
|
||||
condition: "status = 'active'",
|
||||
expected: "SELECT * FROM users WHERE status = 'active' ORDER BY name",
|
||||
},
|
||||
{
|
||||
name: "Add WHERE before GROUP BY",
|
||||
sqlQuery: "SELECT COUNT(*) FROM users GROUP BY department",
|
||||
condition: "status = 'active'",
|
||||
expected: "SELECT COUNT(*) FROM users WHERE status = 'active' GROUP BY department",
|
||||
},
|
||||
{
|
||||
name: "Add WHERE before LIMIT",
|
||||
sqlQuery: "SELECT * FROM users LIMIT 10",
|
||||
condition: "status = 'active'",
|
||||
expected: "SELECT * FROM users WHERE status = 'active' LIMIT 10",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sqlQryWhere(tt.sqlQuery, tt.condition)
|
||||
if result != tt.expected {
|
||||
t.Errorf("sqlQryWhere() = %q, expected %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetIPAddress tests IP address extraction
|
||||
func TestGetIPAddress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setupReq func() *http.Request
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-For header",
|
||||
setupReq: func() *http.Request {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
|
||||
return req
|
||||
},
|
||||
expected: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP header",
|
||||
setupReq: func() *http.Request {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "192.168.1.200")
|
||||
return req
|
||||
},
|
||||
expected: "192.168.1.200",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr fallback",
|
||||
setupReq: func() *http.Request {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
return req
|
||||
},
|
||||
expected: "192.168.1.1:12345",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := tt.setupReq()
|
||||
result := getIPAddress(req)
|
||||
if result != tt.expected {
|
||||
t.Errorf("getIPAddress() = %q, expected %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParsePaginationParams tests pagination parameter parsing
|
||||
func TestParsePaginationParams(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams map[string]string
|
||||
expectedSort string
|
||||
expectedLimit int
|
||||
expectedOffset int
|
||||
}{
|
||||
{
|
||||
name: "No parameters - defaults",
|
||||
queryParams: map[string]string{},
|
||||
expectedSort: "",
|
||||
expectedLimit: 20,
|
||||
expectedOffset: 0,
|
||||
},
|
||||
{
|
||||
name: "All parameters provided",
|
||||
queryParams: map[string]string{
|
||||
"sort": "name,-created_at",
|
||||
"limit": "100",
|
||||
"offset": "50",
|
||||
},
|
||||
expectedSort: "name,-created_at",
|
||||
expectedLimit: 100,
|
||||
expectedOffset: 50,
|
||||
},
|
||||
{
|
||||
name: "Invalid limit - use default",
|
||||
queryParams: map[string]string{
|
||||
"limit": "invalid",
|
||||
},
|
||||
expectedSort: "",
|
||||
expectedLimit: 20,
|
||||
expectedOffset: 0,
|
||||
},
|
||||
{
|
||||
name: "Negative offset - use default",
|
||||
queryParams: map[string]string{
|
||||
"offset": "-10",
|
||||
},
|
||||
expectedSort: "",
|
||||
expectedLimit: 20,
|
||||
expectedOffset: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
|
||||
sort, limit, offset := handler.parsePaginationParams(req)
|
||||
|
||||
if sort != tt.expectedSort {
|
||||
t.Errorf("Expected sort=%q, got %q", tt.expectedSort, sort)
|
||||
}
|
||||
if limit != tt.expectedLimit {
|
||||
t.Errorf("Expected limit=%d, got %d", tt.expectedLimit, limit)
|
||||
}
|
||||
if offset != tt.expectedOffset {
|
||||
t.Errorf("Expected offset=%d, got %d", tt.expectedOffset, offset)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlQuery tests the SqlQuery handler for single record queries
|
||||
func TestSqlQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
blankParams bool
|
||||
queryParams map[string]string
|
||||
headers map[string]string
|
||||
setupDB func() *MockDatabase
|
||||
expectedStatus int
|
||||
validateResp func(t *testing.T, body []byte)
|
||||
}{
|
||||
{
|
||||
name: "Basic query - returns single record",
|
||||
sqlQuery: "SELECT * FROM users WHERE id = 1",
|
||||
blankParams: false,
|
||||
setupDB: func() *MockDatabase {
|
||||
return &MockDatabase{
|
||||
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
|
||||
db := &MockDatabase{
|
||||
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
rows := dest.(*[]map[string]interface{})
|
||||
*rows = []map[string]interface{}{
|
||||
{"id": float64(1), "name": "Test User", "email": "test@example.com"},
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return fn(db)
|
||||
},
|
||||
}
|
||||
},
|
||||
expectedStatus: 200,
|
||||
validateResp: func(t *testing.T, body []byte) {
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
if result["name"] != "Test User" {
|
||||
t.Errorf("Expected name='Test User', got %v", result["name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Query with no results",
|
||||
sqlQuery: "SELECT * FROM users WHERE id = 999",
|
||||
blankParams: false,
|
||||
setupDB: func() *MockDatabase {
|
||||
return &MockDatabase{
|
||||
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
|
||||
db := &MockDatabase{
|
||||
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
// Return empty array
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return fn(db)
|
||||
},
|
||||
}
|
||||
},
|
||||
expectedStatus: 200,
|
||||
validateResp: func(t *testing.T, body []byte) {
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
if len(result) != 0 {
|
||||
t.Errorf("Expected empty result, got %v", result)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := tt.setupDB()
|
||||
handler := NewHandler(db)
|
||||
|
||||
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handlerFunc := handler.SqlQuery(tt.sqlQuery, tt.blankParams)
|
||||
handlerFunc(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
if tt.validateResp != nil {
|
||||
tt.validateResp(t, w.Body.Bytes())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlQueryList tests the SqlQueryList handler for list queries
|
||||
func TestSqlQueryList(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
noCount bool
|
||||
blankParams bool
|
||||
allowFilter bool
|
||||
queryParams map[string]string
|
||||
headers map[string]string
|
||||
setupDB func() *MockDatabase
|
||||
expectedStatus int
|
||||
validateResp func(t *testing.T, w *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
name: "Basic list query",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
noCount: false,
|
||||
blankParams: false,
|
||||
allowFilter: false,
|
||||
setupDB: func() *MockDatabase {
|
||||
return &MockDatabase{
|
||||
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
|
||||
callCount := 0
|
||||
db := &MockDatabase{
|
||||
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
callCount++
|
||||
if strings.Contains(query, "COUNT") {
|
||||
// Count query
|
||||
countResult := dest.(*struct{ Count int64 })
|
||||
countResult.Count = 2
|
||||
} else {
|
||||
// Main query
|
||||
rows := dest.(*[]map[string]interface{})
|
||||
*rows = []map[string]interface{}{
|
||||
{"id": float64(1), "name": "User 1"},
|
||||
{"id": float64(2), "name": "User 2"},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return fn(db)
|
||||
},
|
||||
}
|
||||
},
|
||||
expectedStatus: 200,
|
||||
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
|
||||
var result []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
if len(result) != 2 {
|
||||
t.Errorf("Expected 2 results, got %d", len(result))
|
||||
}
|
||||
|
||||
// Check Content-Range header
|
||||
contentRange := w.Header().Get("Content-Range")
|
||||
if !strings.Contains(contentRange, "2") {
|
||||
t.Errorf("Expected Content-Range to contain total count, got: %s", contentRange)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "List query with noCount",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
noCount: true,
|
||||
blankParams: false,
|
||||
allowFilter: false,
|
||||
setupDB: func() *MockDatabase {
|
||||
return &MockDatabase{
|
||||
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
|
||||
db := &MockDatabase{
|
||||
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
if strings.Contains(query, "COUNT") {
|
||||
t.Error("Count query should not be executed when noCount is true")
|
||||
}
|
||||
rows := dest.(*[]map[string]interface{})
|
||||
*rows = []map[string]interface{}{
|
||||
{"id": float64(1), "name": "User 1"},
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return fn(db)
|
||||
},
|
||||
}
|
||||
},
|
||||
expectedStatus: 200,
|
||||
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
|
||||
var result []map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("Failed to unmarshal response: %v", err)
|
||||
}
|
||||
if len(result) != 1 {
|
||||
t.Errorf("Expected 1 result, got %d", len(result))
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := tt.setupDB()
|
||||
handler := NewHandler(db)
|
||||
|
||||
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handlerFunc := handler.SqlQueryList(tt.sqlQuery, tt.noCount, tt.blankParams, tt.allowFilter)
|
||||
handlerFunc(w, req)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("Expected status %d, got %d. Body: %s", tt.expectedStatus, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
if tt.validateResp != nil {
|
||||
tt.validateResp(t, w)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMergeQueryParams tests query parameter merging
|
||||
func TestMergeQueryParams(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
queryParams map[string]string
|
||||
allowFilter bool
|
||||
expectedQuery string
|
||||
checkVars func(t *testing.T, vars map[string]interface{})
|
||||
}{
|
||||
{
|
||||
name: "Replace placeholder with parameter",
|
||||
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
|
||||
queryParams: map[string]string{"p-user_id": "123"},
|
||||
allowFilter: false,
|
||||
checkVars: func(t *testing.T, vars map[string]interface{}) {
|
||||
if vars["p-user_id"] != "123" {
|
||||
t.Errorf("Expected p-user_id=123, got %v", vars["p-user_id"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Add filter when allowed",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
queryParams: map[string]string{"status": "active"},
|
||||
allowFilter: true,
|
||||
checkVars: func(t *testing.T, vars map[string]interface{}) {
|
||||
if vars["status"] != "active" {
|
||||
t.Errorf("Expected status=active, got %v", vars["status"])
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
|
||||
variables := make(map[string]interface{})
|
||||
propQry := make(map[string]string)
|
||||
|
||||
result := handler.mergeQueryParams(req, tt.sqlQuery, variables, tt.allowFilter, propQry)
|
||||
|
||||
if result == "" {
|
||||
t.Error("Expected non-empty SQL query result")
|
||||
}
|
||||
|
||||
if tt.checkVars != nil {
|
||||
tt.checkVars(t, variables)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMergeHeaderParams tests header parameter merging
|
||||
func TestMergeHeaderParams(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
headers map[string]string
|
||||
expectedQuery string
|
||||
checkVars func(t *testing.T, vars map[string]interface{})
|
||||
}{
|
||||
{
|
||||
name: "Field filter header",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
headers: map[string]string{"X-FieldFilter-Status": "1"},
|
||||
checkVars: func(t *testing.T, vars map[string]interface{}) {
|
||||
if vars["x-fieldfilter-status"] != "1" {
|
||||
t.Errorf("Expected x-fieldfilter-status=1, got %v", vars["x-fieldfilter-status"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Search filter header",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
headers: map[string]string{"X-SearchFilter-Name": "john"},
|
||||
checkVars: func(t *testing.T, vars map[string]interface{}) {
|
||||
if vars["x-searchfilter-name"] != "john" {
|
||||
t.Errorf("Expected x-searchfilter-name=john, got %v", vars["x-searchfilter-name"])
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := createTestRequest("GET", "/test", nil, tt.headers, nil)
|
||||
variables := make(map[string]interface{})
|
||||
propQry := make(map[string]string)
|
||||
complexAPI := false
|
||||
|
||||
result := handler.mergeHeaderParams(req, tt.sqlQuery, variables, propQry, &complexAPI)
|
||||
|
||||
if result == "" {
|
||||
t.Error("Expected non-empty SQL query result")
|
||||
}
|
||||
|
||||
if tt.checkVars != nil {
|
||||
tt.checkVars(t, variables)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestReplaceMetaVariables tests meta variable replacement
|
||||
func TestReplaceMetaVariables(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
userCtx := &security.UserContext{
|
||||
UserID: 123,
|
||||
UserName: "testuser",
|
||||
SessionID: "456",
|
||||
}
|
||||
|
||||
metainfo := map[string]interface{}{
|
||||
"ipaddress": "192.168.1.1",
|
||||
"url": "/api/test",
|
||||
}
|
||||
|
||||
variables := map[string]interface{}{
|
||||
"param1": "value1",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
expectedCheck func(result string) bool
|
||||
}{
|
||||
{
|
||||
name: "Replace [rid_user]",
|
||||
sqlQuery: "SELECT * FROM users WHERE created_by = [rid_user]",
|
||||
expectedCheck: func(result string) bool {
|
||||
return strings.Contains(result, "123")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Replace [user]",
|
||||
sqlQuery: "SELECT * FROM audit WHERE username = [user]",
|
||||
expectedCheck: func(result string) bool {
|
||||
return strings.Contains(result, "'testuser'")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Replace [rid_session]",
|
||||
sqlQuery: "SELECT * FROM sessions WHERE session_id = [rid_session]",
|
||||
expectedCheck: func(result string) bool {
|
||||
return strings.Contains(result, "456")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := createTestRequest("GET", "/test", nil, nil, nil)
|
||||
result := handler.replaceMetaVariables(tt.sqlQuery, req, userCtx, metainfo, variables)
|
||||
|
||||
if !tt.expectedCheck(result) {
|
||||
t.Errorf("Meta variable replacement failed. Query: %s", result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetReplacementForBlankParam tests the blank parameter replacement logic
|
||||
func TestGetReplacementForBlankParam(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
param string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Parameter in single quotes",
|
||||
sqlQuery: "SELECT * FROM users WHERE name = '[username]'",
|
||||
param: "[username]",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Parameter in dollar quotes",
|
||||
sqlQuery: "SELECT * FROM users WHERE data = $[jsondata]$",
|
||||
param: "[jsondata]",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Parameter not in quotes",
|
||||
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
|
||||
param: "[user_id]",
|
||||
expected: "NULL",
|
||||
},
|
||||
{
|
||||
name: "Parameter not in quotes with AND",
|
||||
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND status = 1",
|
||||
param: "[user_id]",
|
||||
expected: "NULL",
|
||||
},
|
||||
{
|
||||
name: "Parameter in mixed quote context - before quote",
|
||||
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = 'test'",
|
||||
param: "[user_id]",
|
||||
expected: "NULL",
|
||||
},
|
||||
{
|
||||
name: "Parameter in mixed quote context - in quotes",
|
||||
sqlQuery: "SELECT * FROM users WHERE name = '[username]' AND id = 1",
|
||||
param: "[username]",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Parameter with dollar quote tag",
|
||||
sqlQuery: "SELECT * FROM users WHERE body = $tag$[content]$tag$",
|
||||
param: "[content]",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getReplacementForBlankParam(tt.sqlQuery, tt.param)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected replacement '%s', got '%s' for query: %s", tt.expected, result, tt.sqlQuery)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
160
pkg/funcspec/hooks.go
Normal file
160
pkg/funcspec/hooks.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package funcspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// HookType defines the type of hook to execute
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// Query operation hooks (for SqlQuery - single record)
|
||||
BeforeQuery HookType = "before_query"
|
||||
AfterQuery HookType = "after_query"
|
||||
|
||||
// Query list operation hooks (for SqlQueryList - multiple records)
|
||||
BeforeQueryList HookType = "before_query_list"
|
||||
AfterQueryList HookType = "after_query_list"
|
||||
|
||||
// SQL execution hooks (just before SQL is executed)
|
||||
BeforeSQLExec HookType = "before_sql_exec"
|
||||
AfterSQLExec HookType = "after_sql_exec"
|
||||
|
||||
// Response hooks (before response is sent)
|
||||
BeforeResponse HookType = "before_response"
|
||||
)
|
||||
|
||||
// HookContext contains all the data available to a hook
|
||||
type HookContext struct {
|
||||
Context context.Context
|
||||
Handler *Handler // Reference to the handler for accessing database
|
||||
Request *http.Request
|
||||
Writer http.ResponseWriter
|
||||
|
||||
// SQL query and variables
|
||||
SQLQuery string // The SQL query being executed (can be modified by hooks)
|
||||
Variables map[string]interface{} // Variables extracted from request
|
||||
InputVars []string // Input variable placeholders found in query
|
||||
MetaInfo map[string]interface{} // Metadata about the request
|
||||
PropQry map[string]string // Property query parameters
|
||||
|
||||
// User context
|
||||
UserContext *security.UserContext
|
||||
|
||||
// Pagination and filtering (for list queries)
|
||||
SortColumns string
|
||||
Limit int
|
||||
Offset int
|
||||
|
||||
// Results
|
||||
Result interface{} // Query result (single record or list)
|
||||
Total int64 // Total count (for list queries)
|
||||
Error error // Error if operation failed
|
||||
ComplexAPI bool // Whether complex API response format is requested
|
||||
NoCount bool // Whether count query should be skipped
|
||||
BlankParams bool // Whether blank parameters should be removed
|
||||
AllowFilter bool // Whether filtering is allowed
|
||||
|
||||
// Allow hooks to abort the operation
|
||||
Abort bool // If set to true, the operation will be aborted
|
||||
AbortMessage string // Message to return if aborted
|
||||
AbortCode int // HTTP status code if aborted
|
||||
}
|
||||
|
||||
// HookFunc is the signature for hook functions
|
||||
// It receives a HookContext and can modify it or return an error
|
||||
// If an error is returned, the operation will be aborted
|
||||
type HookFunc func(*HookContext) error
|
||||
|
||||
// HookRegistry manages all registered hooks
|
||||
type HookRegistry struct {
|
||||
hooks map[HookType][]HookFunc
|
||||
}
|
||||
|
||||
// NewHookRegistry creates a new hook registry
|
||||
func NewHookRegistry() *HookRegistry {
|
||||
return &HookRegistry{
|
||||
hooks: make(map[HookType][]HookFunc),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a new hook for the specified hook type
|
||||
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
|
||||
if r.hooks == nil {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
}
|
||||
r.hooks[hookType] = append(r.hooks[hookType], hook)
|
||||
logger.Info("Registered funcspec hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
|
||||
}
|
||||
|
||||
// RegisterMultiple registers a hook for multiple hook types
|
||||
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
|
||||
for _, hookType := range hookTypes {
|
||||
r.Register(hookType, hook)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs all hooks for the specified type in order
|
||||
// If any hook returns an error, execution stops and the error is returned
|
||||
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||
hooks, exists := r.hooks[hookType]
|
||||
if !exists || len(hooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Executing %d funcspec hook(s) for %s", len(hooks), hookType)
|
||||
|
||||
for i, hook := range hooks {
|
||||
if err := hook(ctx); err != nil {
|
||||
logger.Error("Funcspec hook %d for %s failed: %v", i+1, hookType, err)
|
||||
return fmt.Errorf("hook execution failed: %w", err)
|
||||
}
|
||||
|
||||
// Check if hook requested abort
|
||||
if ctx.Abort {
|
||||
logger.Warn("Funcspec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all hooks for the specified type
|
||||
func (r *HookRegistry) Clear(hookType HookType) {
|
||||
delete(r.hooks, hookType)
|
||||
logger.Info("Cleared all funcspec hooks for %s", hookType)
|
||||
}
|
||||
|
||||
// ClearAll removes all registered hooks
|
||||
func (r *HookRegistry) ClearAll() {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
logger.Info("Cleared all funcspec hooks")
|
||||
}
|
||||
|
||||
// Count returns the number of hooks registered for a specific type
|
||||
func (r *HookRegistry) Count(hookType HookType) int {
|
||||
if hooks, exists := r.hooks[hookType]; exists {
|
||||
return len(hooks)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// HasHooks returns true if there are any hooks registered for the specified type
|
||||
func (r *HookRegistry) HasHooks(hookType HookType) bool {
|
||||
return r.Count(hookType) > 0
|
||||
}
|
||||
|
||||
// GetAllHookTypes returns all hook types that have registered hooks
|
||||
func (r *HookRegistry) GetAllHookTypes() []HookType {
|
||||
types := make([]HookType, 0, len(r.hooks))
|
||||
for hookType := range r.hooks {
|
||||
types = append(types, hookType)
|
||||
}
|
||||
return types
|
||||
}
|
||||
137
pkg/funcspec/hooks_example.go
Normal file
137
pkg/funcspec/hooks_example.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package funcspec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// Example hook functions demonstrating various use cases
|
||||
|
||||
// ExampleLoggingHook logs all SQL queries before execution
|
||||
func ExampleLoggingHook(ctx *HookContext) error {
|
||||
logger.Info("Executing SQL query for user %s: %s", ctx.UserContext.UserName, ctx.SQLQuery)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleSecurityHook validates user permissions before executing queries
|
||||
func ExampleSecurityHook(ctx *HookContext) error {
|
||||
// Example: Block queries that try to access sensitive tables
|
||||
if strings.Contains(strings.ToLower(ctx.SQLQuery), "sensitive_table") {
|
||||
if ctx.UserContext.UserID != 1 { // Only admin can access
|
||||
ctx.Abort = true
|
||||
ctx.AbortCode = 403
|
||||
ctx.AbortMessage = "Access denied: insufficient permissions"
|
||||
return fmt.Errorf("access denied to sensitive_table")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleQueryModificationHook modifies SQL queries to add user-specific filtering
|
||||
func ExampleQueryModificationHook(ctx *HookContext) error {
|
||||
// Example: Automatically add user_id filter for non-admin users
|
||||
if ctx.UserContext.UserID != 1 { // Not admin
|
||||
// Add WHERE clause to filter by user_id
|
||||
if !strings.Contains(strings.ToLower(ctx.SQLQuery), "where") {
|
||||
ctx.SQLQuery = fmt.Sprintf("%s WHERE user_id = %d", ctx.SQLQuery, ctx.UserContext.UserID)
|
||||
} else {
|
||||
ctx.SQLQuery = strings.Replace(
|
||||
ctx.SQLQuery,
|
||||
"WHERE",
|
||||
fmt.Sprintf("WHERE user_id = %d AND", ctx.UserContext.UserID),
|
||||
1,
|
||||
)
|
||||
}
|
||||
logger.Debug("Modified query for user %d: %s", ctx.UserContext.UserID, ctx.SQLQuery)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleResultFilterHook filters results after query execution
|
||||
func ExampleResultFilterHook(ctx *HookContext) error {
|
||||
// Example: Remove sensitive fields from results for non-admin users
|
||||
if ctx.UserContext.UserID != 1 { // Not admin
|
||||
switch result := ctx.Result.(type) {
|
||||
case []map[string]interface{}:
|
||||
// Filter list results
|
||||
for i := range result {
|
||||
delete(result[i], "password")
|
||||
delete(result[i], "ssn")
|
||||
delete(result[i], "credit_card")
|
||||
}
|
||||
case map[string]interface{}:
|
||||
// Filter single record
|
||||
delete(result, "password")
|
||||
delete(result, "ssn")
|
||||
delete(result, "credit_card")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleAuditHook logs all queries and results for audit purposes
|
||||
func ExampleAuditHook(ctx *HookContext) error {
|
||||
// Log to audit table or external system
|
||||
logger.Info("AUDIT: User %s (%d) executed query from %s",
|
||||
ctx.UserContext.UserName,
|
||||
ctx.UserContext.UserID,
|
||||
ctx.Request.RemoteAddr,
|
||||
)
|
||||
|
||||
// In a real implementation, you might:
|
||||
// - Insert into an audit log table
|
||||
// - Send to a logging service
|
||||
// - Write to a file
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleCacheHook implements simple response caching
|
||||
func ExampleCacheHook(ctx *HookContext) error {
|
||||
// This is a simplified example - real caching would use a proper cache store
|
||||
// Check if we have a cached result for this query
|
||||
// cacheKey := fmt.Sprintf("%s:%s", ctx.UserContext.UserName, ctx.SQLQuery)
|
||||
// if cachedResult := checkCache(cacheKey); cachedResult != nil {
|
||||
// ctx.Result = cachedResult
|
||||
// ctx.Abort = true // Skip query execution
|
||||
// ctx.AbortMessage = "Serving from cache"
|
||||
// }
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExampleErrorHandlingHook provides custom error handling
|
||||
func ExampleErrorHandlingHook(ctx *HookContext) error {
|
||||
if ctx.Error != nil {
|
||||
// Log error with context
|
||||
logger.Error("Query failed for user %s: %v\nQuery: %s",
|
||||
ctx.UserContext.UserName,
|
||||
ctx.Error,
|
||||
ctx.SQLQuery,
|
||||
)
|
||||
|
||||
// You could send notifications, update metrics, etc.
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Example of registering hooks:
|
||||
//
|
||||
// func SetupHooks(handler *Handler) {
|
||||
// hooks := handler.Hooks()
|
||||
//
|
||||
// // Register security hook before query execution
|
||||
// hooks.Register(BeforeQuery, ExampleSecurityHook)
|
||||
// hooks.Register(BeforeQueryList, ExampleSecurityHook)
|
||||
//
|
||||
// // Register logging hook before SQL execution
|
||||
// hooks.Register(BeforeSQLExec, ExampleLoggingHook)
|
||||
//
|
||||
// // Register result filtering after query
|
||||
// hooks.Register(AfterQuery, ExampleResultFilterHook)
|
||||
// hooks.Register(AfterQueryList, ExampleResultFilterHook)
|
||||
//
|
||||
// // Register audit hook after execution
|
||||
// hooks.RegisterMultiple([]HookType{AfterQuery, AfterQueryList}, ExampleAuditHook)
|
||||
// }
|
||||
589
pkg/funcspec/hooks_test.go
Normal file
589
pkg/funcspec/hooks_test.go
Normal file
@@ -0,0 +1,589 @@
|
||||
package funcspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// TestNewHookRegistry tests hook registry creation
|
||||
func TestNewHookRegistry(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
if registry == nil {
|
||||
t.Fatal("Expected registry to be created, got nil")
|
||||
}
|
||||
|
||||
if registry.hooks == nil {
|
||||
t.Error("Expected hooks map to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegisterHook tests registering a single hook
|
||||
func TestRegisterHook(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
hookCalled := false
|
||||
testHook := func(ctx *HookContext) error {
|
||||
hookCalled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeQuery, testHook)
|
||||
|
||||
if !registry.HasHooks(BeforeQuery) {
|
||||
t.Error("Expected hook to be registered")
|
||||
}
|
||||
|
||||
if registry.Count(BeforeQuery) != 1 {
|
||||
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeQuery))
|
||||
}
|
||||
|
||||
// Execute the hook
|
||||
ctx := &HookContext{}
|
||||
err := registry.Execute(BeforeQuery, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Hook execution failed: %v", err)
|
||||
}
|
||||
|
||||
if !hookCalled {
|
||||
t.Error("Expected hook to be called")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegisterMultipleHooks tests registering multiple hooks for same type
|
||||
func TestRegisterMultipleHooks(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
callOrder := []int{}
|
||||
|
||||
hook1 := func(ctx *HookContext) error {
|
||||
callOrder = append(callOrder, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
hook2 := func(ctx *HookContext) error {
|
||||
callOrder = append(callOrder, 2)
|
||||
return nil
|
||||
}
|
||||
|
||||
hook3 := func(ctx *HookContext) error {
|
||||
callOrder = append(callOrder, 3)
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeQuery, hook1)
|
||||
registry.Register(BeforeQuery, hook2)
|
||||
registry.Register(BeforeQuery, hook3)
|
||||
|
||||
if registry.Count(BeforeQuery) != 3 {
|
||||
t.Errorf("Expected 3 hooks, got %d", registry.Count(BeforeQuery))
|
||||
}
|
||||
|
||||
// Execute hooks
|
||||
ctx := &HookContext{}
|
||||
err := registry.Execute(BeforeQuery, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Hook execution failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify hooks were called in order
|
||||
if len(callOrder) != 3 {
|
||||
t.Errorf("Expected 3 hooks to be called, got %d", len(callOrder))
|
||||
}
|
||||
|
||||
for i, expected := range []int{1, 2, 3} {
|
||||
if callOrder[i] != expected {
|
||||
t.Errorf("Expected hook %d at position %d, got %d", expected, i, callOrder[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRegisterMultipleHookTypes tests registering a hook for multiple types
|
||||
func TestRegisterMultipleHookTypes(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
callCount := 0
|
||||
testHook := func(ctx *HookContext) error {
|
||||
callCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
hookTypes := []HookType{BeforeQuery, AfterQuery, BeforeSQLExec}
|
||||
registry.RegisterMultiple(hookTypes, testHook)
|
||||
|
||||
// Verify hook is registered for all types
|
||||
for _, hookType := range hookTypes {
|
||||
if !registry.HasHooks(hookType) {
|
||||
t.Errorf("Expected hook to be registered for %s", hookType)
|
||||
}
|
||||
|
||||
if registry.Count(hookType) != 1 {
|
||||
t.Errorf("Expected 1 hook for %s, got %d", hookType, registry.Count(hookType))
|
||||
}
|
||||
}
|
||||
|
||||
// Execute each hook type
|
||||
ctx := &HookContext{}
|
||||
for _, hookType := range hookTypes {
|
||||
if err := registry.Execute(hookType, ctx); err != nil {
|
||||
t.Errorf("Hook execution failed for %s: %v", hookType, err)
|
||||
}
|
||||
}
|
||||
|
||||
if callCount != 3 {
|
||||
t.Errorf("Expected hook to be called 3 times, got %d", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHookError tests hook error handling
|
||||
func TestHookError(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
expectedError := fmt.Errorf("test error")
|
||||
errorHook := func(ctx *HookContext) error {
|
||||
return expectedError
|
||||
}
|
||||
|
||||
registry.Register(BeforeQuery, errorHook)
|
||||
|
||||
ctx := &HookContext{}
|
||||
err := registry.Execute(BeforeQuery, ctx)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error from hook, got nil")
|
||||
}
|
||||
|
||||
if err.Error() != fmt.Sprintf("hook execution failed: %v", expectedError) {
|
||||
t.Errorf("Expected error message to contain hook error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHookAbort tests hook abort functionality
|
||||
func TestHookAbort(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
abortHook := func(ctx *HookContext) error {
|
||||
ctx.Abort = true
|
||||
ctx.AbortMessage = "Operation aborted by hook"
|
||||
ctx.AbortCode = 403
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeQuery, abortHook)
|
||||
|
||||
ctx := &HookContext{}
|
||||
err := registry.Execute(BeforeQuery, ctx)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error when hook aborts, got nil")
|
||||
}
|
||||
|
||||
if !ctx.Abort {
|
||||
t.Error("Expected Abort to be true")
|
||||
}
|
||||
|
||||
if ctx.AbortMessage != "Operation aborted by hook" {
|
||||
t.Errorf("Expected abort message, got: %s", ctx.AbortMessage)
|
||||
}
|
||||
|
||||
if ctx.AbortCode != 403 {
|
||||
t.Errorf("Expected abort code 403, got: %d", ctx.AbortCode)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHookChainWithError tests that hook chain stops on first error
|
||||
func TestHookChainWithError(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
callOrder := []int{}
|
||||
|
||||
hook1 := func(ctx *HookContext) error {
|
||||
callOrder = append(callOrder, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
hook2 := func(ctx *HookContext) error {
|
||||
callOrder = append(callOrder, 2)
|
||||
return fmt.Errorf("error in hook 2")
|
||||
}
|
||||
|
||||
hook3 := func(ctx *HookContext) error {
|
||||
callOrder = append(callOrder, 3)
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeQuery, hook1)
|
||||
registry.Register(BeforeQuery, hook2)
|
||||
registry.Register(BeforeQuery, hook3)
|
||||
|
||||
ctx := &HookContext{}
|
||||
err := registry.Execute(BeforeQuery, ctx)
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error from hook chain")
|
||||
}
|
||||
|
||||
// Only first two hooks should have been called
|
||||
if len(callOrder) != 2 {
|
||||
t.Errorf("Expected 2 hooks to be called, got %d", len(callOrder))
|
||||
}
|
||||
|
||||
if callOrder[0] != 1 || callOrder[1] != 2 {
|
||||
t.Errorf("Expected hooks 1 and 2 to be called, got: %v", callOrder)
|
||||
}
|
||||
}
|
||||
|
||||
// TestClearHooks tests clearing hooks
|
||||
func TestClearHooks(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
testHook := func(ctx *HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeQuery, testHook)
|
||||
registry.Register(AfterQuery, testHook)
|
||||
|
||||
if !registry.HasHooks(BeforeQuery) {
|
||||
t.Error("Expected BeforeQuery hook to be registered")
|
||||
}
|
||||
|
||||
registry.Clear(BeforeQuery)
|
||||
|
||||
if registry.HasHooks(BeforeQuery) {
|
||||
t.Error("Expected BeforeQuery hooks to be cleared")
|
||||
}
|
||||
|
||||
if !registry.HasHooks(AfterQuery) {
|
||||
t.Error("Expected AfterQuery hook to still be registered")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClearAllHooks tests clearing all hooks
|
||||
func TestClearAllHooks(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
testHook := func(ctx *HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeQuery, testHook)
|
||||
registry.Register(AfterQuery, testHook)
|
||||
registry.Register(BeforeSQLExec, testHook)
|
||||
|
||||
registry.ClearAll()
|
||||
|
||||
if registry.HasHooks(BeforeQuery) || registry.HasHooks(AfterQuery) || registry.HasHooks(BeforeSQLExec) {
|
||||
t.Error("Expected all hooks to be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetAllHookTypes tests getting all registered hook types
|
||||
func TestGetAllHookTypes(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
testHook := func(ctx *HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeQuery, testHook)
|
||||
registry.Register(AfterQuery, testHook)
|
||||
|
||||
types := registry.GetAllHookTypes()
|
||||
|
||||
if len(types) != 2 {
|
||||
t.Errorf("Expected 2 hook types, got %d", len(types))
|
||||
}
|
||||
|
||||
// Verify the types are present
|
||||
foundBefore := false
|
||||
foundAfter := false
|
||||
for _, hookType := range types {
|
||||
if hookType == BeforeQuery {
|
||||
foundBefore = true
|
||||
}
|
||||
if hookType == AfterQuery {
|
||||
foundAfter = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundBefore || !foundAfter {
|
||||
t.Error("Expected both BeforeQuery and AfterQuery hook types")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHookContextModification tests that hooks can modify the context
|
||||
func TestHookContextModification(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
// Hook that modifies SQL query
|
||||
modifyHook := func(ctx *HookContext) error {
|
||||
ctx.SQLQuery = "SELECT * FROM modified_table"
|
||||
ctx.Variables["new_var"] = "new_value"
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeQuery, modifyHook)
|
||||
|
||||
ctx := &HookContext{
|
||||
SQLQuery: "SELECT * FROM original_table",
|
||||
Variables: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
err := registry.Execute(BeforeQuery, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Hook execution failed: %v", err)
|
||||
}
|
||||
|
||||
if ctx.SQLQuery != "SELECT * FROM modified_table" {
|
||||
t.Errorf("Expected SQL query to be modified, got: %s", ctx.SQLQuery)
|
||||
}
|
||||
|
||||
if ctx.Variables["new_var"] != "new_value" {
|
||||
t.Errorf("Expected variable to be added, got: %v", ctx.Variables)
|
||||
}
|
||||
}
|
||||
|
||||
// TestExampleHooks tests the example hooks
|
||||
func TestExampleLoggingHook(t *testing.T) {
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
SQLQuery: "SELECT * FROM test",
|
||||
UserContext: &security.UserContext{
|
||||
UserName: "testuser",
|
||||
},
|
||||
}
|
||||
|
||||
err := ExampleLoggingHook(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("ExampleLoggingHook failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExampleSecurityHook(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
userID int
|
||||
shouldAbort bool
|
||||
}{
|
||||
{
|
||||
name: "Admin accessing sensitive table",
|
||||
sqlQuery: "SELECT * FROM sensitive_table",
|
||||
userID: 1,
|
||||
shouldAbort: false,
|
||||
},
|
||||
{
|
||||
name: "Non-admin accessing sensitive table",
|
||||
sqlQuery: "SELECT * FROM sensitive_table",
|
||||
userID: 2,
|
||||
shouldAbort: true,
|
||||
},
|
||||
{
|
||||
name: "Non-admin accessing normal table",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
userID: 2,
|
||||
shouldAbort: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
SQLQuery: tt.sqlQuery,
|
||||
UserContext: &security.UserContext{
|
||||
UserID: tt.userID,
|
||||
},
|
||||
}
|
||||
|
||||
_ = ExampleSecurityHook(ctx)
|
||||
|
||||
if tt.shouldAbort {
|
||||
if !ctx.Abort {
|
||||
t.Error("Expected security hook to abort operation")
|
||||
}
|
||||
if ctx.AbortCode != 403 {
|
||||
t.Errorf("Expected abort code 403, got %d", ctx.AbortCode)
|
||||
}
|
||||
} else {
|
||||
if ctx.Abort {
|
||||
t.Error("Expected security hook not to abort operation")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExampleResultFilterHook(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int
|
||||
result interface{}
|
||||
validate func(t *testing.T, result interface{})
|
||||
}{
|
||||
{
|
||||
name: "Admin user - no filtering",
|
||||
userID: 1,
|
||||
result: map[string]interface{}{
|
||||
"id": 1,
|
||||
"name": "Test",
|
||||
"password": "secret",
|
||||
},
|
||||
validate: func(t *testing.T, result interface{}) {
|
||||
m := result.(map[string]interface{})
|
||||
if _, exists := m["password"]; !exists {
|
||||
t.Error("Expected password field to remain for admin")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Regular user - sensitive fields removed",
|
||||
userID: 2,
|
||||
result: map[string]interface{}{
|
||||
"id": 1,
|
||||
"name": "Test",
|
||||
"password": "secret",
|
||||
"ssn": "123-45-6789",
|
||||
},
|
||||
validate: func(t *testing.T, result interface{}) {
|
||||
m := result.(map[string]interface{})
|
||||
if _, exists := m["password"]; exists {
|
||||
t.Error("Expected password field to be removed")
|
||||
}
|
||||
if _, exists := m["ssn"]; exists {
|
||||
t.Error("Expected ssn field to be removed")
|
||||
}
|
||||
if _, exists := m["name"]; !exists {
|
||||
t.Error("Expected name field to remain")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Regular user - list results filtered",
|
||||
userID: 2,
|
||||
result: []map[string]interface{}{
|
||||
{"id": 1, "name": "User 1", "password": "secret1"},
|
||||
{"id": 2, "name": "User 2", "password": "secret2"},
|
||||
},
|
||||
validate: func(t *testing.T, result interface{}) {
|
||||
list := result.([]map[string]interface{})
|
||||
for _, m := range list {
|
||||
if _, exists := m["password"]; exists {
|
||||
t.Error("Expected password field to be removed from list")
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Result: tt.result,
|
||||
UserContext: &security.UserContext{
|
||||
UserID: tt.userID,
|
||||
},
|
||||
}
|
||||
|
||||
err := ExampleResultFilterHook(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Hook failed: %v", err)
|
||||
}
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, ctx.Result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExampleAuditHook(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/api/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Request: req,
|
||||
UserContext: &security.UserContext{
|
||||
UserID: 123,
|
||||
UserName: "testuser",
|
||||
},
|
||||
}
|
||||
|
||||
err := ExampleAuditHook(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("ExampleAuditHook failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExampleErrorHandlingHook(t *testing.T) {
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
SQLQuery: "SELECT * FROM test",
|
||||
Error: fmt.Errorf("test error"),
|
||||
UserContext: &security.UserContext{
|
||||
UserName: "testuser",
|
||||
},
|
||||
}
|
||||
|
||||
err := ExampleErrorHandlingHook(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("ExampleErrorHandlingHook failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHookIntegrationWithHandler tests hooks integrated with the handler
|
||||
func TestHookIntegrationWithHandler(t *testing.T) {
|
||||
db := &MockDatabase{
|
||||
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
|
||||
queryDB := &MockDatabase{
|
||||
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
rows := dest.(*[]map[string]interface{})
|
||||
*rows = []map[string]interface{}{
|
||||
{"id": float64(1), "name": "Test User"},
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
return fn(queryDB)
|
||||
},
|
||||
}
|
||||
|
||||
handler := NewHandler(db)
|
||||
|
||||
// Register a hook that modifies the SQL query
|
||||
hookCalled := false
|
||||
handler.Hooks().Register(BeforeSQLExec, func(ctx *HookContext) error {
|
||||
hookCalled = true
|
||||
// Verify we can access context data
|
||||
if ctx.SQLQuery == "" {
|
||||
t.Error("Expected SQL query to be set")
|
||||
}
|
||||
if ctx.UserContext == nil {
|
||||
t.Error("Expected user context to be set")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Execute a query
|
||||
req := createTestRequest("GET", "/test", nil, nil, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", false)
|
||||
handlerFunc(w, req)
|
||||
|
||||
if !hookCalled {
|
||||
t.Error("Expected hook to be called during query execution")
|
||||
}
|
||||
|
||||
if w.Code != 200 {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
411
pkg/funcspec/parameters.go
Normal file
411
pkg/funcspec/parameters.go
Normal file
@@ -0,0 +1,411 @@
|
||||
package funcspec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
|
||||
)
|
||||
|
||||
// RequestParameters holds parsed parameters from headers and query string
|
||||
type RequestParameters struct {
|
||||
// Field selection
|
||||
SelectFields []string
|
||||
NotSelectFields []string
|
||||
Distinct bool
|
||||
|
||||
// Filtering
|
||||
FieldFilters map[string]string // column -> value (exact match)
|
||||
SearchFilters map[string]string // column -> value (ILIKE)
|
||||
SearchOps map[string]FilterOperator // column -> {operator, value, logic}
|
||||
CustomSQLWhere string
|
||||
CustomSQLOr string
|
||||
|
||||
// Sorting & Pagination
|
||||
SortColumns string
|
||||
Limit int
|
||||
Offset int
|
||||
|
||||
// Advanced features
|
||||
SkipCount bool
|
||||
SkipCache bool
|
||||
|
||||
// Response format
|
||||
ResponseFormat string // "simple", "detail", "syncfusion"
|
||||
ComplexAPI bool // true if NOT simple API
|
||||
}
|
||||
|
||||
// FilterOperator represents a filter with operator
|
||||
type FilterOperator struct {
|
||||
Operator string // eq, neq, gt, lt, gte, lte, like, ilike, in, between, etc.
|
||||
Value string
|
||||
Logic string // AND or OR
|
||||
}
|
||||
|
||||
// ParseParameters parses all parameters from request headers and query string
|
||||
func (h *Handler) ParseParameters(r *http.Request) *RequestParameters {
|
||||
params := &RequestParameters{
|
||||
FieldFilters: make(map[string]string),
|
||||
SearchFilters: make(map[string]string),
|
||||
SearchOps: make(map[string]FilterOperator),
|
||||
Limit: 20, // Default limit
|
||||
Offset: 0, // Default offset
|
||||
ResponseFormat: "simple", // Default format
|
||||
ComplexAPI: false, // Default to simple API
|
||||
}
|
||||
|
||||
// Merge headers and query parameters
|
||||
combined := make(map[string]string)
|
||||
|
||||
// Add all headers (normalize to lowercase)
|
||||
for key, values := range r.Header {
|
||||
if len(values) > 0 {
|
||||
combined[strings.ToLower(key)] = values[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Add all query parameters (override headers)
|
||||
for key, values := range r.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
combined[strings.ToLower(key)] = values[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Parse each parameter
|
||||
for key, value := range combined {
|
||||
// Decode value if base64 encoded
|
||||
decodedValue := h.decodeValue(value)
|
||||
|
||||
switch {
|
||||
// Field Selection
|
||||
case strings.HasPrefix(key, "x-select-fields"):
|
||||
params.SelectFields = h.parseCommaSeparated(decodedValue)
|
||||
case strings.HasPrefix(key, "x-not-select-fields"):
|
||||
params.NotSelectFields = h.parseCommaSeparated(decodedValue)
|
||||
case strings.HasPrefix(key, "x-distinct"):
|
||||
params.Distinct = strings.EqualFold(decodedValue, "true")
|
||||
|
||||
// Filtering
|
||||
case strings.HasPrefix(key, "x-fieldfilter-"):
|
||||
colName := strings.TrimPrefix(key, "x-fieldfilter-")
|
||||
params.FieldFilters[colName] = decodedValue
|
||||
case strings.HasPrefix(key, "x-searchfilter-"):
|
||||
colName := strings.TrimPrefix(key, "x-searchfilter-")
|
||||
params.SearchFilters[colName] = decodedValue
|
||||
case strings.HasPrefix(key, "x-searchop-"):
|
||||
h.parseSearchOp(params, key, decodedValue, "AND")
|
||||
case strings.HasPrefix(key, "x-searchor-"):
|
||||
h.parseSearchOp(params, key, decodedValue, "OR")
|
||||
case strings.HasPrefix(key, "x-searchand-"):
|
||||
h.parseSearchOp(params, key, decodedValue, "AND")
|
||||
case strings.HasPrefix(key, "x-custom-sql-w"):
|
||||
if params.CustomSQLWhere != "" {
|
||||
params.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", params.CustomSQLWhere, decodedValue)
|
||||
} else {
|
||||
params.CustomSQLWhere = decodedValue
|
||||
}
|
||||
case strings.HasPrefix(key, "x-custom-sql-or"):
|
||||
if params.CustomSQLOr != "" {
|
||||
params.CustomSQLOr = fmt.Sprintf("%s OR (%s)", params.CustomSQLOr, decodedValue)
|
||||
} else {
|
||||
params.CustomSQLOr = decodedValue
|
||||
}
|
||||
|
||||
// Sorting & Pagination
|
||||
case key == "sort" || strings.HasPrefix(key, "x-sort"):
|
||||
params.SortColumns = decodedValue
|
||||
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
|
||||
// Handle sort(col1,-col2) syntax
|
||||
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
|
||||
params.SortColumns = sortValue
|
||||
case key == "limit" || strings.HasPrefix(key, "x-limit"):
|
||||
if limit, err := strconv.Atoi(decodedValue); err == nil && limit > 0 {
|
||||
params.Limit = limit
|
||||
}
|
||||
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
|
||||
// Handle limit(offset,limit) or limit(limit) syntax
|
||||
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
|
||||
parts := strings.Split(limitValue, ",")
|
||||
if len(parts) > 1 {
|
||||
if offset, err := strconv.Atoi(parts[0]); err == nil {
|
||||
params.Offset = offset
|
||||
}
|
||||
if limit, err := strconv.Atoi(parts[1]); err == nil {
|
||||
params.Limit = limit
|
||||
}
|
||||
} else {
|
||||
if limit, err := strconv.Atoi(parts[0]); err == nil {
|
||||
params.Limit = limit
|
||||
}
|
||||
}
|
||||
case key == "offset" || strings.HasPrefix(key, "x-offset"):
|
||||
if offset, err := strconv.Atoi(decodedValue); err == nil && offset >= 0 {
|
||||
params.Offset = offset
|
||||
}
|
||||
|
||||
// Advanced features
|
||||
case strings.HasPrefix(key, "x-skipcount"):
|
||||
params.SkipCount = strings.EqualFold(decodedValue, "true")
|
||||
case strings.HasPrefix(key, "x-skipcache"):
|
||||
params.SkipCache = strings.EqualFold(decodedValue, "true")
|
||||
|
||||
// Response Format
|
||||
case strings.HasPrefix(key, "x-simpleapi"):
|
||||
params.ResponseFormat = "simple"
|
||||
params.ComplexAPI = decodedValue != "1" && !strings.EqualFold(decodedValue, "true")
|
||||
case strings.HasPrefix(key, "x-detailapi"):
|
||||
params.ResponseFormat = "detail"
|
||||
params.ComplexAPI = true
|
||||
case strings.HasPrefix(key, "x-syncfusion"):
|
||||
params.ResponseFormat = "syncfusion"
|
||||
params.ComplexAPI = true
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// parseSearchOp parses x-searchop-{operator}-{column} or x-searchor-{operator}-{column}
|
||||
func (h *Handler) parseSearchOp(params *RequestParameters, headerKey, value, logic string) {
|
||||
var prefix string
|
||||
if logic == "OR" {
|
||||
prefix = "x-searchor-"
|
||||
} else {
|
||||
prefix = "x-searchop-"
|
||||
if strings.HasPrefix(headerKey, "x-searchand-") {
|
||||
prefix = "x-searchand-"
|
||||
}
|
||||
}
|
||||
|
||||
rest := strings.TrimPrefix(headerKey, prefix)
|
||||
parts := strings.SplitN(rest, "-", 2)
|
||||
if len(parts) != 2 {
|
||||
logger.Warn("Invalid search operator header format: %s", headerKey)
|
||||
return
|
||||
}
|
||||
|
||||
operator := parts[0]
|
||||
colName := parts[1]
|
||||
|
||||
params.SearchOps[colName] = FilterOperator{
|
||||
Operator: operator,
|
||||
Value: value,
|
||||
Logic: logic,
|
||||
}
|
||||
|
||||
logger.Debug("%s search operator: %s %s %s", logic, colName, operator, value)
|
||||
}
|
||||
|
||||
// decodeValue decodes base64 encoded values (ZIP_ or __ prefix)
|
||||
func (h *Handler) decodeValue(value string) string {
|
||||
decoded, _ := restheadspec.DecodeParam(value)
|
||||
return decoded
|
||||
}
|
||||
|
||||
// parseCommaSeparated parses comma-separated values
|
||||
func (h *Handler) parseCommaSeparated(value string) []string {
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(value, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part != "" {
|
||||
result = append(result, part)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// ApplyFieldSelection applies column selection to SQL query
|
||||
func (h *Handler) ApplyFieldSelection(sqlQuery string, params *RequestParameters) string {
|
||||
if len(params.SelectFields) == 0 && len(params.NotSelectFields) == 0 {
|
||||
return sqlQuery
|
||||
}
|
||||
|
||||
// This is a simplified implementation
|
||||
// A full implementation would parse the SQL and replace the SELECT clause
|
||||
// For now, we log a warning that this feature needs manual implementation
|
||||
if len(params.SelectFields) > 0 {
|
||||
logger.Debug("Field selection requested: %v (manual SQL adjustment may be needed)", params.SelectFields)
|
||||
}
|
||||
if len(params.NotSelectFields) > 0 {
|
||||
logger.Debug("Field exclusion requested: %v (manual SQL adjustment may be needed)", params.NotSelectFields)
|
||||
}
|
||||
|
||||
return sqlQuery
|
||||
}
|
||||
|
||||
// ApplyFilters applies all filters to the SQL query
|
||||
func (h *Handler) ApplyFilters(sqlQuery string, params *RequestParameters) string {
|
||||
// Apply field filters (exact match)
|
||||
for colName, value := range params.FieldFilters {
|
||||
condition := ""
|
||||
if value == "" || value == "0" {
|
||||
condition = fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
|
||||
} else {
|
||||
condition = fmt.Sprintf("%s = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
|
||||
}
|
||||
sqlQuery = sqlQryWhere(sqlQuery, condition)
|
||||
logger.Debug("Applied field filter: %s", condition)
|
||||
}
|
||||
|
||||
// Apply search filters (ILIKE)
|
||||
for colName, value := range params.SearchFilters {
|
||||
sval := strings.ReplaceAll(value, "'", "")
|
||||
if sval != "" {
|
||||
condition := fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colName, "colname"), ValidSQL(sval, "colvalue"))
|
||||
sqlQuery = sqlQryWhere(sqlQuery, condition)
|
||||
logger.Debug("Applied search filter: %s", condition)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply search operators
|
||||
for colName, filterOp := range params.SearchOps {
|
||||
condition := h.buildFilterCondition(colName, filterOp)
|
||||
if condition != "" {
|
||||
if filterOp.Logic == "OR" {
|
||||
sqlQuery = sqlQryWhereOr(sqlQuery, condition)
|
||||
} else {
|
||||
sqlQuery = sqlQryWhere(sqlQuery, condition)
|
||||
}
|
||||
logger.Debug("Applied search operator: %s", condition)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom SQL WHERE
|
||||
if params.CustomSQLWhere != "" {
|
||||
colval := ValidSQL(params.CustomSQLWhere, "select")
|
||||
if colval != "" {
|
||||
sqlQuery = sqlQryWhere(sqlQuery, colval)
|
||||
logger.Debug("Applied custom SQL WHERE: %s", colval)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom SQL OR
|
||||
if params.CustomSQLOr != "" {
|
||||
colval := ValidSQL(params.CustomSQLOr, "select")
|
||||
if colval != "" {
|
||||
sqlQuery = sqlQryWhereOr(sqlQuery, colval)
|
||||
logger.Debug("Applied custom SQL OR: %s", colval)
|
||||
}
|
||||
}
|
||||
|
||||
return sqlQuery
|
||||
}
|
||||
|
||||
// buildFilterCondition builds a SQL condition from a FilterOperator
|
||||
func (h *Handler) buildFilterCondition(colName string, op FilterOperator) string {
|
||||
safCol := ValidSQL(colName, "colname")
|
||||
operator := strings.ToLower(op.Operator)
|
||||
value := op.Value
|
||||
|
||||
switch operator {
|
||||
case "contains", "contain", "like":
|
||||
return fmt.Sprintf("%s ILIKE '%%%s%%'", safCol, ValidSQL(value, "colvalue"))
|
||||
case "beginswith", "startswith":
|
||||
return fmt.Sprintf("%s ILIKE '%s%%'", safCol, ValidSQL(value, "colvalue"))
|
||||
case "endswith":
|
||||
return fmt.Sprintf("%s ILIKE '%%%s'", safCol, ValidSQL(value, "colvalue"))
|
||||
case "equals", "eq", "=":
|
||||
if IsNumeric(value) {
|
||||
return fmt.Sprintf("%s = %s", safCol, ValidSQL(value, "colvalue"))
|
||||
}
|
||||
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
|
||||
case "notequals", "neq", "ne", "!=", "<>":
|
||||
if IsNumeric(value) {
|
||||
return fmt.Sprintf("%s != %s", safCol, ValidSQL(value, "colvalue"))
|
||||
}
|
||||
return fmt.Sprintf("%s != '%s'", safCol, ValidSQL(value, "colvalue"))
|
||||
case "greaterthan", "gt", ">":
|
||||
return fmt.Sprintf("%s > %s", safCol, ValidSQL(value, "colvalue"))
|
||||
case "lessthan", "lt", "<":
|
||||
return fmt.Sprintf("%s < %s", safCol, ValidSQL(value, "colvalue"))
|
||||
case "greaterthanorequal", "gte", "ge", ">=":
|
||||
return fmt.Sprintf("%s >= %s", safCol, ValidSQL(value, "colvalue"))
|
||||
case "lessthanorequal", "lte", "le", "<=":
|
||||
return fmt.Sprintf("%s <= %s", safCol, ValidSQL(value, "colvalue"))
|
||||
case "between":
|
||||
parts := strings.Split(value, ",")
|
||||
if len(parts) == 2 {
|
||||
return fmt.Sprintf("%s > %s AND %s < %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
|
||||
}
|
||||
case "betweeninclusive":
|
||||
parts := strings.Split(value, ",")
|
||||
if len(parts) == 2 {
|
||||
return fmt.Sprintf("%s >= %s AND %s <= %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
|
||||
}
|
||||
case "in":
|
||||
values := strings.Split(value, ",")
|
||||
safeValues := make([]string, len(values))
|
||||
for i, v := range values {
|
||||
safeValues[i] = fmt.Sprintf("'%s'", ValidSQL(v, "colvalue"))
|
||||
}
|
||||
return fmt.Sprintf("%s IN (%s)", safCol, strings.Join(safeValues, ", "))
|
||||
case "empty", "isnull", "null":
|
||||
return fmt.Sprintf("(%s IS NULL OR %s = '')", safCol, safCol)
|
||||
case "notempty", "isnotnull", "notnull":
|
||||
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", safCol, safCol)
|
||||
default:
|
||||
logger.Warn("Unknown filter operator: %s, defaulting to equals", operator)
|
||||
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// ApplyDistinct adds DISTINCT to SQL query if requested
|
||||
func (h *Handler) ApplyDistinct(sqlQuery string, params *RequestParameters) string {
|
||||
if !params.Distinct {
|
||||
return sqlQuery
|
||||
}
|
||||
|
||||
// Add DISTINCT after SELECT
|
||||
selectPos := strings.Index(strings.ToUpper(sqlQuery), "SELECT")
|
||||
if selectPos >= 0 {
|
||||
beforeSelect := sqlQuery[:selectPos+6] // "SELECT"
|
||||
afterSelect := sqlQuery[selectPos+6:]
|
||||
sqlQuery = beforeSelect + " DISTINCT" + afterSelect
|
||||
logger.Debug("Applied DISTINCT to query")
|
||||
}
|
||||
|
||||
return sqlQuery
|
||||
}
|
||||
|
||||
// sqlQryWhereOr adds a WHERE clause with OR logic
|
||||
func sqlQryWhereOr(sqlquery, condition string) string {
|
||||
lowerQuery := strings.ToLower(sqlquery)
|
||||
wherePos := strings.Index(lowerQuery, " where ")
|
||||
groupPos := strings.Index(lowerQuery, " group by")
|
||||
orderPos := strings.Index(lowerQuery, " order by")
|
||||
limitPos := strings.Index(lowerQuery, " limit ")
|
||||
|
||||
// Find the insertion point
|
||||
insertPos := len(sqlquery)
|
||||
if groupPos > 0 && groupPos < insertPos {
|
||||
insertPos = groupPos
|
||||
}
|
||||
if orderPos > 0 && orderPos < insertPos {
|
||||
insertPos = orderPos
|
||||
}
|
||||
if limitPos > 0 && limitPos < insertPos {
|
||||
insertPos = limitPos
|
||||
}
|
||||
|
||||
if wherePos > 0 {
|
||||
// WHERE exists, add OR condition
|
||||
before := sqlquery[:insertPos]
|
||||
after := sqlquery[insertPos:]
|
||||
return fmt.Sprintf("%s OR (%s) %s", before, condition, after)
|
||||
} else {
|
||||
// No WHERE exists, add it
|
||||
before := sqlquery[:insertPos]
|
||||
after := sqlquery[insertPos:]
|
||||
return fmt.Sprintf("%s WHERE %s %s", before, condition, after)
|
||||
}
|
||||
}
|
||||
549
pkg/funcspec/parameters_test.go
Normal file
549
pkg/funcspec/parameters_test.go
Normal file
@@ -0,0 +1,549 @@
|
||||
package funcspec
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestParseParameters tests the comprehensive parameter parsing
|
||||
func TestParseParameters(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
queryParams map[string]string
|
||||
headers map[string]string
|
||||
validate func(t *testing.T, params *RequestParameters)
|
||||
}{
|
||||
{
|
||||
name: "Parse field selection",
|
||||
headers: map[string]string{
|
||||
"X-Select-Fields": "id,name,email",
|
||||
"X-Not-Select-Fields": "password,ssn",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if len(params.SelectFields) != 3 {
|
||||
t.Errorf("Expected 3 select fields, got %d", len(params.SelectFields))
|
||||
}
|
||||
if len(params.NotSelectFields) != 2 {
|
||||
t.Errorf("Expected 2 not-select fields, got %d", len(params.NotSelectFields))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse distinct flag",
|
||||
headers: map[string]string{
|
||||
"X-Distinct": "true",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if !params.Distinct {
|
||||
t.Error("Expected Distinct to be true")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse field filters",
|
||||
headers: map[string]string{
|
||||
"X-FieldFilter-Status": "active",
|
||||
"X-FieldFilter-Type": "admin",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if len(params.FieldFilters) != 2 {
|
||||
t.Errorf("Expected 2 field filters, got %d", len(params.FieldFilters))
|
||||
}
|
||||
if params.FieldFilters["status"] != "active" {
|
||||
t.Errorf("Expected status filter=active, got %s", params.FieldFilters["status"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse search filters",
|
||||
headers: map[string]string{
|
||||
"X-SearchFilter-Name": "john",
|
||||
"X-SearchFilter-Email": "test",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if len(params.SearchFilters) != 2 {
|
||||
t.Errorf("Expected 2 search filters, got %d", len(params.SearchFilters))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse sort columns",
|
||||
queryParams: map[string]string{
|
||||
"sort": "-created_at,name",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if params.SortColumns != "-created_at,name" {
|
||||
t.Errorf("Expected sort columns=-created_at,name, got %s", params.SortColumns)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse limit and offset",
|
||||
queryParams: map[string]string{
|
||||
"limit": "100",
|
||||
"offset": "50",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if params.Limit != 100 {
|
||||
t.Errorf("Expected limit=100, got %d", params.Limit)
|
||||
}
|
||||
if params.Offset != 50 {
|
||||
t.Errorf("Expected offset=50, got %d", params.Offset)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse skip count",
|
||||
headers: map[string]string{
|
||||
"X-SkipCount": "true",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if !params.SkipCount {
|
||||
t.Error("Expected SkipCount to be true")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse response format - syncfusion",
|
||||
headers: map[string]string{
|
||||
"X-Syncfusion": "true",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if params.ResponseFormat != "syncfusion" {
|
||||
t.Errorf("Expected ResponseFormat=syncfusion, got %s", params.ResponseFormat)
|
||||
}
|
||||
if !params.ComplexAPI {
|
||||
t.Error("Expected ComplexAPI to be true for syncfusion format")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse response format - detail",
|
||||
headers: map[string]string{
|
||||
"X-DetailAPI": "true",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if params.ResponseFormat != "detail" {
|
||||
t.Errorf("Expected ResponseFormat=detail, got %s", params.ResponseFormat)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse simple API",
|
||||
headers: map[string]string{
|
||||
"X-SimpleAPI": "true",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if params.ResponseFormat != "simple" {
|
||||
t.Errorf("Expected ResponseFormat=simple, got %s", params.ResponseFormat)
|
||||
}
|
||||
if params.ComplexAPI {
|
||||
t.Error("Expected ComplexAPI to be false for simple API")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse custom SQL WHERE",
|
||||
headers: map[string]string{
|
||||
"X-Custom-SQL-W": "status = 'active' AND deleted = false",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if params.CustomSQLWhere == "" {
|
||||
t.Error("Expected CustomSQLWhere to be set")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse search operators - AND",
|
||||
headers: map[string]string{
|
||||
"X-SearchOp-Eq-Name": "john",
|
||||
"X-SearchOp-Gt-Age": "18",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if len(params.SearchOps) != 2 {
|
||||
t.Errorf("Expected 2 search operators, got %d", len(params.SearchOps))
|
||||
}
|
||||
if op, exists := params.SearchOps["name"]; exists {
|
||||
if op.Operator != "eq" {
|
||||
t.Errorf("Expected operator=eq for name, got %s", op.Operator)
|
||||
}
|
||||
if op.Logic != "AND" {
|
||||
t.Errorf("Expected logic=AND, got %s", op.Logic)
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected name search operator to exist")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse search operators - OR",
|
||||
headers: map[string]string{
|
||||
"X-SearchOr-Like-Description": "test",
|
||||
},
|
||||
validate: func(t *testing.T, params *RequestParameters) {
|
||||
if op, exists := params.SearchOps["description"]; exists {
|
||||
if op.Logic != "OR" {
|
||||
t.Errorf("Expected logic=OR, got %s", op.Logic)
|
||||
}
|
||||
} else {
|
||||
t.Error("Expected description search operator to exist")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
|
||||
params := handler.ParseParameters(req)
|
||||
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, params)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildFilterCondition tests the filter condition builder
|
||||
func TestBuildFilterCondition(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
colName string
|
||||
operator FilterOperator
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Equals operator - numeric",
|
||||
colName: "age",
|
||||
operator: FilterOperator{
|
||||
Operator: "eq",
|
||||
Value: "25",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "age = 25",
|
||||
},
|
||||
{
|
||||
name: "Equals operator - string",
|
||||
colName: "name",
|
||||
operator: FilterOperator{
|
||||
Operator: "eq",
|
||||
Value: "john",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "name = 'john'",
|
||||
},
|
||||
{
|
||||
name: "Not equals operator",
|
||||
colName: "status",
|
||||
operator: FilterOperator{
|
||||
Operator: "neq",
|
||||
Value: "inactive",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "status != 'inactive'",
|
||||
},
|
||||
{
|
||||
name: "Greater than operator",
|
||||
colName: "age",
|
||||
operator: FilterOperator{
|
||||
Operator: "gt",
|
||||
Value: "18",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "age > 18",
|
||||
},
|
||||
{
|
||||
name: "Less than operator",
|
||||
colName: "price",
|
||||
operator: FilterOperator{
|
||||
Operator: "lt",
|
||||
Value: "100",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "price < 100",
|
||||
},
|
||||
{
|
||||
name: "Contains operator",
|
||||
colName: "description",
|
||||
operator: FilterOperator{
|
||||
Operator: "contains",
|
||||
Value: "test",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "description ILIKE '%test%'",
|
||||
},
|
||||
{
|
||||
name: "Starts with operator",
|
||||
colName: "name",
|
||||
operator: FilterOperator{
|
||||
Operator: "startswith",
|
||||
Value: "john",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "name ILIKE 'john%'",
|
||||
},
|
||||
{
|
||||
name: "Ends with operator",
|
||||
colName: "email",
|
||||
operator: FilterOperator{
|
||||
Operator: "endswith",
|
||||
Value: "@example.com",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "email ILIKE '%@example.com'",
|
||||
},
|
||||
{
|
||||
name: "Between operator",
|
||||
colName: "age",
|
||||
operator: FilterOperator{
|
||||
Operator: "between",
|
||||
Value: "18,65",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "age > 18 AND age < 65",
|
||||
},
|
||||
{
|
||||
name: "IN operator",
|
||||
colName: "status",
|
||||
operator: FilterOperator{
|
||||
Operator: "in",
|
||||
Value: "active,pending,approved",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "status IN ('active', 'pending', 'approved')",
|
||||
},
|
||||
{
|
||||
name: "IS NULL operator",
|
||||
colName: "deleted_at",
|
||||
operator: FilterOperator{
|
||||
Operator: "null",
|
||||
Value: "",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "(deleted_at IS NULL OR deleted_at = '')",
|
||||
},
|
||||
{
|
||||
name: "IS NOT NULL operator",
|
||||
colName: "created_at",
|
||||
operator: FilterOperator{
|
||||
Operator: "notnull",
|
||||
Value: "",
|
||||
Logic: "AND",
|
||||
},
|
||||
expected: "(created_at IS NOT NULL AND created_at != '')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.buildFilterCondition(tt.colName, tt.operator)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected: %s\nGot: %s", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyFilters tests the filter application to SQL queries
|
||||
func TestApplyFilters(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
params *RequestParameters
|
||||
expectedSQL string
|
||||
shouldContain []string
|
||||
}{
|
||||
{
|
||||
name: "Apply field filter",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
params: &RequestParameters{
|
||||
FieldFilters: map[string]string{
|
||||
"status": "active",
|
||||
},
|
||||
},
|
||||
shouldContain: []string{"WHERE", "status"},
|
||||
},
|
||||
{
|
||||
name: "Apply search filter",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
params: &RequestParameters{
|
||||
SearchFilters: map[string]string{
|
||||
"name": "john",
|
||||
},
|
||||
},
|
||||
shouldContain: []string{"WHERE", "name", "ILIKE"},
|
||||
},
|
||||
{
|
||||
name: "Apply search operators",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
params: &RequestParameters{
|
||||
SearchOps: map[string]FilterOperator{
|
||||
"age": {
|
||||
Operator: "gt",
|
||||
Value: "18",
|
||||
Logic: "AND",
|
||||
},
|
||||
},
|
||||
},
|
||||
shouldContain: []string{"WHERE", "age", ">", "18"},
|
||||
},
|
||||
{
|
||||
name: "Apply custom SQL WHERE",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
params: &RequestParameters{
|
||||
CustomSQLWhere: "deleted = false",
|
||||
},
|
||||
shouldContain: []string{"WHERE", "deleted"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.ApplyFilters(tt.sqlQuery, tt.params)
|
||||
|
||||
for _, expected := range tt.shouldContain {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestApplyDistinct tests DISTINCT application
|
||||
func TestApplyDistinct(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
distinct bool
|
||||
shouldHave string
|
||||
}{
|
||||
{
|
||||
name: "Apply DISTINCT",
|
||||
sqlQuery: "SELECT id, name FROM users",
|
||||
distinct: true,
|
||||
shouldHave: "SELECT DISTINCT",
|
||||
},
|
||||
{
|
||||
name: "Do not apply DISTINCT",
|
||||
sqlQuery: "SELECT id, name FROM users",
|
||||
distinct: false,
|
||||
shouldHave: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
params := &RequestParameters{Distinct: tt.distinct}
|
||||
result := handler.ApplyDistinct(tt.sqlQuery, params)
|
||||
|
||||
if tt.shouldHave != "" {
|
||||
if !strings.Contains(result, tt.shouldHave) {
|
||||
t.Errorf("Expected SQL to contain %q, got: %s", tt.shouldHave, result)
|
||||
}
|
||||
} else {
|
||||
// Should not have DISTINCT when not requested
|
||||
if strings.Contains(result, "DISTINCT") && !tt.distinct {
|
||||
t.Errorf("SQL should not contain DISTINCT when not requested: %s", result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseCommaSeparated tests comma-separated value parsing
|
||||
func TestParseCommaSeparated(t *testing.T) {
|
||||
handler := NewHandler(&MockDatabase{})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Simple comma-separated",
|
||||
input: "id,name,email",
|
||||
expected: []string{"id", "name", "email"},
|
||||
},
|
||||
{
|
||||
name: "With spaces",
|
||||
input: "id, name, email",
|
||||
expected: []string{"id", "name", "email"},
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
input: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "Single value",
|
||||
input: "id",
|
||||
expected: []string{"id"},
|
||||
},
|
||||
{
|
||||
name: "With extra commas",
|
||||
input: "id,,name,,email",
|
||||
expected: []string{"id", "name", "email"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.parseCommaSeparated(tt.input)
|
||||
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("Expected %d values, got %d", len(tt.expected), len(result))
|
||||
return
|
||||
}
|
||||
|
||||
for i, expected := range tt.expected {
|
||||
if result[i] != expected {
|
||||
t.Errorf("Expected value %d to be %s, got %s", i, expected, result[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSqlQryWhereOr tests OR WHERE clause manipulation
|
||||
func TestSqlQryWhereOr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sqlQuery string
|
||||
condition string
|
||||
shouldContain []string
|
||||
}{
|
||||
{
|
||||
name: "Add WHERE with OR to query without WHERE",
|
||||
sqlQuery: "SELECT * FROM users",
|
||||
condition: "status = 'inactive'",
|
||||
shouldContain: []string{"WHERE", "status = 'inactive'"},
|
||||
},
|
||||
{
|
||||
name: "Add OR to query with existing WHERE",
|
||||
sqlQuery: "SELECT * FROM users WHERE id > 0",
|
||||
condition: "status = 'inactive'",
|
||||
shouldContain: []string{"WHERE", "OR", "(status = 'inactive')"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sqlQryWhereOr(tt.sqlQuery, tt.condition)
|
||||
|
||||
for _, expected := range tt.shouldContain {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
83
pkg/funcspec/security_adapter.go
Normal file
83
pkg/funcspec/security_adapter.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package funcspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks registers security hooks for funcspec handlers
|
||||
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
|
||||
// We provide audit logging for data access tracking
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 1: BeforeQueryList - Audit logging before query list execution
|
||||
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
||||
secCtx := newFuncSpecSecurityContext(hookCtx)
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Hook 2: BeforeQuery - Audit logging before single query execution
|
||||
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
|
||||
secCtx := newFuncSpecSecurityContext(hookCtx)
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Note: Row-level security and column masking are challenging in funcspec
|
||||
// because the SQL query is fully user-defined. Security should be implemented
|
||||
// at the SQL function level or through database policies (RLS).
|
||||
}
|
||||
|
||||
// funcSpecSecurityContext adapts funcspec.HookContext to security.SecurityContext interface
|
||||
type funcSpecSecurityContext struct {
|
||||
ctx *HookContext
|
||||
}
|
||||
|
||||
func newFuncSpecSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||
return &funcSpecSecurityContext{ctx: ctx}
|
||||
}
|
||||
|
||||
func (f *funcSpecSecurityContext) GetContext() context.Context {
|
||||
return f.ctx.Context
|
||||
}
|
||||
|
||||
func (f *funcSpecSecurityContext) GetUserID() (int, bool) {
|
||||
if f.ctx.UserContext == nil {
|
||||
return 0, false
|
||||
}
|
||||
return int(f.ctx.UserContext.UserID), true
|
||||
}
|
||||
|
||||
func (f *funcSpecSecurityContext) GetSchema() string {
|
||||
// funcspec doesn't have a schema concept, extract from SQL query or use default
|
||||
return "public"
|
||||
}
|
||||
|
||||
func (f *funcSpecSecurityContext) GetEntity() string {
|
||||
// funcspec doesn't have an entity concept, could parse from SQL or use a placeholder
|
||||
return "sql_query"
|
||||
}
|
||||
|
||||
func (f *funcSpecSecurityContext) GetModel() interface{} {
|
||||
// funcspec doesn't use models in the same way as restheadspec
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *funcSpecSecurityContext) GetQuery() interface{} {
|
||||
// In funcspec, the query is a string, not a query builder object
|
||||
return f.ctx.SQLQuery
|
||||
}
|
||||
|
||||
func (f *funcSpecSecurityContext) SetQuery(query interface{}) {
|
||||
// In funcspec, we could modify the SQL string, but this should be done cautiously
|
||||
if sqlQuery, ok := query.(string); ok {
|
||||
f.ctx.SQLQuery = sqlQuery
|
||||
}
|
||||
}
|
||||
|
||||
func (f *funcSpecSecurityContext) GetResult() interface{} {
|
||||
return f.ctx.Result
|
||||
}
|
||||
|
||||
func (f *funcSpecSecurityContext) SetResult(result interface{}) {
|
||||
f.ctx.Result = result
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -22,6 +23,15 @@ func Init(dev bool) {
|
||||
|
||||
}
|
||||
|
||||
func UpdateLoggerPath(path string, dev bool) {
|
||||
defaultConfig := zap.NewProductionConfig()
|
||||
if dev {
|
||||
defaultConfig = zap.NewDevelopmentConfig()
|
||||
}
|
||||
defaultConfig.OutputPaths = []string{path}
|
||||
UpdateLogger(&defaultConfig)
|
||||
}
|
||||
|
||||
func UpdateLogger(config *zap.Config) {
|
||||
defaultConfig := zap.NewProductionConfig()
|
||||
defaultConfig.OutputPaths = []string{"resolvespec.log"}
|
||||
@@ -70,3 +80,50 @@ func Debug(template string, args ...interface{}) {
|
||||
}
|
||||
Logger.Debugw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
|
||||
}
|
||||
|
||||
// CatchPanic - Handle panic
|
||||
func CatchPanicCallback(location string, cb func(err any)) {
|
||||
if err := recover(); err != nil {
|
||||
// callstack := debug.Stack()
|
||||
|
||||
if Logger != nil {
|
||||
Error("Panic in %s : %v", location, err)
|
||||
} else {
|
||||
fmt.Printf("%s:PANIC->%+v", location, err)
|
||||
debug.PrintStack()
|
||||
}
|
||||
|
||||
// push to sentry
|
||||
// hub := sentry.CurrentHub()
|
||||
// if hub != nil {
|
||||
// evtID := hub.Recover(err)
|
||||
// if evtID != nil {
|
||||
// sentry.Flush(time.Second * 2)
|
||||
// }
|
||||
// }
|
||||
|
||||
if cb != nil {
|
||||
cb(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CatchPanic - Handle panic
|
||||
func CatchPanic(location string) {
|
||||
CatchPanicCallback(location, nil)
|
||||
}
|
||||
|
||||
// HandlePanic logs a panic and returns it as an error
|
||||
// This should be called with the result of recover() from a deferred function
|
||||
// Example usage:
|
||||
//
|
||||
// defer func() {
|
||||
// if r := recover(); r != nil {
|
||||
// err = logger.HandlePanic("MethodName", r)
|
||||
// }
|
||||
// }()
|
||||
func HandlePanic(methodName string, r any) error {
|
||||
stack := debug.Stack()
|
||||
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
|
||||
return fmt.Errorf("panic in %s: %v", methodName, r)
|
||||
}
|
||||
|
||||
259
pkg/metrics/README.md
Normal file
259
pkg/metrics/README.md
Normal file
@@ -0,0 +1,259 @@
|
||||
# Metrics Package
|
||||
|
||||
A pluggable metrics collection system with Prometheus implementation.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
|
||||
// Initialize Prometheus provider
|
||||
provider := metrics.NewPrometheusProvider()
|
||||
metrics.SetProvider(provider)
|
||||
|
||||
// Apply middleware to your router
|
||||
router.Use(provider.Middleware)
|
||||
|
||||
// Expose metrics endpoint
|
||||
http.Handle("/metrics", provider.Handler())
|
||||
```
|
||||
|
||||
## Provider Interface
|
||||
|
||||
The package uses a provider interface, allowing you to plug in different metric systems:
|
||||
|
||||
```go
|
||||
type Provider interface {
|
||||
RecordHTTPRequest(method, path, status string, duration time.Duration)
|
||||
IncRequestsInFlight()
|
||||
DecRequestsInFlight()
|
||||
RecordDBQuery(operation, table string, duration time.Duration, err error)
|
||||
RecordCacheHit(provider string)
|
||||
RecordCacheMiss(provider string)
|
||||
UpdateCacheSize(provider string, size int64)
|
||||
Handler() http.Handler
|
||||
}
|
||||
```
|
||||
|
||||
## Recording Metrics
|
||||
|
||||
### HTTP Metrics (Automatic)
|
||||
|
||||
When using the middleware, HTTP metrics are recorded automatically:
|
||||
|
||||
```go
|
||||
router.Use(provider.Middleware)
|
||||
```
|
||||
|
||||
**Collected:**
|
||||
- Request duration (histogram)
|
||||
- Request count by method, path, and status
|
||||
- Requests in flight (gauge)
|
||||
|
||||
### Database Metrics
|
||||
|
||||
```go
|
||||
start := time.Now()
|
||||
rows, err := db.Query("SELECT * FROM users WHERE id = ?", userID)
|
||||
duration := time.Since(start)
|
||||
|
||||
metrics.GetProvider().RecordDBQuery("SELECT", "users", duration, err)
|
||||
```
|
||||
|
||||
### Cache Metrics
|
||||
|
||||
```go
|
||||
// Record cache hit
|
||||
metrics.GetProvider().RecordCacheHit("memory")
|
||||
|
||||
// Record cache miss
|
||||
metrics.GetProvider().RecordCacheMiss("memory")
|
||||
|
||||
// Update cache size
|
||||
metrics.GetProvider().UpdateCacheSize("memory", 1024)
|
||||
```
|
||||
|
||||
## Prometheus Metrics
|
||||
|
||||
When using `PrometheusProvider`, the following metrics are available:
|
||||
|
||||
| Metric Name | Type | Labels | Description |
|
||||
|-------------|------|--------|-------------|
|
||||
| `http_request_duration_seconds` | Histogram | method, path, status | HTTP request duration |
|
||||
| `http_requests_total` | Counter | method, path, status | Total HTTP requests |
|
||||
| `http_requests_in_flight` | Gauge | - | Current in-flight requests |
|
||||
| `db_query_duration_seconds` | Histogram | operation, table | Database query duration |
|
||||
| `db_queries_total` | Counter | operation, table, status | Total database queries |
|
||||
| `cache_hits_total` | Counter | provider | Total cache hits |
|
||||
| `cache_misses_total` | Counter | provider | Total cache misses |
|
||||
| `cache_size_items` | Gauge | provider | Current cache size |
|
||||
|
||||
## Prometheus Queries
|
||||
|
||||
### HTTP Request Rate
|
||||
|
||||
```promql
|
||||
rate(http_requests_total[5m])
|
||||
```
|
||||
|
||||
### HTTP Request Duration (95th percentile)
|
||||
|
||||
```promql
|
||||
histogram_quantile(0.95, rate(http_request_duration_seconds_bucket[5m]))
|
||||
```
|
||||
|
||||
### Database Query Error Rate
|
||||
|
||||
```promql
|
||||
rate(db_queries_total{status="error"}[5m])
|
||||
```
|
||||
|
||||
### Cache Hit Rate
|
||||
|
||||
```promql
|
||||
rate(cache_hits_total[5m]) / (rate(cache_hits_total[5m]) + rate(cache_misses_total[5m]))
|
||||
```
|
||||
|
||||
## No-Op Provider
|
||||
|
||||
If metrics are disabled:
|
||||
|
||||
```go
|
||||
// No provider set - uses no-op provider automatically
|
||||
metrics.GetProvider().RecordHTTPRequest(...) // Does nothing
|
||||
```
|
||||
|
||||
## Custom Provider
|
||||
|
||||
Implement your own metrics provider:
|
||||
|
||||
```go
|
||||
type CustomProvider struct{}
|
||||
|
||||
func (c *CustomProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
|
||||
// Send to your metrics system
|
||||
}
|
||||
|
||||
// Implement other Provider interface methods...
|
||||
|
||||
func (c *CustomProvider) Handler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Return your metrics format
|
||||
})
|
||||
}
|
||||
|
||||
// Use it
|
||||
metrics.SetProvider(&CustomProvider{})
|
||||
```
|
||||
|
||||
## Complete Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Initialize metrics
|
||||
provider := metrics.NewPrometheusProvider()
|
||||
metrics.SetProvider(provider)
|
||||
|
||||
// Create router
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Apply metrics middleware
|
||||
router.Use(provider.Middleware)
|
||||
|
||||
// Expose metrics endpoint
|
||||
router.Handle("/metrics", provider.Handler())
|
||||
|
||||
// Your API routes
|
||||
router.HandleFunc("/api/users", getUsersHandler)
|
||||
|
||||
log.Fatal(http.ListenAndServe(":8080", router))
|
||||
}
|
||||
|
||||
func getUsersHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Record database query
|
||||
start := time.Now()
|
||||
users, err := fetchUsers()
|
||||
duration := time.Since(start)
|
||||
|
||||
metrics.GetProvider().RecordDBQuery("SELECT", "users", duration, err)
|
||||
|
||||
if err != nil {
|
||||
http.Error(w, "Internal Server Error", 500)
|
||||
return
|
||||
}
|
||||
|
||||
// Return users...
|
||||
}
|
||||
```
|
||||
|
||||
## Docker Compose Example
|
||||
|
||||
```yaml
|
||||
version: '3'
|
||||
services:
|
||||
app:
|
||||
build: .
|
||||
ports:
|
||||
- "8080:8080"
|
||||
|
||||
prometheus:
|
||||
image: prom/prometheus
|
||||
ports:
|
||||
- "9090:9090"
|
||||
volumes:
|
||||
- ./prometheus.yml:/etc/prometheus/prometheus.yml
|
||||
command:
|
||||
- '--config.file=/etc/prometheus/prometheus.yml'
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana
|
||||
ports:
|
||||
- "3000:3000"
|
||||
depends_on:
|
||||
- prometheus
|
||||
```
|
||||
|
||||
**prometheus.yml:**
|
||||
|
||||
```yaml
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
|
||||
scrape_configs:
|
||||
- job_name: 'resolvespec'
|
||||
static_configs:
|
||||
- targets: ['app:8080']
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Label Cardinality**: Keep labels low-cardinality
|
||||
- ✅ Good: `method`, `status_code`
|
||||
- ❌ Bad: `user_id`, `timestamp`
|
||||
|
||||
2. **Path Normalization**: Normalize dynamic paths
|
||||
```go
|
||||
// Instead of /api/users/123
|
||||
// Use /api/users/:id
|
||||
```
|
||||
|
||||
3. **Metric Naming**: Follow Prometheus conventions
|
||||
- Use `_total` suffix for counters
|
||||
- Use `_seconds` suffix for durations
|
||||
- Use base units (seconds, not milliseconds)
|
||||
|
||||
4. **Performance**: Metrics collection is lock-free and highly performant
|
||||
- Safe for high-throughput applications
|
||||
- Minimal overhead (<1% in most cases)
|
||||
73
pkg/metrics/interfaces.go
Normal file
73
pkg/metrics/interfaces.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// Provider defines the interface for metric collection
|
||||
type Provider interface {
|
||||
// RecordHTTPRequest records metrics for an HTTP request
|
||||
RecordHTTPRequest(method, path, status string, duration time.Duration)
|
||||
|
||||
// IncRequestsInFlight increments the in-flight requests counter
|
||||
IncRequestsInFlight()
|
||||
|
||||
// DecRequestsInFlight decrements the in-flight requests counter
|
||||
DecRequestsInFlight()
|
||||
|
||||
// RecordDBQuery records metrics for a database query
|
||||
RecordDBQuery(operation, table string, duration time.Duration, err error)
|
||||
|
||||
// RecordCacheHit records a cache hit
|
||||
RecordCacheHit(provider string)
|
||||
|
||||
// RecordCacheMiss records a cache miss
|
||||
RecordCacheMiss(provider string)
|
||||
|
||||
// UpdateCacheSize updates the cache size metric
|
||||
UpdateCacheSize(provider string, size int64)
|
||||
|
||||
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
|
||||
Handler() http.Handler
|
||||
}
|
||||
|
||||
// globalProvider is the global metrics provider
|
||||
var globalProvider Provider
|
||||
|
||||
// SetProvider sets the global metrics provider
|
||||
func SetProvider(p Provider) {
|
||||
globalProvider = p
|
||||
}
|
||||
|
||||
// GetProvider returns the current metrics provider
|
||||
func GetProvider() Provider {
|
||||
if globalProvider == nil {
|
||||
// Return no-op provider if none is set
|
||||
return &NoOpProvider{}
|
||||
}
|
||||
return globalProvider
|
||||
}
|
||||
|
||||
// NoOpProvider is a no-op implementation of Provider
|
||||
type NoOpProvider struct{}
|
||||
|
||||
func (n *NoOpProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {}
|
||||
func (n *NoOpProvider) IncRequestsInFlight() {}
|
||||
func (n *NoOpProvider) DecRequestsInFlight() {}
|
||||
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
||||
}
|
||||
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
||||
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
||||
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
|
||||
func (n *NoOpProvider) Handler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
_, err := w.Write([]byte("Metrics provider not configured"))
|
||||
if err != nil {
|
||||
logger.Warn("Failed to write. %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
174
pkg/metrics/prometheus.go
Normal file
174
pkg/metrics/prometheus.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
// PrometheusProvider implements the Provider interface using Prometheus
|
||||
type PrometheusProvider struct {
|
||||
requestDuration *prometheus.HistogramVec
|
||||
requestTotal *prometheus.CounterVec
|
||||
requestsInFlight prometheus.Gauge
|
||||
dbQueryDuration *prometheus.HistogramVec
|
||||
dbQueryTotal *prometheus.CounterVec
|
||||
cacheHits *prometheus.CounterVec
|
||||
cacheMisses *prometheus.CounterVec
|
||||
cacheSize *prometheus.GaugeVec
|
||||
}
|
||||
|
||||
// NewPrometheusProvider creates a new Prometheus metrics provider
|
||||
func NewPrometheusProvider() *PrometheusProvider {
|
||||
return &PrometheusProvider{
|
||||
requestDuration: promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "http_request_duration_seconds",
|
||||
Help: "HTTP request duration in seconds",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
},
|
||||
[]string{"method", "path", "status"},
|
||||
),
|
||||
requestTotal: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "http_requests_total",
|
||||
Help: "Total number of HTTP requests",
|
||||
},
|
||||
[]string{"method", "path", "status"},
|
||||
),
|
||||
|
||||
requestsInFlight: promauto.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "http_requests_in_flight",
|
||||
Help: "Current number of HTTP requests being processed",
|
||||
},
|
||||
),
|
||||
dbQueryDuration: promauto.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "db_query_duration_seconds",
|
||||
Help: "Database query duration in seconds",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
},
|
||||
[]string{"operation", "table"},
|
||||
),
|
||||
dbQueryTotal: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "db_queries_total",
|
||||
Help: "Total number of database queries",
|
||||
},
|
||||
[]string{"operation", "table", "status"},
|
||||
),
|
||||
cacheHits: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "cache_hits_total",
|
||||
Help: "Total number of cache hits",
|
||||
},
|
||||
[]string{"provider"},
|
||||
),
|
||||
cacheMisses: promauto.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "cache_misses_total",
|
||||
Help: "Total number of cache misses",
|
||||
},
|
||||
[]string{"provider"},
|
||||
),
|
||||
cacheSize: promauto.NewGaugeVec(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "cache_size_items",
|
||||
Help: "Number of items in cache",
|
||||
},
|
||||
[]string{"provider"},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// ResponseWriter wraps http.ResponseWriter to capture status code
|
||||
type ResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func NewResponseWriter(w http.ResponseWriter) *ResponseWriter {
|
||||
return &ResponseWriter{
|
||||
ResponseWriter: w,
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *ResponseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// RecordHTTPRequest implements Provider interface
|
||||
func (p *PrometheusProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
|
||||
p.requestDuration.WithLabelValues(method, path, status).Observe(duration.Seconds())
|
||||
p.requestTotal.WithLabelValues(method, path, status).Inc()
|
||||
}
|
||||
|
||||
// IncRequestsInFlight implements Provider interface
|
||||
func (p *PrometheusProvider) IncRequestsInFlight() {
|
||||
p.requestsInFlight.Inc()
|
||||
}
|
||||
|
||||
// DecRequestsInFlight implements Provider interface
|
||||
func (p *PrometheusProvider) DecRequestsInFlight() {
|
||||
p.requestsInFlight.Dec()
|
||||
}
|
||||
|
||||
// RecordDBQuery implements Provider interface
|
||||
func (p *PrometheusProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
||||
status := "success"
|
||||
if err != nil {
|
||||
status = "error"
|
||||
}
|
||||
p.dbQueryDuration.WithLabelValues(operation, table).Observe(duration.Seconds())
|
||||
p.dbQueryTotal.WithLabelValues(operation, table, status).Inc()
|
||||
}
|
||||
|
||||
// RecordCacheHit implements Provider interface
|
||||
func (p *PrometheusProvider) RecordCacheHit(provider string) {
|
||||
p.cacheHits.WithLabelValues(provider).Inc()
|
||||
}
|
||||
|
||||
// RecordCacheMiss implements Provider interface
|
||||
func (p *PrometheusProvider) RecordCacheMiss(provider string) {
|
||||
p.cacheMisses.WithLabelValues(provider).Inc()
|
||||
}
|
||||
|
||||
// UpdateCacheSize implements Provider interface
|
||||
func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) {
|
||||
p.cacheSize.WithLabelValues(provider).Set(float64(size))
|
||||
}
|
||||
|
||||
// Handler implements Provider interface
|
||||
func (p *PrometheusProvider) Handler() http.Handler {
|
||||
return promhttp.Handler()
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that collects metrics
|
||||
func (p *PrometheusProvider) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Increment in-flight requests
|
||||
p.IncRequestsInFlight()
|
||||
defer p.DecRequestsInFlight()
|
||||
|
||||
// Wrap response writer to capture status code
|
||||
rw := NewResponseWriter(w)
|
||||
|
||||
// Call next handler
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
// Record metrics
|
||||
duration := time.Since(start)
|
||||
status := strconv.Itoa(rw.statusCode)
|
||||
|
||||
p.RecordHTTPRequest(r.Method, r.URL.Path, status, duration)
|
||||
})
|
||||
}
|
||||
806
pkg/middleware/README.md
Normal file
806
pkg/middleware/README.md
Normal file
@@ -0,0 +1,806 @@
|
||||
# Middleware Package
|
||||
|
||||
HTTP middleware utilities for security and performance.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Rate Limiting](#rate-limiting)
|
||||
2. [Request Size Limits](#request-size-limits)
|
||||
3. [Input Sanitization](#input-sanitization)
|
||||
|
||||
---
|
||||
|
||||
## Rate Limiting
|
||||
|
||||
Production-grade rate limiting using token bucket algorithm.
|
||||
|
||||
### Quick Start
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
|
||||
// Create rate limiter: 100 requests per second, burst of 20
|
||||
rateLimiter := middleware.NewRateLimiter(100, 20)
|
||||
|
||||
// Apply to all routes
|
||||
router.Use(rateLimiter.Middleware)
|
||||
```
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Rate limit: 10 requests per second, burst of 5
|
||||
rateLimiter := middleware.NewRateLimiter(10, 5)
|
||||
router.Use(rateLimiter.Middleware)
|
||||
|
||||
router.HandleFunc("/api/data", dataHandler)
|
||||
|
||||
log.Fatal(http.ListenAndServe(":8080", router))
|
||||
}
|
||||
```
|
||||
|
||||
### Custom Key Extraction
|
||||
|
||||
By default, rate limiting is per IP address. Customize the key:
|
||||
|
||||
```go
|
||||
// Rate limit by User ID from header
|
||||
keyFunc := func(r *http.Request) string {
|
||||
userID := r.Header.Get("X-User-ID")
|
||||
if userID == "" {
|
||||
return r.RemoteAddr // Fallback to IP
|
||||
}
|
||||
return "user:" + userID
|
||||
}
|
||||
|
||||
router.Use(rateLimiter.MiddlewareWithKeyFunc(keyFunc))
|
||||
```
|
||||
|
||||
### Advanced Key Functions
|
||||
|
||||
**By API Key:**
|
||||
|
||||
```go
|
||||
keyFunc := func(r *http.Request) string {
|
||||
apiKey := r.Header.Get("X-API-Key")
|
||||
if apiKey == "" {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return "api:" + apiKey
|
||||
}
|
||||
```
|
||||
|
||||
**By Authenticated User:**
|
||||
|
||||
```go
|
||||
keyFunc := func(r *http.Request) string {
|
||||
// Extract from JWT or session
|
||||
user := getUserFromContext(r.Context())
|
||||
if user != nil {
|
||||
return "user:" + user.ID
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
```
|
||||
|
||||
**By Path + User:**
|
||||
|
||||
```go
|
||||
keyFunc := func(r *http.Request) string {
|
||||
user := getUserFromContext(r.Context())
|
||||
if user != nil {
|
||||
return fmt.Sprintf("user:%s:path:%s", user.ID, r.URL.Path)
|
||||
}
|
||||
return r.URL.Path + ":" + r.RemoteAddr
|
||||
}
|
||||
```
|
||||
|
||||
### Different Limits Per Route
|
||||
|
||||
```go
|
||||
func main() {
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Public endpoints: 10 rps
|
||||
publicLimiter := middleware.NewRateLimiter(10, 5)
|
||||
|
||||
// API endpoints: 100 rps
|
||||
apiLimiter := middleware.NewRateLimiter(100, 20)
|
||||
|
||||
// Admin endpoints: 1000 rps
|
||||
adminLimiter := middleware.NewRateLimiter(1000, 50)
|
||||
|
||||
// Apply different limiters to subrouters
|
||||
publicRouter := router.PathPrefix("/public").Subrouter()
|
||||
publicRouter.Use(publicLimiter.Middleware)
|
||||
|
||||
apiRouter := router.PathPrefix("/api").Subrouter()
|
||||
apiRouter.Use(apiLimiter.Middleware)
|
||||
|
||||
adminRouter := router.PathPrefix("/admin").Subrouter()
|
||||
adminRouter.Use(adminLimiter.Middleware)
|
||||
}
|
||||
```
|
||||
|
||||
### Rate Limit Response
|
||||
|
||||
When rate limited, clients receive:
|
||||
|
||||
```http
|
||||
HTTP/1.1 429 Too Many Requests
|
||||
Content-Type: text/plain
|
||||
|
||||
```
|
||||
|
||||
### Configuration Examples
|
||||
|
||||
**Tight Rate Limit (Anti-abuse):**
|
||||
|
||||
```go
|
||||
// 1 request per second, burst of 3
|
||||
rateLimiter := middleware.NewRateLimiter(1, 3)
|
||||
```
|
||||
|
||||
**Moderate Rate Limit (Standard API):**
|
||||
|
||||
```go
|
||||
// 100 requests per second, burst of 20
|
||||
rateLimiter := middleware.NewRateLimiter(100, 20)
|
||||
```
|
||||
|
||||
**Generous Rate Limit (Internal Services):**
|
||||
|
||||
```go
|
||||
// 1000 requests per second, burst of 100
|
||||
rateLimiter := middleware.NewRateLimiter(1000, 100)
|
||||
```
|
||||
|
||||
**Time-based Limits:**
|
||||
|
||||
```go
|
||||
// 60 requests per minute = 1 request per second
|
||||
rateLimiter := middleware.NewRateLimiter(1, 10)
|
||||
|
||||
// 1000 requests per hour ≈ 0.28 requests per second
|
||||
rateLimiter := middleware.NewRateLimiter(0.28, 50)
|
||||
```
|
||||
|
||||
### Understanding Burst
|
||||
|
||||
The burst parameter allows short bursts above the rate:
|
||||
|
||||
```go
|
||||
// Rate: 10 rps, Burst: 5
|
||||
// Allows up to 5 requests immediately, then 10/second
|
||||
rateLimiter := middleware.NewRateLimiter(10, 5)
|
||||
```
|
||||
|
||||
**Bucket fills at rate:** 10 tokens/second
|
||||
**Bucket capacity:** 5 tokens
|
||||
**Request consumes:** 1 token
|
||||
|
||||
**Example traffic pattern:**
|
||||
- T=0s: 5 requests → ✅ All allowed (burst)
|
||||
- T=0.1s: 1 request → ❌ Denied (bucket empty)
|
||||
- T=0.5s: 1 request → ✅ Allowed (bucket refilled 0.5 tokens)
|
||||
- T=1s: 1 request → ✅ Allowed (bucket has ~1 token)
|
||||
|
||||
### Cleanup Behavior
|
||||
|
||||
The rate limiter automatically cleans up inactive limiters every 5 minutes to prevent memory leaks.
|
||||
|
||||
### Performance Characteristics
|
||||
|
||||
- **Memory**: ~100 bytes per active limiter
|
||||
- **Throughput**: >1M requests/second
|
||||
- **Latency**: <1μs per request
|
||||
- **Concurrency**: Lock-free for rate checks
|
||||
|
||||
### Production Deployment
|
||||
|
||||
**With Reverse Proxy:**
|
||||
|
||||
```go
|
||||
// Use X-Forwarded-For or X-Real-IP
|
||||
keyFunc := func(r *http.Request) string {
|
||||
// Check proxy headers first
|
||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||
return strings.Split(ip, ",")[0]
|
||||
}
|
||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
router.Use(rateLimiter.MiddlewareWithKeyFunc(keyFunc))
|
||||
```
|
||||
|
||||
**Environment-based Configuration:**
|
||||
|
||||
```go
|
||||
import "os"
|
||||
|
||||
func getRateLimiter() *middleware.RateLimiter {
|
||||
rps := getEnvFloat("RATE_LIMIT_RPS", 100)
|
||||
burst := getEnvInt("RATE_LIMIT_BURST", 20)
|
||||
return middleware.NewRateLimiter(rps, burst)
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Rate Limits
|
||||
|
||||
```bash
|
||||
# Send 10 requests rapidly
|
||||
for i in {1..10}; do
|
||||
curl -w "Status: %{http_code}\n" http://localhost:8080/api/data
|
||||
done
|
||||
```
|
||||
|
||||
**Expected output:**
|
||||
```
|
||||
Status: 200 # Request 1-5 (within burst)
|
||||
Status: 200
|
||||
Status: 200
|
||||
Status: 200
|
||||
Status: 200
|
||||
Status: 429 # Request 6-10 (rate limited)
|
||||
Status: 429
|
||||
Status: 429
|
||||
Status: 429
|
||||
Status: 429
|
||||
```
|
||||
|
||||
### Complete Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// Configuration from environment
|
||||
rps, _ := strconv.ParseFloat(os.Getenv("RATE_LIMIT_RPS"), 64)
|
||||
if rps == 0 {
|
||||
rps = 100 // Default
|
||||
}
|
||||
|
||||
burst, _ := strconv.Atoi(os.Getenv("RATE_LIMIT_BURST"))
|
||||
if burst == 0 {
|
||||
burst = 20 // Default
|
||||
}
|
||||
|
||||
// Create rate limiter
|
||||
rateLimiter := middleware.NewRateLimiter(rps, burst)
|
||||
|
||||
// Custom key extraction
|
||||
keyFunc := func(r *http.Request) string {
|
||||
// Try API key first
|
||||
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
|
||||
return "api:" + apiKey
|
||||
}
|
||||
// Try authenticated user
|
||||
if userID := r.Header.Get("X-User-ID"); userID != "" {
|
||||
return "user:" + userID
|
||||
}
|
||||
// Fall back to IP
|
||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||
return ip
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
// Create router
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Apply rate limiting
|
||||
router.Use(rateLimiter.MiddlewareWithKeyFunc(keyFunc))
|
||||
|
||||
// Routes
|
||||
router.HandleFunc("/api/data", dataHandler)
|
||||
router.HandleFunc("/health", healthHandler)
|
||||
|
||||
log.Printf("Starting server with rate limit: %.1f rps, burst: %d", rps, burst)
|
||||
log.Fatal(http.ListenAndServe(":8080", router))
|
||||
}
|
||||
|
||||
func dataHandler(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"message": "Data endpoint",
|
||||
})
|
||||
}
|
||||
|
||||
func healthHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Set Appropriate Limits**: Consider your backend capacity
|
||||
- Database: Can it handle X queries/second?
|
||||
- External APIs: What are their rate limits?
|
||||
- Server resources: CPU, memory, connections
|
||||
|
||||
2. **Use Burst Wisely**: Allow legitimate traffic spikes
|
||||
- Too low: Reject valid bursts
|
||||
- Too high: Allow abuse
|
||||
|
||||
3. **Monitor Rate Limits**: Track how often limits are hit
|
||||
```go
|
||||
// Log rate limit events
|
||||
if rateLimited {
|
||||
log.Printf("Rate limited: %s", clientKey)
|
||||
}
|
||||
```
|
||||
|
||||
4. **Provide Feedback**: Include rate limit headers (future enhancement)
|
||||
```http
|
||||
X-RateLimit-Limit: 100
|
||||
X-RateLimit-Remaining: 95
|
||||
X-RateLimit-Reset: 1640000000
|
||||
```
|
||||
|
||||
5. **Tiered Limits**: Different limits for different user tiers
|
||||
```go
|
||||
func getRateLimiter(userTier string) *middleware.RateLimiter {
|
||||
switch userTier {
|
||||
case "premium":
|
||||
return middleware.NewRateLimiter(1000, 100)
|
||||
case "standard":
|
||||
return middleware.NewRateLimiter(100, 20)
|
||||
default:
|
||||
return middleware.NewRateLimiter(10, 5)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Request Size Limits
|
||||
|
||||
Protect against oversized request bodies with configurable size limits.
|
||||
|
||||
### Quick Start
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
|
||||
// Default: 10MB limit
|
||||
sizeLimiter := middleware.NewRequestSizeLimiter(0)
|
||||
router.Use(sizeLimiter.Middleware)
|
||||
```
|
||||
|
||||
### Custom Size Limit
|
||||
|
||||
```go
|
||||
// 5MB limit
|
||||
sizeLimiter := middleware.NewRequestSizeLimiter(5 * 1024 * 1024)
|
||||
router.Use(sizeLimiter.Middleware)
|
||||
|
||||
// Or use constants
|
||||
sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size5MB)
|
||||
```
|
||||
|
||||
### Available Size Constants
|
||||
|
||||
```go
|
||||
middleware.Size1MB // 1 MB
|
||||
middleware.Size5MB // 5 MB
|
||||
middleware.Size10MB // 10 MB (default)
|
||||
middleware.Size50MB // 50 MB
|
||||
middleware.Size100MB // 100 MB
|
||||
```
|
||||
|
||||
### Different Limits Per Route
|
||||
|
||||
```go
|
||||
func main() {
|
||||
router := mux.NewRouter()
|
||||
|
||||
// File upload endpoint: 50MB
|
||||
uploadLimiter := middleware.NewRequestSizeLimiter(middleware.Size50MB)
|
||||
uploadRouter := router.PathPrefix("/upload").Subrouter()
|
||||
uploadRouter.Use(uploadLimiter.Middleware)
|
||||
|
||||
// API endpoints: 1MB
|
||||
apiLimiter := middleware.NewRequestSizeLimiter(middleware.Size1MB)
|
||||
apiRouter := router.PathPrefix("/api").Subrouter()
|
||||
apiRouter.Use(apiLimiter.Middleware)
|
||||
}
|
||||
```
|
||||
|
||||
### Dynamic Size Limits
|
||||
|
||||
```go
|
||||
// Custom size based on request
|
||||
sizeFunc := func(r *http.Request) int64 {
|
||||
// Premium users get 50MB
|
||||
if isPremiumUser(r) {
|
||||
return middleware.Size50MB
|
||||
}
|
||||
// Free users get 5MB
|
||||
return middleware.Size5MB
|
||||
}
|
||||
|
||||
router.Use(sizeLimiter.MiddlewareWithCustomSize(sizeFunc))
|
||||
```
|
||||
|
||||
**By Content-Type:**
|
||||
|
||||
```go
|
||||
sizeFunc := func(r *http.Request) int64 {
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
|
||||
switch {
|
||||
case strings.Contains(contentType, "multipart/form-data"):
|
||||
return middleware.Size50MB // File uploads
|
||||
case strings.Contains(contentType, "application/json"):
|
||||
return middleware.Size1MB // JSON APIs
|
||||
default:
|
||||
return middleware.Size10MB // Default
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Error Response
|
||||
|
||||
When size limit exceeded:
|
||||
|
||||
```http
|
||||
HTTP/1.1 413 Request Entity Too Large
|
||||
X-Max-Request-Size: 10485760
|
||||
|
||||
http: request body too large
|
||||
```
|
||||
|
||||
### Complete Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
router := mux.NewRouter()
|
||||
|
||||
// API routes: 1MB limit
|
||||
api := router.PathPrefix("/api").Subrouter()
|
||||
apiLimiter := middleware.NewRequestSizeLimiter(middleware.Size1MB)
|
||||
api.Use(apiLimiter.Middleware)
|
||||
api.HandleFunc("/users", createUserHandler).Methods("POST")
|
||||
|
||||
// Upload routes: 50MB limit
|
||||
upload := router.PathPrefix("/upload").Subrouter()
|
||||
uploadLimiter := middleware.NewRequestSizeLimiter(middleware.Size50MB)
|
||||
upload.Use(uploadLimiter.Middleware)
|
||||
upload.HandleFunc("/file", uploadFileHandler).Methods("POST")
|
||||
|
||||
log.Fatal(http.ListenAndServe(":8080", router))
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Input Sanitization
|
||||
|
||||
Protect against XSS, injection attacks, and malicious input.
|
||||
|
||||
### Quick Start
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
|
||||
// Default sanitizer (safe defaults)
|
||||
sanitizer := middleware.DefaultSanitizer()
|
||||
router.Use(sanitizer.Middleware)
|
||||
```
|
||||
|
||||
### Sanitizer Types
|
||||
|
||||
**Default Sanitizer (Recommended):**
|
||||
|
||||
```go
|
||||
sanitizer := middleware.DefaultSanitizer()
|
||||
// ✓ Escapes HTML entities
|
||||
// ✓ Removes null bytes
|
||||
// ✓ Removes control characters
|
||||
// ✓ Blocks XSS patterns (script tags, event handlers)
|
||||
// ✗ Does not strip HTML (allows legitimate content)
|
||||
```
|
||||
|
||||
**Strict Sanitizer:**
|
||||
|
||||
```go
|
||||
sanitizer := middleware.StrictSanitizer()
|
||||
// ✓ All default features
|
||||
// ✓ Strips ALL HTML tags
|
||||
// ✓ Max string length: 10,000 chars
|
||||
```
|
||||
|
||||
### Custom Configuration
|
||||
|
||||
```go
|
||||
sanitizer := &middleware.Sanitizer{
|
||||
StripHTML: true, // Remove HTML tags
|
||||
EscapeHTML: false, // Don't escape (already stripped)
|
||||
RemoveNullBytes: true, // Remove \x00
|
||||
RemoveControlChars: true, // Remove dangerous control chars
|
||||
MaxStringLength: 5000, // Limit to 5000 chars
|
||||
|
||||
// Block patterns (regex)
|
||||
BlockPatterns: []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)<script`),
|
||||
regexp.MustCompile(`(?i)javascript:`),
|
||||
},
|
||||
|
||||
// Custom sanitization function
|
||||
CustomSanitizer: func(s string) string {
|
||||
// Your custom logic
|
||||
return strings.ToLower(s)
|
||||
},
|
||||
}
|
||||
|
||||
router.Use(sanitizer.Middleware)
|
||||
```
|
||||
|
||||
### What Gets Sanitized
|
||||
|
||||
**Automatic (via middleware):**
|
||||
- Query parameters
|
||||
- Headers (User-Agent, Referer, X-Forwarded-For, X-Real-IP)
|
||||
|
||||
**Manual (in your handler):**
|
||||
- Request body (JSON, form data)
|
||||
- Database queries
|
||||
- File names
|
||||
|
||||
### Manual Sanitization
|
||||
|
||||
**String Values:**
|
||||
|
||||
```go
|
||||
sanitizer := middleware.DefaultSanitizer()
|
||||
|
||||
// Sanitize user input
|
||||
username := sanitizer.Sanitize(r.FormValue("username"))
|
||||
email := sanitizer.Sanitize(r.FormValue("email"))
|
||||
```
|
||||
|
||||
**Map/JSON Data:**
|
||||
|
||||
```go
|
||||
var data map[string]interface{}
|
||||
json.Unmarshal(body, &data)
|
||||
|
||||
// Sanitize all string values recursively
|
||||
sanitizedData := sanitizer.SanitizeMap(data)
|
||||
```
|
||||
|
||||
**Nested Structures:**
|
||||
|
||||
```go
|
||||
type User struct {
|
||||
Name string
|
||||
Email string
|
||||
Bio string
|
||||
Profile map[string]interface{}
|
||||
}
|
||||
|
||||
// After unmarshaling
|
||||
user.Name = sanitizer.Sanitize(user.Name)
|
||||
user.Email = sanitizer.Sanitize(user.Email)
|
||||
user.Bio = sanitizer.Sanitize(user.Bio)
|
||||
user.Profile = sanitizer.SanitizeMap(user.Profile)
|
||||
```
|
||||
|
||||
### Specialized Sanitizers
|
||||
|
||||
**Filenames:**
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
|
||||
filename := middleware.SanitizeFilename(uploadedFilename)
|
||||
// Removes: .., /, \, null bytes
|
||||
// Limits: 255 characters
|
||||
```
|
||||
|
||||
**Emails:**
|
||||
|
||||
```go
|
||||
email := middleware.SanitizeEmail(" USER@EXAMPLE.COM ")
|
||||
// Result: "user@example.com"
|
||||
// Trims, lowercases, removes null bytes
|
||||
```
|
||||
|
||||
**URLs:**
|
||||
|
||||
```go
|
||||
url := middleware.SanitizeURL(userInput)
|
||||
// Blocks: javascript:, data: protocols
|
||||
// Removes: null bytes
|
||||
```
|
||||
|
||||
### Blocked Patterns (Default)
|
||||
|
||||
The default sanitizer blocks:
|
||||
|
||||
1. **Script tags**: `<script>...</script>`
|
||||
2. **JavaScript protocol**: `javascript:alert(1)`
|
||||
3. **Event handlers**: `onclick="..."`, `onerror="..."`
|
||||
4. **Iframes**: `<iframe src="...">`
|
||||
5. **Objects**: `<object data="...">`
|
||||
6. **Embeds**: `<embed src="...">`
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
**1. Layer Defense:**
|
||||
|
||||
```go
|
||||
// Layer 1: Middleware (query params, headers)
|
||||
router.Use(sanitizer.Middleware)
|
||||
|
||||
// Layer 2: Input validation (in handler)
|
||||
func createUserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
var user User
|
||||
json.NewDecoder(r.Body).Decode(&user)
|
||||
|
||||
// Sanitize
|
||||
user.Name = sanitizer.Sanitize(user.Name)
|
||||
user.Email = middleware.SanitizeEmail(user.Email)
|
||||
|
||||
// Validate
|
||||
if !isValidEmail(user.Email) {
|
||||
http.Error(w, "Invalid email", 400)
|
||||
return
|
||||
}
|
||||
|
||||
// Use parameterized queries (prevents SQL injection)
|
||||
db.Exec("INSERT INTO users (name, email) VALUES (?, ?)",
|
||||
user.Name, user.Email)
|
||||
}
|
||||
```
|
||||
|
||||
**2. Context-Aware Sanitization:**
|
||||
|
||||
```go
|
||||
// HTML content (user posts, comments)
|
||||
sanitizer := middleware.StrictSanitizer()
|
||||
post.Content = sanitizer.Sanitize(post.Content)
|
||||
|
||||
// Structured data (JSON API)
|
||||
sanitizer := middleware.DefaultSanitizer()
|
||||
data = sanitizer.SanitizeMap(jsonData)
|
||||
|
||||
// Search queries (preserve special chars)
|
||||
query = middleware.SanitizeFilename(searchTerm) // Light sanitization
|
||||
```
|
||||
|
||||
**3. Output Encoding:**
|
||||
|
||||
```go
|
||||
// When rendering HTML
|
||||
import "html/template"
|
||||
|
||||
tmpl := template.Must(template.New("page").Parse(`
|
||||
<h1>{{.Title}}</h1>
|
||||
<p>{{.Content}}</p>
|
||||
`))
|
||||
|
||||
// template.HTML automatically escapes
|
||||
tmpl.Execute(w, data)
|
||||
```
|
||||
|
||||
### Complete Example
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/middleware"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Apply sanitization middleware
|
||||
sanitizer := middleware.DefaultSanitizer()
|
||||
router.Use(sanitizer.Middleware)
|
||||
|
||||
router.HandleFunc("/api/users", createUserHandler).Methods("POST")
|
||||
|
||||
log.Fatal(http.ListenAndServe(":8080", router))
|
||||
}
|
||||
|
||||
func createUserHandler(w http.ResponseWriter, r *http.Request) {
|
||||
sanitizer := middleware.DefaultSanitizer()
|
||||
|
||||
var user struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Bio string `json:"bio"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
|
||||
http.Error(w, "Invalid JSON", 400)
|
||||
return
|
||||
}
|
||||
|
||||
// Sanitize inputs
|
||||
user.Name = sanitizer.Sanitize(user.Name)
|
||||
user.Email = middleware.SanitizeEmail(user.Email)
|
||||
user.Bio = sanitizer.Sanitize(user.Bio)
|
||||
|
||||
// Validate
|
||||
if len(user.Name) == 0 || len(user.Email) == 0 {
|
||||
http.Error(w, "Name and email required", 400)
|
||||
return
|
||||
}
|
||||
|
||||
// Save to database (use parameterized queries!)
|
||||
// db.Exec("INSERT INTO users (name, email, bio) VALUES (?, ?, ?)",
|
||||
// user.Name, user.Email, user.Bio)
|
||||
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]string{
|
||||
"status": "created",
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Sanitization
|
||||
|
||||
```bash
|
||||
# Test XSS prevention
|
||||
curl -X POST http://localhost:8080/api/users \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"name": "<script>alert(1)</script>John",
|
||||
"email": "test@example.com",
|
||||
"bio": "My bio with <iframe src=\"evil.com\"></iframe>"
|
||||
}'
|
||||
|
||||
# Script tags and iframes should be removed
|
||||
```
|
||||
|
||||
### Performance
|
||||
|
||||
- **Overhead**: <1ms per request for typical payloads
|
||||
- **Regex compilation**: Done once at initialization
|
||||
- **Safe for production**: Minimal performance impact
|
||||
- **Safe for production**: Minimal performance impact
|
||||
212
pkg/middleware/blacklist.go
Normal file
212
pkg/middleware/blacklist.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// IPBlacklist provides IP blocking functionality
|
||||
type IPBlacklist struct {
|
||||
mu sync.RWMutex
|
||||
ips map[string]bool // Individual IPs
|
||||
cidrs []*net.IPNet // CIDR ranges
|
||||
reason map[string]string
|
||||
useProxy bool // Whether to check X-Forwarded-For headers
|
||||
}
|
||||
|
||||
// BlacklistConfig configures the IP blacklist
|
||||
type BlacklistConfig struct {
|
||||
// UseProxy indicates whether to extract IP from X-Forwarded-For/X-Real-IP headers
|
||||
UseProxy bool
|
||||
}
|
||||
|
||||
// NewIPBlacklist creates a new IP blacklist
|
||||
func NewIPBlacklist(config BlacklistConfig) *IPBlacklist {
|
||||
return &IPBlacklist{
|
||||
ips: make(map[string]bool),
|
||||
cidrs: make([]*net.IPNet, 0),
|
||||
reason: make(map[string]string),
|
||||
useProxy: config.UseProxy,
|
||||
}
|
||||
}
|
||||
|
||||
// BlockIP blocks a single IP address
|
||||
func (bl *IPBlacklist) BlockIP(ip string, reason string) error {
|
||||
// Validate IP
|
||||
if net.ParseIP(ip) == nil {
|
||||
return &net.ParseError{Type: "IP address", Text: ip}
|
||||
}
|
||||
|
||||
bl.mu.Lock()
|
||||
defer bl.mu.Unlock()
|
||||
|
||||
bl.ips[ip] = true
|
||||
if reason != "" {
|
||||
bl.reason[ip] = reason
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BlockCIDR blocks an IP range using CIDR notation
|
||||
func (bl *IPBlacklist) BlockCIDR(cidr string, reason string) error {
|
||||
_, ipNet, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bl.mu.Lock()
|
||||
defer bl.mu.Unlock()
|
||||
|
||||
bl.cidrs = append(bl.cidrs, ipNet)
|
||||
if reason != "" {
|
||||
bl.reason[cidr] = reason
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnblockIP removes an IP from the blacklist
|
||||
func (bl *IPBlacklist) UnblockIP(ip string) {
|
||||
bl.mu.Lock()
|
||||
defer bl.mu.Unlock()
|
||||
|
||||
delete(bl.ips, ip)
|
||||
delete(bl.reason, ip)
|
||||
}
|
||||
|
||||
// UnblockCIDR removes a CIDR range from the blacklist
|
||||
func (bl *IPBlacklist) UnblockCIDR(cidr string) {
|
||||
bl.mu.Lock()
|
||||
defer bl.mu.Unlock()
|
||||
|
||||
// Find and remove the CIDR
|
||||
for i, ipNet := range bl.cidrs {
|
||||
if ipNet.String() == cidr {
|
||||
bl.cidrs = append(bl.cidrs[:i], bl.cidrs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
delete(bl.reason, cidr)
|
||||
}
|
||||
|
||||
// IsBlocked checks if an IP is blacklisted
|
||||
func (bl *IPBlacklist) IsBlocked(ip string) (blacklist bool, reason string) {
|
||||
bl.mu.RLock()
|
||||
defer bl.mu.RUnlock()
|
||||
|
||||
// Check individual IPs
|
||||
if bl.ips[ip] {
|
||||
return true, bl.reason[ip]
|
||||
}
|
||||
|
||||
// Check CIDR ranges
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
for i, ipNet := range bl.cidrs {
|
||||
if ipNet.Contains(parsedIP) {
|
||||
cidr := ipNet.String()
|
||||
// Try to find reason by CIDR or by index
|
||||
if reason, ok := bl.reason[cidr]; ok {
|
||||
return true, reason
|
||||
}
|
||||
// Check if reason was stored by original CIDR string
|
||||
for key, reason := range bl.reason {
|
||||
if strings.Contains(key, "/") && key == cidr {
|
||||
return true, reason
|
||||
}
|
||||
}
|
||||
// Return true even if no reason found
|
||||
if i < len(bl.cidrs) {
|
||||
return true, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// GetBlacklist returns all blacklisted IPs and CIDRs
|
||||
func (bl *IPBlacklist) GetBlacklist() (ips []string, cidrs []string) {
|
||||
bl.mu.RLock()
|
||||
defer bl.mu.RUnlock()
|
||||
|
||||
ips = make([]string, 0, len(bl.ips))
|
||||
for ip := range bl.ips {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
|
||||
cidrs = make([]string, 0, len(bl.cidrs))
|
||||
for _, ipNet := range bl.cidrs {
|
||||
cidrs = append(cidrs, ipNet.String())
|
||||
}
|
||||
|
||||
return ips, cidrs
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that blocks blacklisted IPs
|
||||
func (bl *IPBlacklist) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var clientIP string
|
||||
if bl.useProxy {
|
||||
clientIP = getClientIP(r)
|
||||
// Clean up IPv6 brackets if present
|
||||
clientIP = strings.Trim(clientIP, "[]")
|
||||
} else {
|
||||
// Extract IP from RemoteAddr
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||
clientIP = r.RemoteAddr[:idx]
|
||||
} else {
|
||||
clientIP = r.RemoteAddr
|
||||
}
|
||||
clientIP = strings.Trim(clientIP, "[]")
|
||||
}
|
||||
|
||||
blocked, reason := bl.IsBlocked(clientIP)
|
||||
if blocked {
|
||||
response := map[string]interface{}{
|
||||
"error": "forbidden",
|
||||
"message": "Access denied",
|
||||
}
|
||||
if reason != "" {
|
||||
response["reason"] = reason
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
err := json.NewEncoder(w).Encode(response)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to write blacklist response: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// StatsHandler returns an HTTP handler that shows blacklist statistics
|
||||
func (bl *IPBlacklist) StatsHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ips, cidrs := bl.GetBlacklist()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"blocked_ips": ips,
|
||||
"blocked_cidrs": cidrs,
|
||||
"total_ips": len(ips),
|
||||
"total_cidrs": len(cidrs),
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(stats)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to encode stats: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
254
pkg/middleware/blacklist_test.go
Normal file
254
pkg/middleware/blacklist_test.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIPBlacklist_BlockIP(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
|
||||
// Block an IP
|
||||
err := bl.BlockIP("192.168.1.100", "Suspicious activity")
|
||||
if err != nil {
|
||||
t.Fatalf("BlockIP() error = %v", err)
|
||||
}
|
||||
|
||||
// Check if IP is blocked
|
||||
blocked, reason := bl.IsBlocked("192.168.1.100")
|
||||
if !blocked {
|
||||
t.Error("IP should be blocked")
|
||||
}
|
||||
if reason != "Suspicious activity" {
|
||||
t.Errorf("Reason = %q, want %q", reason, "Suspicious activity")
|
||||
}
|
||||
|
||||
// Check non-blocked IP
|
||||
blocked, _ = bl.IsBlocked("192.168.1.1")
|
||||
if blocked {
|
||||
t.Error("IP should not be blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_BlockCIDR(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
|
||||
// Block a CIDR range
|
||||
err := bl.BlockCIDR("10.0.0.0/24", "Internal network blocked")
|
||||
if err != nil {
|
||||
t.Fatalf("BlockCIDR() error = %v", err)
|
||||
}
|
||||
|
||||
// Check IPs in range
|
||||
testIPs := []string{
|
||||
"10.0.0.1",
|
||||
"10.0.0.100",
|
||||
"10.0.0.254",
|
||||
}
|
||||
|
||||
for _, ip := range testIPs {
|
||||
blocked, _ := bl.IsBlocked(ip)
|
||||
if !blocked {
|
||||
t.Errorf("IP %s should be blocked by CIDR", ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Check IP outside range
|
||||
blocked, _ := bl.IsBlocked("10.0.1.1")
|
||||
if blocked {
|
||||
t.Error("IP outside CIDR range should not be blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_UnblockIP(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
|
||||
// Block and then unblock
|
||||
bl.BlockIP("192.168.1.100", "Test")
|
||||
|
||||
blocked, _ := bl.IsBlocked("192.168.1.100")
|
||||
if !blocked {
|
||||
t.Error("IP should be blocked")
|
||||
}
|
||||
|
||||
bl.UnblockIP("192.168.1.100")
|
||||
|
||||
blocked, _ = bl.IsBlocked("192.168.1.100")
|
||||
if blocked {
|
||||
t.Error("IP should be unblocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_UnblockCIDR(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
|
||||
// Block and then unblock CIDR
|
||||
bl.BlockCIDR("10.0.0.0/24", "Test")
|
||||
|
||||
blocked, _ := bl.IsBlocked("10.0.0.1")
|
||||
if !blocked {
|
||||
t.Error("IP should be blocked by CIDR")
|
||||
}
|
||||
|
||||
bl.UnblockCIDR("10.0.0.0/24")
|
||||
|
||||
blocked, _ = bl.IsBlocked("10.0.0.1")
|
||||
if blocked {
|
||||
t.Error("IP should be unblocked after CIDR removal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_Middleware(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
bl.BlockIP("192.168.1.100", "Banned")
|
||||
|
||||
handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
|
||||
// Blocked IP should get 403
|
||||
t.Run("BlockedIP", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.100:12345"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
|
||||
var response map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if response["error"] != "forbidden" {
|
||||
t.Errorf("Error = %v, want %q", response["error"], "forbidden")
|
||||
}
|
||||
})
|
||||
|
||||
// Allowed IP should succeed
|
||||
t.Run("AllowedIP", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIPBlacklist_MiddlewareWithProxy(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: true})
|
||||
bl.BlockIP("203.0.113.1", "Blocked via proxy")
|
||||
|
||||
handler := bl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Test X-Forwarded-For
|
||||
t.Run("X-Forwarded-For", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "10.0.0.1:12345"
|
||||
req.Header.Set("X-Forwarded-For", "203.0.113.1")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
|
||||
// Test X-Real-IP
|
||||
t.Run("X-Real-IP", func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "10.0.0.1:12345"
|
||||
req.Header.Set("X-Real-IP", "203.0.113.1")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIPBlacklist_StatsHandler(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
bl.BlockIP("192.168.1.100", "Test1")
|
||||
bl.BlockIP("192.168.1.101", "Test2")
|
||||
bl.BlockCIDR("10.0.0.0/24", "Test CIDR")
|
||||
|
||||
handler := bl.StatsHandler()
|
||||
|
||||
req := httptest.NewRequest("GET", "/blacklist-stats", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var stats map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if int(stats["total_ips"].(float64)) != 2 {
|
||||
t.Errorf("total_ips = %v, want 2", stats["total_ips"])
|
||||
}
|
||||
|
||||
if int(stats["total_cidrs"].(float64)) != 1 {
|
||||
t.Errorf("total_cidrs = %v, want 1", stats["total_cidrs"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_GetBlacklist(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
bl.BlockIP("192.168.1.100", "")
|
||||
bl.BlockIP("192.168.1.101", "")
|
||||
bl.BlockCIDR("10.0.0.0/24", "")
|
||||
|
||||
ips, cidrs := bl.GetBlacklist()
|
||||
|
||||
if len(ips) != 2 {
|
||||
t.Errorf("len(ips) = %d, want 2", len(ips))
|
||||
}
|
||||
|
||||
if len(cidrs) != 1 {
|
||||
t.Errorf("len(cidrs) = %d, want 1", len(cidrs))
|
||||
}
|
||||
|
||||
// Verify CIDR format
|
||||
if cidrs[0] != "10.0.0.0/24" {
|
||||
t.Errorf("CIDR = %q, want %q", cidrs[0], "10.0.0.0/24")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_InvalidIP(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
|
||||
err := bl.BlockIP("invalid-ip", "Test")
|
||||
if err == nil {
|
||||
t.Error("BlockIP() should return error for invalid IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlacklist_InvalidCIDR(t *testing.T) {
|
||||
bl := NewIPBlacklist(BlacklistConfig{UseProxy: false})
|
||||
|
||||
err := bl.BlockCIDR("invalid-cidr", "Test")
|
||||
if err == nil {
|
||||
t.Error("BlockCIDR() should return error for invalid CIDR")
|
||||
}
|
||||
}
|
||||
233
pkg/middleware/ratelimit.go
Normal file
233
pkg/middleware/ratelimit.go
Normal file
@@ -0,0 +1,233 @@
|
||||
// Package middleware provides HTTP middleware functionalities such as rate limiting and IP blacklisting.
|
||||
package middleware
|
||||
|
||||
//nolint:all
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimiter provides rate limiting functionality
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
limiters map[string]*rate.Limiter
|
||||
rate rate.Limit
|
||||
burst int
|
||||
cleanup time.Duration
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
// rps is requests per second, burst is the maximum burst size
|
||||
func NewRateLimiter(rps float64, burst int) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
rate: rate.Limit(rps),
|
||||
burst: burst,
|
||||
cleanup: 5 * time.Minute, // Clean up stale limiters every 5 minutes
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go rl.cleanupRoutine()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// getLimiter returns the rate limiter for a given key (e.g., IP address)
|
||||
func (rl *RateLimiter) getLimiter(key string) *rate.Limiter {
|
||||
rl.mu.RLock()
|
||||
limiter, exists := rl.limiters[key]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
// Double-check after acquiring write lock
|
||||
if limiter, exists := rl.limiters[key]; exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
limiter = rate.NewLimiter(rl.rate, rl.burst)
|
||||
rl.limiters[key] = limiter
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanupRoutine periodically removes inactive limiters
|
||||
func (rl *RateLimiter) cleanupRoutine() {
|
||||
ticker := time.NewTicker(rl.cleanup)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
// Simple cleanup: remove all limiters
|
||||
// In production, you might want to track last access time
|
||||
rl.limiters = make(map[string]*rate.Limiter)
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that applies rate limiting
|
||||
// Automatically handles X-Forwarded-For headers when behind a proxy
|
||||
func (rl *RateLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Extract client IP, handling proxy headers
|
||||
key := getClientIP(r)
|
||||
|
||||
limiter := rl.getLimiter(key)
|
||||
|
||||
if !limiter.Allow() {
|
||||
http.Error(w, `{"error":"rate_limit_exceeded","message":"Too many requests"}`, http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// MiddlewareWithKeyFunc returns an HTTP middleware with a custom key extraction function
|
||||
func (rl *RateLimiter) MiddlewareWithKeyFunc(keyFunc func(*http.Request) string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := keyFunc(r)
|
||||
if key == "" {
|
||||
key = r.RemoteAddr
|
||||
}
|
||||
|
||||
limiter := rl.getLimiter(key)
|
||||
|
||||
if !limiter.Allow() {
|
||||
http.Error(w, `{"error":"rate_limit_exceeded","message":"Too many requests"}`, http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitInfo contains information about a specific IP's rate limit status
|
||||
type RateLimitInfo struct {
|
||||
IP string `json:"ip"`
|
||||
TokensRemaining float64 `json:"tokens_remaining"`
|
||||
Limit float64 `json:"limit"`
|
||||
Burst int `json:"burst"`
|
||||
}
|
||||
|
||||
// GetTrackedIPs returns all IPs currently being tracked by the rate limiter
|
||||
func (rl *RateLimiter) GetTrackedIPs() []string {
|
||||
rl.mu.RLock()
|
||||
defer rl.mu.RUnlock()
|
||||
|
||||
ips := make([]string, 0, len(rl.limiters))
|
||||
for ip := range rl.limiters {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
return ips
|
||||
}
|
||||
|
||||
// GetRateLimitInfo returns rate limit information for a specific IP
|
||||
func (rl *RateLimiter) GetRateLimitInfo(ip string) *RateLimitInfo {
|
||||
rl.mu.RLock()
|
||||
limiter, exists := rl.limiters[ip]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// Return default info for untracked IP
|
||||
return &RateLimitInfo{
|
||||
IP: ip,
|
||||
TokensRemaining: float64(rl.burst),
|
||||
Limit: float64(rl.rate),
|
||||
Burst: rl.burst,
|
||||
}
|
||||
}
|
||||
|
||||
return &RateLimitInfo{
|
||||
IP: ip,
|
||||
TokensRemaining: limiter.Tokens(),
|
||||
Limit: float64(rl.rate),
|
||||
Burst: rl.burst,
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllRateLimitInfo returns rate limit information for all tracked IPs
|
||||
func (rl *RateLimiter) GetAllRateLimitInfo() []*RateLimitInfo {
|
||||
ips := rl.GetTrackedIPs()
|
||||
info := make([]*RateLimitInfo, 0, len(ips))
|
||||
|
||||
for _, ip := range ips {
|
||||
info = append(info, rl.GetRateLimitInfo(ip))
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// StatsHandler returns an HTTP handler that exposes rate limit statistics
|
||||
// Example: GET /rate-limit-stats
|
||||
func (rl *RateLimiter) StatsHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Support querying specific IP via ?ip=x.x.x.x
|
||||
if ip := r.URL.Query().Get("ip"); ip != "" {
|
||||
info := rl.GetRateLimitInfo(ip)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(info)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to encode json: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Return all tracked IPs
|
||||
allInfo := rl.GetAllRateLimitInfo()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_tracked_ips": len(allInfo),
|
||||
"rate_limit_config": map[string]interface{}{
|
||||
"requests_per_second": float64(rl.rate),
|
||||
"burst": rl.burst,
|
||||
},
|
||||
"tracked_ips": allInfo,
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(stats)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to encode json: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// getClientIP extracts the real client IP from the request
|
||||
// Handles X-Forwarded-For, X-Real-IP, and falls back to RemoteAddr
|
||||
func getClientIP(r *http.Request) string {
|
||||
// Check X-Forwarded-For header (most common in production)
|
||||
// Format: X-Forwarded-For: client, proxy1, proxy2
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first IP (the original client)
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
|
||||
// Check X-Real-IP header (used by some proxies like nginx)
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
|
||||
// Fall back to RemoteAddr
|
||||
// Remove port if present (format: "ip:port")
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||
return r.RemoteAddr[:idx]
|
||||
}
|
||||
|
||||
return r.RemoteAddr
|
||||
}
|
||||
388
pkg/middleware/ratelimit_test.go
Normal file
388
pkg/middleware/ratelimit_test.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiter(t *testing.T) {
|
||||
// Create rate limiter: 2 requests per second, burst of 2
|
||||
rl := NewRateLimiter(2, 2)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("OK"))
|
||||
}))
|
||||
|
||||
// First request should succeed
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("First request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Second request should succeed (within burst)
|
||||
w = httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Second request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Third request should be rate limited
|
||||
w = httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("Third request should be rate limited: got %d, want %d", w.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
// Wait for rate limiter to refill
|
||||
time.Sleep(600 * time.Millisecond)
|
||||
|
||||
// Request should succeed again
|
||||
w = httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Request after wait failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterDifferentIPs(t *testing.T) {
|
||||
rl := NewRateLimiter(1, 1)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First IP
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.1:12345"
|
||||
|
||||
// Second IP
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.2:12345"
|
||||
|
||||
// Both should succeed (different IPs)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Errorf("First IP request failed: got %d, want %d", w1.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("Second IP request failed: got %d, want %d", w2.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
xRealIP string
|
||||
expectedIP string
|
||||
}{
|
||||
{
|
||||
name: "RemoteAddr only",
|
||||
remoteAddr: "192.168.1.1:12345",
|
||||
expectedIP: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For single IP",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
xForwardedFor: "203.0.113.1",
|
||||
expectedIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For multiple IPs",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
xForwardedFor: "203.0.113.1, 10.0.0.2, 10.0.0.3",
|
||||
expectedIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
xRealIP: "203.0.113.1",
|
||||
expectedIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "X-Forwarded-For takes precedence over X-Real-IP",
|
||||
remoteAddr: "10.0.0.1:12345",
|
||||
xForwardedFor: "203.0.113.1",
|
||||
xRealIP: "203.0.113.2",
|
||||
expectedIP: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "IPv6 address",
|
||||
remoteAddr: "[2001:db8::1]:12345",
|
||||
expectedIP: "[2001:db8::1]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = tt.remoteAddr
|
||||
|
||||
if tt.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
||||
}
|
||||
if tt.xRealIP != "" {
|
||||
req.Header.Set("X-Real-IP", tt.xRealIP)
|
||||
}
|
||||
|
||||
ip := getClientIP(req)
|
||||
if ip != tt.expectedIP {
|
||||
t.Errorf("getClientIP() = %q, want %q", ip, tt.expectedIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterWithCustomKeyFunc(t *testing.T) {
|
||||
rl := NewRateLimiter(1, 1)
|
||||
|
||||
// Use user ID as key
|
||||
keyFunc := func(r *http.Request) string {
|
||||
userID := r.Header.Get("X-User-ID")
|
||||
if userID == "" {
|
||||
return r.RemoteAddr
|
||||
}
|
||||
return "user:" + userID
|
||||
}
|
||||
|
||||
handler := rl.MiddlewareWithKeyFunc(keyFunc)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// User 1
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.Header.Set("X-User-ID", "user1")
|
||||
|
||||
// User 2
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.Header.Set("X-User-ID", "user2")
|
||||
|
||||
// Both users should succeed (different keys)
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
if w1.Code != http.StatusOK {
|
||||
t.Errorf("User 1 request failed: got %d, want %d", w1.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
if w2.Code != http.StatusOK {
|
||||
t.Errorf("User 2 request failed: got %d, want %d", w2.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// User 1 second request should be rate limited
|
||||
w1 = httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
if w1.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("User 1 second request should be rate limited: got %d, want %d", w1.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_GetTrackedIPs(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 10)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make requests from different IPs
|
||||
ips := []string{"192.168.1.1", "192.168.1.2", "192.168.1.3"}
|
||||
for _, ip := range ips {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = ip + ":12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// Check tracked IPs
|
||||
trackedIPs := rl.GetTrackedIPs()
|
||||
if len(trackedIPs) != len(ips) {
|
||||
t.Errorf("len(trackedIPs) = %d, want %d", len(trackedIPs), len(ips))
|
||||
}
|
||||
|
||||
// Verify all IPs are tracked
|
||||
ipMap := make(map[string]bool)
|
||||
for _, ip := range trackedIPs {
|
||||
ipMap[ip] = true
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
if !ipMap[ip] {
|
||||
t.Errorf("IP %s should be tracked", ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_GetRateLimitInfo(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 5)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make a request
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = "192.168.1.1:12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
// Get rate limit info
|
||||
info := rl.GetRateLimitInfo("192.168.1.1")
|
||||
|
||||
if info.IP != "192.168.1.1" {
|
||||
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
|
||||
}
|
||||
|
||||
if info.Limit != 10.0 {
|
||||
t.Errorf("Limit = %f, want 10.0", info.Limit)
|
||||
}
|
||||
|
||||
if info.Burst != 5 {
|
||||
t.Errorf("Burst = %d, want 5", info.Burst)
|
||||
}
|
||||
|
||||
// Tokens should be less than burst after one request
|
||||
if info.TokensRemaining >= float64(info.Burst) {
|
||||
t.Errorf("TokensRemaining = %f, should be less than %d", info.TokensRemaining, info.Burst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_GetRateLimitInfo_UntrackedIP(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 5)
|
||||
|
||||
// Get info for untracked IP (should return default)
|
||||
info := rl.GetRateLimitInfo("192.168.1.1")
|
||||
|
||||
if info.IP != "192.168.1.1" {
|
||||
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
|
||||
}
|
||||
|
||||
if info.TokensRemaining != float64(rl.burst) {
|
||||
t.Errorf("TokensRemaining = %f, want %d (full burst)", info.TokensRemaining, rl.burst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_GetAllRateLimitInfo(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 10)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make requests from different IPs
|
||||
ips := []string{"192.168.1.1", "192.168.1.2"}
|
||||
for _, ip := range ips {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.RemoteAddr = ip + ":12345"
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// Get all rate limit info
|
||||
allInfo := rl.GetAllRateLimitInfo()
|
||||
|
||||
if len(allInfo) != len(ips) {
|
||||
t.Errorf("len(allInfo) = %d, want %d", len(allInfo), len(ips))
|
||||
}
|
||||
|
||||
// Verify each IP has info
|
||||
for _, info := range allInfo {
|
||||
found := false
|
||||
for _, ip := range ips {
|
||||
if info.IP == ip {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Unexpected IP in info: %s", info.IP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_StatsHandler(t *testing.T) {
|
||||
rl := NewRateLimiter(10, 5)
|
||||
|
||||
handler := rl.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make requests from different IPs
|
||||
req1 := httptest.NewRequest("GET", "/test", nil)
|
||||
req1.RemoteAddr = "192.168.1.1:12345"
|
||||
w1 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w1, req1)
|
||||
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.RemoteAddr = "192.168.1.2:12345"
|
||||
w2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w2, req2)
|
||||
|
||||
// Test stats handler (all IPs)
|
||||
t.Run("AllIPs", func(t *testing.T) {
|
||||
statsHandler := rl.StatsHandler()
|
||||
req := httptest.NewRequest("GET", "/rate-limit-stats", nil)
|
||||
w := httptest.NewRecorder()
|
||||
statsHandler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var stats map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &stats); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if int(stats["total_tracked_ips"].(float64)) != 2 {
|
||||
t.Errorf("total_tracked_ips = %v, want 2", stats["total_tracked_ips"])
|
||||
}
|
||||
|
||||
config := stats["rate_limit_config"].(map[string]interface{})
|
||||
if config["requests_per_second"].(float64) != 10.0 {
|
||||
t.Errorf("requests_per_second = %v, want 10.0", config["requests_per_second"])
|
||||
}
|
||||
})
|
||||
|
||||
// Test stats handler (specific IP)
|
||||
t.Run("SpecificIP", func(t *testing.T) {
|
||||
statsHandler := rl.StatsHandler()
|
||||
req := httptest.NewRequest("GET", "/rate-limit-stats?ip=192.168.1.1", nil)
|
||||
w := httptest.NewRecorder()
|
||||
statsHandler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Status = %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
var info RateLimitInfo
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &info); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if info.IP != "192.168.1.1" {
|
||||
t.Errorf("IP = %q, want %q", info.IP, "192.168.1.1")
|
||||
}
|
||||
})
|
||||
}
|
||||
251
pkg/middleware/sanitize.go
Normal file
251
pkg/middleware/sanitize.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"html"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Sanitizer provides input sanitization beyond SQL injection protection
|
||||
type Sanitizer struct {
|
||||
// StripHTML removes HTML tags from input
|
||||
StripHTML bool
|
||||
|
||||
// EscapeHTML escapes HTML entities
|
||||
EscapeHTML bool
|
||||
|
||||
// RemoveNullBytes removes null bytes from input
|
||||
RemoveNullBytes bool
|
||||
|
||||
// RemoveControlChars removes control characters (except newline, carriage return, tab)
|
||||
RemoveControlChars bool
|
||||
|
||||
// MaxStringLength limits individual string field length (0 = no limit)
|
||||
MaxStringLength int
|
||||
|
||||
// BlockPatterns are regex patterns to block (e.g., script tags, SQL keywords)
|
||||
BlockPatterns []*regexp.Regexp
|
||||
|
||||
// Custom sanitization function
|
||||
CustomSanitizer func(string) string
|
||||
}
|
||||
|
||||
// DefaultSanitizer returns a sanitizer with secure defaults
|
||||
func DefaultSanitizer() *Sanitizer {
|
||||
return &Sanitizer{
|
||||
StripHTML: false, // Don't strip by default (breaks legitimate HTML content)
|
||||
EscapeHTML: true, // Escape HTML entities to prevent XSS
|
||||
RemoveNullBytes: true, // Remove null bytes (security best practice)
|
||||
RemoveControlChars: true, // Remove dangerous control characters
|
||||
MaxStringLength: 0, // No limit by default
|
||||
|
||||
// Block common XSS and injection patterns
|
||||
BlockPatterns: []*regexp.Regexp{
|
||||
regexp.MustCompile(`(?i)<script[^>]*>.*?</script>`), // Script tags
|
||||
regexp.MustCompile(`(?i)javascript:`), // JavaScript protocol
|
||||
regexp.MustCompile(`(?i)on\w+\s*=`), // Event handlers (onclick, onerror, etc.)
|
||||
regexp.MustCompile(`(?i)<iframe[^>]*>`), // Iframes
|
||||
regexp.MustCompile(`(?i)<object[^>]*>`), // Objects
|
||||
regexp.MustCompile(`(?i)<embed[^>]*>`), // Embeds
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// StrictSanitizer returns a sanitizer with very strict rules
|
||||
func StrictSanitizer() *Sanitizer {
|
||||
s := DefaultSanitizer()
|
||||
s.StripHTML = true
|
||||
s.MaxStringLength = 10000
|
||||
return s
|
||||
}
|
||||
|
||||
// Sanitize sanitizes a string value
|
||||
func (s *Sanitizer) Sanitize(value string) string {
|
||||
if value == "" {
|
||||
return value
|
||||
}
|
||||
|
||||
// Remove null bytes
|
||||
if s.RemoveNullBytes {
|
||||
value = strings.ReplaceAll(value, "\x00", "")
|
||||
}
|
||||
|
||||
// Remove control characters
|
||||
if s.RemoveControlChars {
|
||||
value = removeControlCharacters(value)
|
||||
}
|
||||
|
||||
// Check block patterns
|
||||
for _, pattern := range s.BlockPatterns {
|
||||
if pattern.MatchString(value) {
|
||||
// Replace matched pattern with empty string
|
||||
value = pattern.ReplaceAllString(value, "")
|
||||
}
|
||||
}
|
||||
|
||||
// Strip HTML tags
|
||||
if s.StripHTML {
|
||||
value = stripHTMLTags(value)
|
||||
}
|
||||
|
||||
// Escape HTML entities
|
||||
if s.EscapeHTML && !s.StripHTML {
|
||||
value = html.EscapeString(value)
|
||||
}
|
||||
|
||||
// Apply max length
|
||||
if s.MaxStringLength > 0 && len(value) > s.MaxStringLength {
|
||||
value = value[:s.MaxStringLength]
|
||||
}
|
||||
|
||||
// Apply custom sanitizer
|
||||
if s.CustomSanitizer != nil {
|
||||
value = s.CustomSanitizer(value)
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
// SanitizeMap sanitizes all string values in a map
|
||||
func (s *Sanitizer) SanitizeMap(data map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
for key, value := range data {
|
||||
result[key] = s.sanitizeValue(value)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// sanitizeValue recursively sanitizes values
|
||||
func (s *Sanitizer) sanitizeValue(value interface{}) interface{} {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return s.Sanitize(v)
|
||||
case map[string]interface{}:
|
||||
return s.SanitizeMap(v)
|
||||
case []interface{}:
|
||||
result := make([]interface{}, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = s.sanitizeValue(item)
|
||||
}
|
||||
return result
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that sanitizes request headers and query params
|
||||
// Note: Body sanitization should be done at the application level after parsing
|
||||
func (s *Sanitizer) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Sanitize query parameters
|
||||
if r.URL.RawQuery != "" {
|
||||
q := r.URL.Query()
|
||||
sanitized := false
|
||||
for key, values := range q {
|
||||
for i, value := range values {
|
||||
sanitizedValue := s.Sanitize(value)
|
||||
if sanitizedValue != value {
|
||||
values[i] = sanitizedValue
|
||||
sanitized = true
|
||||
}
|
||||
}
|
||||
if sanitized {
|
||||
q[key] = values
|
||||
}
|
||||
}
|
||||
if sanitized {
|
||||
r.URL.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
// Sanitize specific headers (User-Agent, Referer, etc.)
|
||||
dangerousHeaders := []string{
|
||||
"User-Agent",
|
||||
"Referer",
|
||||
"X-Forwarded-For",
|
||||
"X-Real-IP",
|
||||
}
|
||||
|
||||
for _, header := range dangerousHeaders {
|
||||
if value := r.Header.Get(header); value != "" {
|
||||
sanitized := s.Sanitize(value)
|
||||
if sanitized != value {
|
||||
r.Header.Set(header, sanitized)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// removeControlCharacters removes control characters except \n, \r, \t
|
||||
func removeControlCharacters(s string) string {
|
||||
var result strings.Builder
|
||||
for _, r := range s {
|
||||
// Keep newline, carriage return, tab, and non-control characters
|
||||
if r == '\n' || r == '\r' || r == '\t' || r >= 32 {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// stripHTMLTags removes HTML tags from a string
|
||||
func stripHTMLTags(s string) string {
|
||||
// Simple regex to remove HTML tags
|
||||
re := regexp.MustCompile(`<[^>]*>`)
|
||||
return re.ReplaceAllString(s, "")
|
||||
}
|
||||
|
||||
// Common sanitization patterns
|
||||
|
||||
// SanitizeFilename sanitizes a filename
|
||||
func SanitizeFilename(filename string) string {
|
||||
// Remove path traversal attempts
|
||||
filename = strings.ReplaceAll(filename, "..", "")
|
||||
filename = strings.ReplaceAll(filename, "/", "")
|
||||
filename = strings.ReplaceAll(filename, "\\", "")
|
||||
|
||||
// Remove null bytes
|
||||
filename = strings.ReplaceAll(filename, "\x00", "")
|
||||
|
||||
// Limit length
|
||||
if len(filename) > 255 {
|
||||
filename = filename[:255]
|
||||
}
|
||||
|
||||
return filename
|
||||
}
|
||||
|
||||
// SanitizeEmail performs basic email sanitization
|
||||
func SanitizeEmail(email string) string {
|
||||
email = strings.TrimSpace(strings.ToLower(email))
|
||||
|
||||
// Remove dangerous characters
|
||||
email = strings.ReplaceAll(email, "\x00", "")
|
||||
email = removeControlCharacters(email)
|
||||
|
||||
return email
|
||||
}
|
||||
|
||||
// SanitizeURL performs basic URL sanitization
|
||||
func SanitizeURL(url string) string {
|
||||
url = strings.TrimSpace(url)
|
||||
|
||||
// Remove null bytes
|
||||
url = strings.ReplaceAll(url, "\x00", "")
|
||||
|
||||
// Block javascript: and data: protocols
|
||||
if strings.HasPrefix(strings.ToLower(url), "javascript:") {
|
||||
return ""
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(url), "data:") {
|
||||
return ""
|
||||
}
|
||||
|
||||
return url
|
||||
}
|
||||
273
pkg/middleware/sanitize_test.go
Normal file
273
pkg/middleware/sanitize_test.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSanitizeXSS(t *testing.T) {
|
||||
sanitizer := DefaultSanitizer()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
contains string // String that should NOT be in output
|
||||
}{
|
||||
{
|
||||
name: "Script tag",
|
||||
input: "<script>alert(1)</script>",
|
||||
contains: "<script>",
|
||||
},
|
||||
{
|
||||
name: "JavaScript protocol",
|
||||
input: "javascript:alert(1)",
|
||||
contains: "javascript:",
|
||||
},
|
||||
{
|
||||
name: "Event handler",
|
||||
input: "<img onerror='alert(1)'>",
|
||||
contains: "onerror=",
|
||||
},
|
||||
{
|
||||
name: "Iframe",
|
||||
input: "<iframe src='evil.com'></iframe>",
|
||||
contains: "<iframe",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sanitizer.Sanitize(tt.input)
|
||||
if result == tt.input {
|
||||
t.Errorf("Sanitize() did not modify input: %q", tt.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeNullBytes(t *testing.T) {
|
||||
sanitizer := DefaultSanitizer()
|
||||
|
||||
input := "hello\x00world"
|
||||
result := sanitizer.Sanitize(input)
|
||||
|
||||
if result == input {
|
||||
t.Error("Null bytes should be removed")
|
||||
}
|
||||
|
||||
if len(result) >= len(input) {
|
||||
t.Errorf("Result length should be less than input: got %d, input %d", len(result), len(input))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeControlCharacters(t *testing.T) {
|
||||
sanitizer := DefaultSanitizer()
|
||||
|
||||
// Include various control characters
|
||||
input := "hello\x01\x02world\x1F"
|
||||
result := sanitizer.Sanitize(input)
|
||||
|
||||
if result == input {
|
||||
t.Error("Control characters should be removed")
|
||||
}
|
||||
|
||||
// Newlines, tabs, carriage returns should be preserved
|
||||
input2 := "hello\nworld\t\r"
|
||||
result2 := sanitizer.Sanitize(input2)
|
||||
|
||||
if result2 != input2 {
|
||||
t.Errorf("Safe control characters should be preserved: got %q, want %q", result2, input2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeMap(t *testing.T) {
|
||||
sanitizer := DefaultSanitizer()
|
||||
|
||||
input := map[string]interface{}{
|
||||
"name": "<script>alert(1)</script>John",
|
||||
"email": "test@example.com",
|
||||
"nested": map[string]interface{}{
|
||||
"bio": "<iframe src='evil.com'>Bio</iframe>",
|
||||
},
|
||||
}
|
||||
|
||||
result := sanitizer.SanitizeMap(input)
|
||||
|
||||
// Check that script tag was removed/escaped
|
||||
name, ok := result["name"].(string)
|
||||
if !ok || name == input["name"] {
|
||||
t.Error("Name should be sanitized")
|
||||
}
|
||||
|
||||
// Check nested map
|
||||
nested, ok := result["nested"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Nested should still be a map")
|
||||
}
|
||||
|
||||
bio, ok := nested["bio"].(string)
|
||||
if !ok || bio == input["nested"].(map[string]interface{})["bio"] {
|
||||
t.Error("Nested bio should be sanitized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeMiddleware(t *testing.T) {
|
||||
sanitizer := DefaultSanitizer()
|
||||
|
||||
handler := sanitizer.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check that query param was sanitized
|
||||
param := r.URL.Query().Get("q")
|
||||
if param == "<script>alert(1)</script>" {
|
||||
t.Error("Query param should be sanitized")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test?q=<script>alert(1)</script>", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Handler failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeFilename(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
contains string // String that should NOT be in output
|
||||
}{
|
||||
{
|
||||
name: "Path traversal",
|
||||
input: "../../../etc/passwd",
|
||||
contains: "..",
|
||||
},
|
||||
{
|
||||
name: "Absolute path",
|
||||
input: "/etc/passwd",
|
||||
contains: "/",
|
||||
},
|
||||
{
|
||||
name: "Windows path",
|
||||
input: "..\\..\\windows\\system32",
|
||||
contains: "\\",
|
||||
},
|
||||
{
|
||||
name: "Null byte",
|
||||
input: "file\x00.txt",
|
||||
contains: "\x00",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeFilename(tt.input)
|
||||
if result == tt.input {
|
||||
t.Errorf("SanitizeFilename() did not modify input: %q", tt.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Uppercase",
|
||||
input: "TEST@EXAMPLE.COM",
|
||||
expected: "test@example.com",
|
||||
},
|
||||
{
|
||||
name: "Whitespace",
|
||||
input: " test@example.com ",
|
||||
expected: "test@example.com",
|
||||
},
|
||||
{
|
||||
name: "Null bytes",
|
||||
input: "test\x00@example.com",
|
||||
expected: "test@example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeEmail(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeEmail() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "JavaScript protocol",
|
||||
input: "javascript:alert(1)",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Data protocol",
|
||||
input: "data:text/html,<script>alert(1)</script>",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Valid HTTP URL",
|
||||
input: "https://example.com",
|
||||
expected: "https://example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeURL(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeURL() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStrictSanitizer(t *testing.T) {
|
||||
sanitizer := StrictSanitizer()
|
||||
|
||||
input := "<b>Bold text</b> with <script>alert(1)</script>"
|
||||
result := sanitizer.Sanitize(input)
|
||||
|
||||
// Should strip ALL HTML tags
|
||||
if result == input {
|
||||
t.Error("Strict sanitizer should modify input")
|
||||
}
|
||||
|
||||
// Should not contain any HTML tags
|
||||
if len(result) > 0 && (result[0] == '<' || result[len(result)-1] == '>') {
|
||||
t.Error("Result should not contain HTML tags")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxStringLength(t *testing.T) {
|
||||
sanitizer := &Sanitizer{
|
||||
MaxStringLength: 10,
|
||||
}
|
||||
|
||||
input := "This is a very long string that exceeds the maximum length"
|
||||
result := sanitizer.Sanitize(input)
|
||||
|
||||
if len(result) != 10 {
|
||||
t.Errorf("Result length = %d, want 10", len(result))
|
||||
}
|
||||
|
||||
if result != input[:10] {
|
||||
t.Errorf("Result = %q, want %q", result, input[:10])
|
||||
}
|
||||
}
|
||||
70
pkg/middleware/sizelimit.go
Normal file
70
pkg/middleware/sizelimit.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultMaxRequestSize is the default maximum request body size (10MB)
|
||||
DefaultMaxRequestSize = 10 * 1024 * 1024 // 10MB
|
||||
|
||||
// MaxRequestSizeHeader is the header name for max request size
|
||||
MaxRequestSizeHeader = "X-Max-Request-Size"
|
||||
)
|
||||
|
||||
// RequestSizeLimiter limits the size of request bodies
|
||||
type RequestSizeLimiter struct {
|
||||
maxSize int64
|
||||
}
|
||||
|
||||
// NewRequestSizeLimiter creates a new request size limiter
|
||||
// maxSize is in bytes. If 0, uses DefaultMaxRequestSize (10MB)
|
||||
func NewRequestSizeLimiter(maxSize int64) *RequestSizeLimiter {
|
||||
if maxSize <= 0 {
|
||||
maxSize = DefaultMaxRequestSize
|
||||
}
|
||||
return &RequestSizeLimiter{
|
||||
maxSize: maxSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that enforces request size limits
|
||||
func (rsl *RequestSizeLimiter) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Set max bytes reader on the request body
|
||||
r.Body = http.MaxBytesReader(w, r.Body, rsl.maxSize)
|
||||
|
||||
// Add informational header
|
||||
w.Header().Set(MaxRequestSizeHeader, fmt.Sprintf("%d", rsl.maxSize))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// MiddlewareWithCustomSize returns middleware with a custom size limit function
|
||||
// This allows different size limits based on the request
|
||||
func (rsl *RequestSizeLimiter) MiddlewareWithCustomSize(sizeFunc func(*http.Request) int64) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
maxSize := sizeFunc(r)
|
||||
if maxSize <= 0 {
|
||||
maxSize = rsl.maxSize
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxSize)
|
||||
w.Header().Set(MaxRequestSizeHeader, fmt.Sprintf("%d", maxSize))
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Common size limits
|
||||
const (
|
||||
Size1MB = 1 * 1024 * 1024
|
||||
Size5MB = 5 * 1024 * 1024
|
||||
Size10MB = 10 * 1024 * 1024
|
||||
Size50MB = 50 * 1024 * 1024
|
||||
Size100MB = 100 * 1024 * 1024
|
||||
)
|
||||
126
pkg/middleware/sizelimit_test.go
Normal file
126
pkg/middleware/sizelimit_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRequestSizeLimiter(t *testing.T) {
|
||||
// 1KB limit
|
||||
limiter := NewRequestSizeLimiter(1024)
|
||||
|
||||
handler := limiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Try to read body
|
||||
_, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Small request (should succeed)
|
||||
t.Run("SmallRequest", func(t *testing.T) {
|
||||
body := bytes.NewReader(make([]byte, 512)) // 512 bytes
|
||||
req := httptest.NewRequest("POST", "/test", body)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Small request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Check header
|
||||
if maxSize := w.Header().Get(MaxRequestSizeHeader); maxSize != "1024" {
|
||||
t.Errorf("MaxRequestSizeHeader = %q, want %q", maxSize, "1024")
|
||||
}
|
||||
})
|
||||
|
||||
// Large request (should fail)
|
||||
t.Run("LargeRequest", func(t *testing.T) {
|
||||
body := bytes.NewReader(make([]byte, 2048)) // 2KB
|
||||
req := httptest.NewRequest("POST", "/test", body)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Errorf("Large request should fail: got %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRequestSizeLimiterDefault(t *testing.T) {
|
||||
// Default limiter (10MB)
|
||||
limiter := NewRequestSizeLimiter(0)
|
||||
|
||||
handler := limiter.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("POST", "/test", bytes.NewReader(make([]byte, 1024)))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
// Check default size
|
||||
if maxSize := w.Header().Get(MaxRequestSizeHeader); maxSize != "10485760" {
|
||||
t.Errorf("Default MaxRequestSizeHeader = %q, want %q", maxSize, "10485760")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestSizeLimiterWithCustomSize(t *testing.T) {
|
||||
limiter := NewRequestSizeLimiter(1024)
|
||||
|
||||
// Premium users get 10MB, regular users get 1KB
|
||||
sizeFunc := func(r *http.Request) int64 {
|
||||
if r.Header.Get("X-User-Tier") == "premium" {
|
||||
return Size10MB
|
||||
}
|
||||
return 1024
|
||||
}
|
||||
|
||||
handler := limiter.MiddlewareWithCustomSize(sizeFunc)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Regular user with large request (should fail)
|
||||
t.Run("RegularUserLargeRequest", func(t *testing.T) {
|
||||
body := bytes.NewReader(make([]byte, 2048))
|
||||
req := httptest.NewRequest("POST", "/test", body)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Errorf("Regular user large request should fail: got %d, want %d", w.Code, http.StatusRequestEntityTooLarge)
|
||||
}
|
||||
})
|
||||
|
||||
// Premium user with large request (should succeed)
|
||||
t.Run("PremiumUserLargeRequest", func(t *testing.T) {
|
||||
body := bytes.NewReader(make([]byte, 2048))
|
||||
req := httptest.NewRequest("POST", "/test", body)
|
||||
req.Header.Set("X-User-Tier", "premium")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Premium user large request failed: got %d, want %d", w.Code, http.StatusOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -17,6 +17,10 @@ var defaultRegistry = &DefaultModelRegistry{
|
||||
models: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Global list of registries (searched in order)
|
||||
var registries = []*DefaultModelRegistry{defaultRegistry}
|
||||
var registriesMutex sync.RWMutex
|
||||
|
||||
// NewModelRegistry creates a new model registry
|
||||
func NewModelRegistry() *DefaultModelRegistry {
|
||||
return &DefaultModelRegistry{
|
||||
@@ -24,6 +28,33 @@ func NewModelRegistry() *DefaultModelRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
func SetDefaultRegistry(registry *DefaultModelRegistry) {
|
||||
registriesMutex.Lock()
|
||||
defer registriesMutex.Unlock()
|
||||
|
||||
foundAt := -1
|
||||
for idx, r := range registries {
|
||||
if r == defaultRegistry {
|
||||
foundAt = idx
|
||||
break
|
||||
}
|
||||
}
|
||||
defaultRegistry = registry
|
||||
if foundAt >= 0 {
|
||||
registries[foundAt] = registry
|
||||
} else {
|
||||
registries = append([]*DefaultModelRegistry{registry}, registries...)
|
||||
}
|
||||
}
|
||||
|
||||
// AddRegistry adds a registry to the global list of registries
|
||||
// Registries are searched in the order they were added
|
||||
func AddRegistry(registry *DefaultModelRegistry) {
|
||||
registriesMutex.Lock()
|
||||
defer registriesMutex.Unlock()
|
||||
registries = append(registries, registry)
|
||||
}
|
||||
|
||||
func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) error {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
@@ -69,19 +100,19 @@ func (r *DefaultModelRegistry) RegisterModel(name string, model interface{}) err
|
||||
func (r *DefaultModelRegistry) GetModel(name string) (interface{}, error) {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
|
||||
model, exists := r.models[name]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("model %s not found", name)
|
||||
}
|
||||
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func (r *DefaultModelRegistry) GetAllModels() map[string]interface{} {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
|
||||
result := make(map[string]interface{})
|
||||
for k, v := range r.models {
|
||||
result[k] = v
|
||||
@@ -107,9 +138,19 @@ func RegisterModel(model interface{}, name string) error {
|
||||
return defaultRegistry.RegisterModel(name, model)
|
||||
}
|
||||
|
||||
// GetModelByName retrieves a model from the default global registry by name
|
||||
// GetModelByName retrieves a model by searching through all registries in order
|
||||
// Returns the first match found
|
||||
func GetModelByName(name string) (interface{}, error) {
|
||||
return defaultRegistry.GetModel(name)
|
||||
registriesMutex.RLock()
|
||||
defer registriesMutex.RUnlock()
|
||||
|
||||
for _, registry := range registries {
|
||||
if model, err := registry.GetModel(name); err == nil {
|
||||
return model, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("model %s not found in any registry", name)
|
||||
}
|
||||
|
||||
// IterateModels iterates over all models in the default global registry
|
||||
@@ -122,14 +163,26 @@ func IterateModels(fn func(name string, model interface{})) {
|
||||
}
|
||||
}
|
||||
|
||||
// GetModels returns a list of all models in the default global registry
|
||||
// GetModels returns a list of all models from all registries
|
||||
// Models are collected in registry order, with duplicates included
|
||||
func GetModels() []interface{} {
|
||||
defaultRegistry.mutex.RLock()
|
||||
defer defaultRegistry.mutex.RUnlock()
|
||||
registriesMutex.RLock()
|
||||
defer registriesMutex.RUnlock()
|
||||
|
||||
models := make([]interface{}, 0, len(defaultRegistry.models))
|
||||
for _, model := range defaultRegistry.models {
|
||||
models = append(models, model)
|
||||
var models []interface{}
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, registry := range registries {
|
||||
registry.mutex.RLock()
|
||||
for name, model := range registry.models {
|
||||
// Only add the first occurrence of each model name
|
||||
if !seen[name] {
|
||||
models = append(models, model)
|
||||
seen[name] = true
|
||||
}
|
||||
}
|
||||
registry.mutex.RUnlock()
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
}
|
||||
|
||||
128
pkg/reflection/generic_model.go
Normal file
128
pkg/reflection/generic_model.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
type ModelFieldDetail struct {
|
||||
Name string `json:"name"`
|
||||
DataType string `json:"datatype"`
|
||||
SQLName string `json:"sqlname"`
|
||||
SQLDataType string `json:"sqldatatype"`
|
||||
SQLKey string `json:"sqlkey"`
|
||||
Nullable bool `json:"nullable"`
|
||||
FieldValue reflect.Value `json:"-"`
|
||||
}
|
||||
|
||||
// GetModelColumnDetail - Get a list of columns in the SQL declaration of the model
|
||||
// This function recursively processes embedded structs to include their fields
|
||||
func GetModelColumnDetail(record reflect.Value) []ModelFieldDetail {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("Panic in GetModelColumnDetail : %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
lst := make([]ModelFieldDetail, 0)
|
||||
|
||||
if !record.IsValid() {
|
||||
return lst
|
||||
}
|
||||
if record.Kind() == reflect.Pointer || record.Kind() == reflect.Interface {
|
||||
record = record.Elem()
|
||||
}
|
||||
if record.Kind() != reflect.Struct {
|
||||
return lst
|
||||
}
|
||||
|
||||
collectFieldDetails(record, &lst)
|
||||
|
||||
return lst
|
||||
}
|
||||
|
||||
// collectFieldDetails recursively collects field details from a struct value and its embedded fields
|
||||
func collectFieldDetails(record reflect.Value, lst *[]ModelFieldDetail) {
|
||||
modeltype := record.Type()
|
||||
|
||||
for i := 0; i < modeltype.NumField(); i++ {
|
||||
fieldtype := modeltype.Field(i)
|
||||
fieldValue := record.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if fieldtype.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
embeddedValue := fieldValue
|
||||
if fieldValue.Kind() == reflect.Pointer {
|
||||
if fieldValue.IsNil() {
|
||||
// Skip nil embedded pointers
|
||||
continue
|
||||
}
|
||||
embeddedValue = fieldValue.Elem()
|
||||
}
|
||||
|
||||
// Recursively process embedded struct
|
||||
if embeddedValue.Kind() == reflect.Struct {
|
||||
collectFieldDetails(embeddedValue, lst)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
gormdetail := fieldtype.Tag.Get("gorm")
|
||||
gormdetail = strings.Trim(gormdetail, " ")
|
||||
fielddetail := ModelFieldDetail{}
|
||||
fielddetail.FieldValue = fieldValue
|
||||
fielddetail.Name = fieldtype.Name
|
||||
fielddetail.DataType = fieldtype.Type.Name()
|
||||
fielddetail.SQLName = fnFindKeyVal(gormdetail, "column:")
|
||||
fielddetail.SQLDataType = fnFindKeyVal(gormdetail, "type:")
|
||||
gormdetailLower := strings.ToLower(gormdetail)
|
||||
switch {
|
||||
case strings.Index(gormdetailLower, "identity") > 0 || strings.Index(gormdetailLower, "primary_key") > 0:
|
||||
fielddetail.SQLKey = "primary_key"
|
||||
case strings.Contains(gormdetailLower, "unique"):
|
||||
fielddetail.SQLKey = "unique"
|
||||
case strings.Contains(gormdetailLower, "uniqueindex"):
|
||||
fielddetail.SQLKey = "uniqueindex"
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(gormdetail), "nullable") {
|
||||
fielddetail.Nullable = true
|
||||
} else if strings.Contains(strings.ToLower(gormdetail), "null") {
|
||||
fielddetail.Nullable = true
|
||||
}
|
||||
if strings.Contains(strings.ToLower(gormdetail), "not null") {
|
||||
fielddetail.Nullable = false
|
||||
}
|
||||
|
||||
if strings.Contains(strings.ToLower(gormdetail), "foreignkey:") {
|
||||
fielddetail.SQLKey = "foreign_key"
|
||||
ik := strings.Index(strings.ToLower(gormdetail), "foreignkey:")
|
||||
ie := strings.Index(gormdetail[ik:], ";")
|
||||
if ie > ik && ik > 0 {
|
||||
fielddetail.SQLName = strings.ToLower(gormdetail)[ik+11 : ik+ie]
|
||||
// fmt.Printf("\r\nforeignkey: %v", fielddetail)
|
||||
}
|
||||
|
||||
}
|
||||
// ";foreignkey:rid_parent;association_foreignkey:id_atevent;save_associations:false;association_autocreate:false;"
|
||||
|
||||
*lst = append(*lst, fielddetail)
|
||||
}
|
||||
}
|
||||
|
||||
func fnFindKeyVal(src, key string) string {
|
||||
icolStart := strings.Index(strings.ToLower(src), strings.ToLower(key))
|
||||
val := ""
|
||||
if icolStart >= 0 {
|
||||
val = src[icolStart+len(key):]
|
||||
icolend := strings.Index(val, ";")
|
||||
if icolend > 0 {
|
||||
val = val[:icolend]
|
||||
}
|
||||
return val
|
||||
}
|
||||
return ""
|
||||
}
|
||||
49
pkg/reflection/helpers.go
Normal file
49
pkg/reflection/helpers.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package reflection
|
||||
|
||||
import "reflect"
|
||||
|
||||
func Len(v any) int {
|
||||
val := reflect.ValueOf(v)
|
||||
valKind := val.Kind()
|
||||
|
||||
if valKind == reflect.Ptr {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
switch val.Kind() {
|
||||
case reflect.Slice, reflect.Array, reflect.Map, reflect.String, reflect.Chan:
|
||||
return val.Len()
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractTableNameOnly extracts the table name from a fully qualified table reference.
|
||||
// It removes any schema prefix (e.g., "schema.table" -> "table") and truncates at
|
||||
// the first delimiter (comma, space, tab, or newline). If the input contains multiple
|
||||
// dots, it returns everything after the last dot up to the first delimiter.
|
||||
func ExtractTableNameOnly(fullName string) string {
|
||||
// First, split by dot to remove schema prefix if present
|
||||
lastDotIndex := -1
|
||||
for i, char := range fullName {
|
||||
if char == '.' {
|
||||
lastDotIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
// Start from after the last dot (or from beginning if no dot)
|
||||
startIndex := 0
|
||||
if lastDotIndex != -1 {
|
||||
startIndex = lastDotIndex + 1
|
||||
}
|
||||
|
||||
// Now find the end (first delimiter after the table name)
|
||||
for i := startIndex; i < len(fullName); i++ {
|
||||
char := rune(fullName[i])
|
||||
if char == ',' || char == ' ' || char == '\t' || char == '\n' {
|
||||
return fullName[startIndex:i]
|
||||
}
|
||||
}
|
||||
|
||||
return fullName[startIndex:]
|
||||
}
|
||||
883
pkg/reflection/model_utils.go
Normal file
883
pkg/reflection/model_utils.go
Normal file
@@ -0,0 +1,883 @@
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
type PrimaryKeyNameProvider interface {
|
||||
GetIDName() string
|
||||
}
|
||||
|
||||
// GetPrimaryKeyName extracts the primary key column name from a model
|
||||
// It first checks if the model implements PrimaryKeyNameProvider (GetIDName method)
|
||||
// Falls back to reflection to find bun:",pk" tag, then gorm:"primaryKey" tag
|
||||
func GetPrimaryKeyName(model any) string {
|
||||
if reflect.TypeOf(model) == nil {
|
||||
return ""
|
||||
}
|
||||
// If we are given a string model name, look up the model
|
||||
if reflect.TypeOf(model).Kind() == reflect.String {
|
||||
name := model.(string)
|
||||
m, err := modelregistry.GetModelByName(name)
|
||||
if err == nil {
|
||||
model = m
|
||||
}
|
||||
}
|
||||
|
||||
// Check if model implements PrimaryKeyNameProvider
|
||||
if provider, ok := model.(PrimaryKeyNameProvider); ok {
|
||||
return provider.GetIDName()
|
||||
}
|
||||
|
||||
// Try Bun tag first
|
||||
if pkName := getPrimaryKeyFromReflection(model, "bun"); pkName != "" {
|
||||
return pkName
|
||||
}
|
||||
|
||||
// Fall back to GORM tag
|
||||
if pkName := getPrimaryKeyFromReflection(model, "gorm"); pkName != "" {
|
||||
return pkName
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetPrimaryKeyValue extracts the primary key value from a model instance
|
||||
// Returns the value of the primary key field
|
||||
func GetPrimaryKeyValue(model any) any {
|
||||
if model == nil || reflect.TypeOf(model) == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() == reflect.Pointer {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
if val.Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try Bun tag first
|
||||
if pkValue := findPrimaryKeyValue(val, "bun"); pkValue != nil {
|
||||
return pkValue
|
||||
}
|
||||
|
||||
// Fall back to GORM tag
|
||||
if pkValue := findPrimaryKeyValue(val, "gorm"); pkValue != nil {
|
||||
return pkValue
|
||||
}
|
||||
|
||||
// Last resort: look for field named "ID" or "Id"
|
||||
if pkValue := findFieldByName(val, "id"); pkValue != nil {
|
||||
return pkValue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findPrimaryKeyValue recursively searches for a primary key field in the struct
|
||||
func findPrimaryKeyValue(val reflect.Value, ormType string) any {
|
||||
typ := val.Type()
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
fieldValue := val.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||
// Recursively search in embedded struct
|
||||
if pkValue := findPrimaryKeyValue(fieldValue, ormType); pkValue != nil {
|
||||
return pkValue
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for primary key tag
|
||||
switch ormType {
|
||||
case "bun":
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if strings.Contains(bunTag, "pk") && fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
case "gorm":
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if strings.Contains(gormTag, "primaryKey") && fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// findFieldByName recursively searches for a field by name in the struct
|
||||
func findFieldByName(val reflect.Value, name string) any {
|
||||
typ := val.Type()
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
fieldValue := val.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous && field.Type.Kind() == reflect.Struct {
|
||||
// Recursively search in embedded struct
|
||||
if result := findFieldByName(fieldValue, name); result != nil {
|
||||
return result
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if field name matches
|
||||
if strings.EqualFold(field.Name, name) && fieldValue.CanInterface() {
|
||||
return fieldValue.Interface()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelColumns extracts all column names from a model using reflection
|
||||
// It checks bun tags first, then gorm tags, then json tags, and finally falls back to lowercase field names
|
||||
// This function recursively processes embedded structs to include their fields
|
||||
func GetModelColumns(model any) []string {
|
||||
var columns []string
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Validate that we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return columns
|
||||
}
|
||||
|
||||
collectColumnsFromType(modelType, &columns)
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// collectColumnsFromType recursively collects column names from a struct type and its embedded fields
|
||||
func collectColumnsFromType(typ reflect.Type, columns *[]string) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Recursively process embedded struct
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
collectColumnsFromType(fieldType, columns)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Get column name using the same logic as primary key extraction
|
||||
columnName := getColumnNameFromField(field)
|
||||
|
||||
if columnName != "" {
|
||||
*columns = append(*columns, columnName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getColumnNameFromField extracts the column name from a struct field
|
||||
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
|
||||
func getColumnNameFromField(field reflect.StructField) string {
|
||||
// Try bun tag first
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && bunTag != "-" {
|
||||
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
}
|
||||
|
||||
// Try gorm tag
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" && gormTag != "-" {
|
||||
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to json tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
// Extract just the field name before any options
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Last resort: use field name in lowercase
|
||||
return strings.ToLower(field.Name)
|
||||
}
|
||||
|
||||
// getPrimaryKeyFromReflection uses reflection to find the primary key field
|
||||
// This function recursively searches embedded structs
|
||||
func getPrimaryKeyFromReflection(model any, ormType string) string {
|
||||
val := reflect.ValueOf(model)
|
||||
if val.Kind() == reflect.Pointer {
|
||||
val = val.Elem()
|
||||
}
|
||||
|
||||
if val.Kind() != reflect.Struct {
|
||||
return ""
|
||||
}
|
||||
|
||||
typ := val.Type()
|
||||
return findPrimaryKeyNameFromType(typ, ormType)
|
||||
}
|
||||
|
||||
// findPrimaryKeyNameFromType recursively searches for the primary key field name in a struct type
|
||||
func findPrimaryKeyNameFromType(typ reflect.Type, ormType string) string {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Recursively search in embedded struct
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
if pkName := findPrimaryKeyNameFromType(fieldType, ormType); pkName != "" {
|
||||
return pkName
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch ormType {
|
||||
case "gorm":
|
||||
// Check for gorm tag with primaryKey
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if strings.Contains(gormTag, "primaryKey") {
|
||||
// Try to extract column name from gorm tag
|
||||
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
// Fall back to json tag
|
||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||||
return strings.Split(jsonTag, ",")[0]
|
||||
}
|
||||
}
|
||||
case "bun":
|
||||
// Check for bun tag with pk flag
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if strings.Contains(bunTag, "pk") {
|
||||
// Extract column name from bun tag
|
||||
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||
return colName
|
||||
}
|
||||
// Fall back to json tag
|
||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" {
|
||||
return strings.Split(jsonTag, ",")[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// ExtractColumnFromGormTag extracts the column name from a gorm tag
|
||||
// Example: "column:id;primaryKey" -> "id"
|
||||
func ExtractColumnFromGormTag(tag string) string {
|
||||
parts := strings.Split(tag, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if colName, found := strings.CutPrefix(part, "column:"); found {
|
||||
return colName
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ExtractColumnFromBunTag extracts the column name from a bun tag
|
||||
// Example: "id,pk" -> "id"
|
||||
// Example: ",pk" -> "" (will fall back to json tag)
|
||||
func ExtractColumnFromBunTag(tag string) string {
|
||||
parts := strings.Split(tag, ",")
|
||||
if strings.HasPrefix(strings.ToLower(tag), "table:") || strings.HasPrefix(strings.ToLower(tag), "rel:") || strings.HasPrefix(strings.ToLower(tag), "join:") {
|
||||
return ""
|
||||
}
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetSQLModelColumns extracts column names that have valid SQL field mappings
|
||||
// This function only returns columns that:
|
||||
// 1. Have bun or gorm tags (not just json tags)
|
||||
// 2. Are not relations (no rel:, join:, foreignKey, references, many2many tags)
|
||||
// 3. Are not scan-only embedded fields
|
||||
func GetSQLModelColumns(model any) []string {
|
||||
var columns []string
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers, slices, and arrays to get to the base struct type
|
||||
for modelType != nil && (modelType.Kind() == reflect.Pointer || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Validate that we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return columns
|
||||
}
|
||||
|
||||
collectSQLColumnsFromType(modelType, &columns, false)
|
||||
|
||||
return columns
|
||||
}
|
||||
|
||||
// collectSQLColumnsFromType recursively collects SQL column names from a struct type
|
||||
// scanOnlyEmbedded indicates if we're inside a scan-only embedded struct
|
||||
func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbedded bool) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Check if the embedded struct itself is scan-only
|
||||
isScanOnly := scanOnlyEmbedded
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||
isScanOnly = true
|
||||
}
|
||||
|
||||
// Recursively process embedded struct
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
collectSQLColumnsFromType(fieldType, columns, isScanOnly)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Skip fields in scan-only embedded structs
|
||||
if scanOnlyEmbedded {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get bun and gorm tags
|
||||
bunTag := field.Tag.Get("bun")
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
|
||||
// Skip if neither bun nor gorm tag exists
|
||||
if bunTag == "" && gormTag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if explicitly marked with "-"
|
||||
if bunTag == "-" || gormTag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if field itself is scan-only (bun)
|
||||
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if field itself is read-only (gorm)
|
||||
if gormTag != "" && isGormFieldReadOnly(gormTag) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip relation fields (bun)
|
||||
if bunTag != "" {
|
||||
// Skip if it's a bun relation (rel:, join:, or m2m:)
|
||||
if strings.Contains(bunTag, "rel:") ||
|
||||
strings.Contains(bunTag, "join:") ||
|
||||
strings.Contains(bunTag, "m2m:") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Skip relation fields (gorm)
|
||||
if gormTag != "" {
|
||||
// Skip if it has gorm relationship tags
|
||||
if strings.Contains(gormTag, "foreignKey:") ||
|
||||
strings.Contains(gormTag, "references:") ||
|
||||
strings.Contains(gormTag, "many2many:") ||
|
||||
strings.Contains(gormTag, "constraint:") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Get column name
|
||||
columnName := ""
|
||||
if bunTag != "" {
|
||||
columnName = ExtractColumnFromBunTag(bunTag)
|
||||
}
|
||||
if columnName == "" && gormTag != "" {
|
||||
columnName = ExtractColumnFromGormTag(gormTag)
|
||||
}
|
||||
|
||||
// Skip if we couldn't extract a column name
|
||||
if columnName == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
*columns = append(*columns, columnName)
|
||||
}
|
||||
}
|
||||
|
||||
// IsColumnWritable checks if a column can be written to in the database
|
||||
// For bun: returns false if the field has "scanonly" tag
|
||||
// For gorm: returns false if the field has "<-:false" or "->" (read-only) tag
|
||||
// This function recursively searches embedded structs
|
||||
func IsColumnWritable(model any, columnName string) bool {
|
||||
modelType := reflect.TypeOf(model)
|
||||
|
||||
// Unwrap pointers to get to the base struct type
|
||||
for modelType != nil && modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Validate that we have a struct type
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return false
|
||||
}
|
||||
|
||||
found, writable := isColumnWritableInType(modelType, columnName)
|
||||
if found {
|
||||
return writable
|
||||
}
|
||||
|
||||
// Column not found in model, allow it (might be a dynamic column)
|
||||
return true
|
||||
}
|
||||
|
||||
// isColumnWritableInType recursively searches for a column and checks if it's writable
|
||||
// Returns (found, writable) where found indicates if the column was found
|
||||
func isColumnWritableInType(typ reflect.Type, columnName string) (found bool, writable bool) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Check if this is an embedded struct
|
||||
if field.Anonymous {
|
||||
// Unwrap pointer type if necessary
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Pointer {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
|
||||
// Recursively search in embedded struct
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
if found, writable := isColumnWritableInType(fieldType, columnName); found {
|
||||
return true, writable
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this field matches the column name
|
||||
fieldColumnName := getColumnNameFromField(field)
|
||||
if fieldColumnName != columnName {
|
||||
continue
|
||||
}
|
||||
|
||||
// Found the field, now check if it's writable
|
||||
// Check bun tag for scanonly
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" {
|
||||
if isBunFieldScanOnly(bunTag) {
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
|
||||
// Check gorm tag for write restrictions
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" {
|
||||
if isGormFieldReadOnly(gormTag) {
|
||||
return true, false
|
||||
}
|
||||
}
|
||||
|
||||
// Column is writable
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Column not found
|
||||
return false, false
|
||||
}
|
||||
|
||||
// isBunFieldScanOnly checks if a bun tag indicates the field is scan-only
|
||||
// Example: "column_name,scanonly" -> true
|
||||
func isBunFieldScanOnly(tag string) bool {
|
||||
parts := strings.Split(tag, ",")
|
||||
for _, part := range parts {
|
||||
if strings.TrimSpace(part) == "scanonly" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isGormFieldReadOnly checks if a gorm tag indicates the field is read-only
|
||||
// Examples:
|
||||
// - "<-:false" -> true (no writes allowed)
|
||||
// - "->" -> true (read-only, common pattern)
|
||||
// - "column:name;->" -> true
|
||||
// - "<-:create" -> false (writes allowed on create)
|
||||
func isGormFieldReadOnly(tag string) bool {
|
||||
parts := strings.Split(tag, ";")
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
|
||||
// Check for read-only marker
|
||||
if part == "->" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for write restrictions
|
||||
if value, found := strings.CutPrefix(part, "<-:"); found {
|
||||
if value == "false" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractSourceColumn extracts the base column name from PostgreSQL JSON operators
|
||||
// Examples:
|
||||
// - "columna->>'val'" returns "columna"
|
||||
// - "columna->'key'" returns "columna"
|
||||
// - "columna" returns "columna"
|
||||
// - "table.columna->>'val'" returns "table.columna"
|
||||
func ExtractSourceColumn(colName string) string {
|
||||
// Check for PostgreSQL JSON operators: -> and ->>
|
||||
if idx := strings.Index(colName, "->>"); idx != -1 {
|
||||
return strings.TrimSpace(colName[:idx])
|
||||
}
|
||||
if idx := strings.Index(colName, "->"); idx != -1 {
|
||||
return strings.TrimSpace(colName[:idx])
|
||||
}
|
||||
return colName
|
||||
}
|
||||
|
||||
// ToSnakeCase converts a string from CamelCase to snake_case
|
||||
func ToSnakeCase(s string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range s {
|
||||
if i > 0 && r >= 'A' && r <= 'Z' {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
result.WriteRune(r)
|
||||
}
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// GetColumnTypeFromModel uses reflection to determine the Go type of a column in a model
|
||||
func GetColumnTypeFromModel(model interface{}, colName string) reflect.Kind {
|
||||
if model == nil {
|
||||
return reflect.Invalid
|
||||
}
|
||||
|
||||
// Extract the source column name (remove JSON operators like ->> or ->)
|
||||
sourceColName := ExtractSourceColumn(colName)
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
// Dereference pointer if needed
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a struct
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
return reflect.Invalid
|
||||
}
|
||||
|
||||
// Find the field by JSON tag or field name
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// Check JSON tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" {
|
||||
// Parse JSON tag (format: "name,omitempty")
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if parts[0] == sourceColName {
|
||||
return field.Type.Kind()
|
||||
}
|
||||
}
|
||||
|
||||
// Check field name (case-insensitive)
|
||||
if strings.EqualFold(field.Name, sourceColName) {
|
||||
return field.Type.Kind()
|
||||
}
|
||||
|
||||
// Check snake_case conversion
|
||||
snakeCaseName := ToSnakeCase(field.Name)
|
||||
if snakeCaseName == sourceColName {
|
||||
return field.Type.Kind()
|
||||
}
|
||||
}
|
||||
|
||||
return reflect.Invalid
|
||||
}
|
||||
|
||||
// IsNumericType checks if a reflect.Kind is a numeric type
|
||||
func IsNumericType(kind reflect.Kind) bool {
|
||||
return kind == reflect.Int || kind == reflect.Int8 || kind == reflect.Int16 ||
|
||||
kind == reflect.Int32 || kind == reflect.Int64 || kind == reflect.Uint ||
|
||||
kind == reflect.Uint8 || kind == reflect.Uint16 || kind == reflect.Uint32 ||
|
||||
kind == reflect.Uint64 || kind == reflect.Float32 || kind == reflect.Float64
|
||||
}
|
||||
|
||||
// IsStringType checks if a reflect.Kind is a string type
|
||||
func IsStringType(kind reflect.Kind) bool {
|
||||
return kind == reflect.String
|
||||
}
|
||||
|
||||
// IsNumericValue checks if a string value can be parsed as a number
|
||||
func IsNumericValue(value string) bool {
|
||||
value = strings.TrimSpace(value)
|
||||
_, err := strconv.ParseFloat(value, 64)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// ConvertToNumericType converts a string value to the appropriate numeric type
|
||||
func ConvertToNumericType(value string, kind reflect.Kind) (interface{}, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
|
||||
switch kind {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
// Parse as integer
|
||||
bitSize := 64
|
||||
switch kind {
|
||||
case reflect.Int8:
|
||||
bitSize = 8
|
||||
case reflect.Int16:
|
||||
bitSize = 16
|
||||
case reflect.Int32:
|
||||
bitSize = 32
|
||||
}
|
||||
|
||||
intVal, err := strconv.ParseInt(value, 10, bitSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid integer value: %w", err)
|
||||
}
|
||||
|
||||
// Return the appropriate type
|
||||
switch kind {
|
||||
case reflect.Int:
|
||||
return int(intVal), nil
|
||||
case reflect.Int8:
|
||||
return int8(intVal), nil
|
||||
case reflect.Int16:
|
||||
return int16(intVal), nil
|
||||
case reflect.Int32:
|
||||
return int32(intVal), nil
|
||||
case reflect.Int64:
|
||||
return intVal, nil
|
||||
}
|
||||
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
// Parse as unsigned integer
|
||||
bitSize := 64
|
||||
switch kind {
|
||||
case reflect.Uint8:
|
||||
bitSize = 8
|
||||
case reflect.Uint16:
|
||||
bitSize = 16
|
||||
case reflect.Uint32:
|
||||
bitSize = 32
|
||||
}
|
||||
|
||||
uintVal, err := strconv.ParseUint(value, 10, bitSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid unsigned integer value: %w", err)
|
||||
}
|
||||
|
||||
// Return the appropriate type
|
||||
switch kind {
|
||||
case reflect.Uint:
|
||||
return uint(uintVal), nil
|
||||
case reflect.Uint8:
|
||||
return uint8(uintVal), nil
|
||||
case reflect.Uint16:
|
||||
return uint16(uintVal), nil
|
||||
case reflect.Uint32:
|
||||
return uint32(uintVal), nil
|
||||
case reflect.Uint64:
|
||||
return uintVal, nil
|
||||
}
|
||||
|
||||
case reflect.Float32, reflect.Float64:
|
||||
// Parse as float
|
||||
bitSize := 64
|
||||
if kind == reflect.Float32 {
|
||||
bitSize = 32
|
||||
}
|
||||
|
||||
floatVal, err := strconv.ParseFloat(value, bitSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid float value: %w", err)
|
||||
}
|
||||
|
||||
if kind == reflect.Float32 {
|
||||
return float32(floatVal), nil
|
||||
}
|
||||
return floatVal, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported numeric type: %v", kind)
|
||||
}
|
||||
|
||||
// GetRelationModel gets the model type for a relation field
|
||||
// It searches for the field by name in the following order (case-insensitive):
|
||||
// 1. Actual field name
|
||||
// 2. Bun tag name (if exists)
|
||||
// 3. Gorm tag name (if exists)
|
||||
// 4. JSON tag name (if exists)
|
||||
//
|
||||
// Supports recursive field paths using dot notation (e.g., "MAL.MAL.DEF")
|
||||
// For nested fields, it traverses through each level of the struct hierarchy
|
||||
func GetRelationModel(model interface{}, fieldName string) interface{} {
|
||||
if model == nil || fieldName == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Split the field name by "." to handle nested/recursive relations
|
||||
fieldParts := strings.Split(fieldName, ".")
|
||||
|
||||
// Start with the current model
|
||||
currentModel := model
|
||||
|
||||
// Traverse through each level of the field path
|
||||
for _, part := range fieldParts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
currentModel = getRelationModelSingleLevel(currentModel, part)
|
||||
if currentModel == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return currentModel
|
||||
}
|
||||
|
||||
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
||||
// This is a helper function used by GetRelationModel to handle one level at a time
|
||||
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
||||
if model == nil || fieldName == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find the field by checking in priority order (case-insensitive)
|
||||
var field *reflect.StructField
|
||||
normalizedFieldName := strings.ToLower(fieldName)
|
||||
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
f := modelType.Field(i)
|
||||
|
||||
// 1. Check actual field name (case-insensitive)
|
||||
if strings.EqualFold(f.Name, fieldName) {
|
||||
field = &f
|
||||
break
|
||||
}
|
||||
|
||||
// 2. Check bun tag name
|
||||
bunTag := f.Tag.Get("bun")
|
||||
if bunTag != "" {
|
||||
bunColName := ExtractColumnFromBunTag(bunTag)
|
||||
if bunColName != "" && strings.EqualFold(bunColName, normalizedFieldName) {
|
||||
field = &f
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Check gorm tag name
|
||||
gormTag := f.Tag.Get("gorm")
|
||||
if gormTag != "" {
|
||||
gormColName := ExtractColumnFromGormTag(gormTag)
|
||||
if gormColName != "" && strings.EqualFold(gormColName, normalizedFieldName) {
|
||||
field = &f
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Check JSON tag name
|
||||
jsonTag := f.Tag.Get("json")
|
||||
if jsonTag != "" {
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
|
||||
if strings.EqualFold(parts[0], normalizedFieldName) {
|
||||
field = &f
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if field == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the target type
|
||||
targetType := field.Type
|
||||
if targetType == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if targetType.Kind() == reflect.Slice {
|
||||
targetType = targetType.Elem()
|
||||
if targetType == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if targetType.Kind() == reflect.Ptr {
|
||||
targetType = targetType.Elem()
|
||||
if targetType == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if targetType.Kind() != reflect.Struct {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create a zero value of the target type
|
||||
return reflect.New(targetType).Elem().Interface()
|
||||
}
|
||||
616
pkg/reflection/model_utils_test.go
Normal file
616
pkg/reflection/model_utils_test.go
Normal file
@@ -0,0 +1,616 @@
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Test models for GORM
|
||||
type GormModelWithGetIDName struct {
|
||||
ID int `gorm:"column:rid_test;primaryKey" json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (m GormModelWithGetIDName) GetIDName() string {
|
||||
return "rid_test"
|
||||
}
|
||||
|
||||
type GormModelWithColumnTag struct {
|
||||
ID int `gorm:"column:custom_id;primaryKey" json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type GormModelWithJSONFallback struct {
|
||||
ID int `gorm:"primaryKey" json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// Test models for Bun
|
||||
type BunModelWithGetIDName struct {
|
||||
ID int `bun:"rid_test,pk" json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (m BunModelWithGetIDName) GetIDName() string {
|
||||
return "rid_test"
|
||||
}
|
||||
|
||||
type BunModelWithColumnTag struct {
|
||||
ID int `bun:"custom_id,pk" json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type BunModelWithJSONFallback struct {
|
||||
ID int `bun:",pk" json:"user_id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func TestGetPrimaryKeyName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "GORM model with GetIDName method",
|
||||
model: GormModelWithGetIDName{},
|
||||
expected: "rid_test",
|
||||
},
|
||||
{
|
||||
name: "GORM model with column tag",
|
||||
model: GormModelWithColumnTag{},
|
||||
expected: "custom_id",
|
||||
},
|
||||
{
|
||||
name: "GORM model with JSON fallback",
|
||||
model: GormModelWithJSONFallback{},
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
name: "GORM model pointer with GetIDName",
|
||||
model: &GormModelWithGetIDName{},
|
||||
expected: "rid_test",
|
||||
},
|
||||
{
|
||||
name: "GORM model pointer with column tag",
|
||||
model: &GormModelWithColumnTag{},
|
||||
expected: "custom_id",
|
||||
},
|
||||
{
|
||||
name: "Bun model with GetIDName method",
|
||||
model: BunModelWithGetIDName{},
|
||||
expected: "rid_test",
|
||||
},
|
||||
{
|
||||
name: "Bun model with column tag",
|
||||
model: BunModelWithColumnTag{},
|
||||
expected: "custom_id",
|
||||
},
|
||||
{
|
||||
name: "Bun model with JSON fallback",
|
||||
model: BunModelWithJSONFallback{},
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
name: "Bun model pointer with GetIDName",
|
||||
model: &BunModelWithGetIDName{},
|
||||
expected: "rid_test",
|
||||
},
|
||||
{
|
||||
name: "Bun model pointer with column tag",
|
||||
model: &BunModelWithColumnTag{},
|
||||
expected: "custom_id",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetPrimaryKeyName(tt.model)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractColumnFromGormTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "column tag with primaryKey",
|
||||
tag: "column:rid_test;primaryKey",
|
||||
expected: "rid_test",
|
||||
},
|
||||
{
|
||||
name: "column tag with spaces",
|
||||
tag: "column:user_id ; primaryKey ; autoIncrement",
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
name: "no column tag",
|
||||
tag: "primaryKey;autoIncrement",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractColumnFromGormTag(tt.tag)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractColumnFromGormTag() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractColumnFromBunTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "column name with pk flag",
|
||||
tag: "rid_test,pk",
|
||||
expected: "rid_test",
|
||||
},
|
||||
{
|
||||
name: "only pk flag",
|
||||
tag: ",pk",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "column with multiple flags",
|
||||
tag: "user_id,pk,autoincrement",
|
||||
expected: "user_id",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ExtractColumnFromBunTag(tt.tag)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ExtractColumnFromBunTag() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelColumns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Bun model with multiple columns",
|
||||
model: BunModelWithColumnTag{},
|
||||
expected: []string{"custom_id", "name"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with multiple columns",
|
||||
model: GormModelWithColumnTag{},
|
||||
expected: []string{"custom_id", "name"},
|
||||
},
|
||||
{
|
||||
name: "Bun model pointer",
|
||||
model: &BunModelWithColumnTag{},
|
||||
expected: []string{"custom_id", "name"},
|
||||
},
|
||||
{
|
||||
name: "GORM model pointer",
|
||||
model: &GormModelWithColumnTag{},
|
||||
expected: []string{"custom_id", "name"},
|
||||
},
|
||||
{
|
||||
name: "Bun model with JSON fallback",
|
||||
model: BunModelWithJSONFallback{},
|
||||
expected: []string{"user_id", "name"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with JSON fallback",
|
||||
model: GormModelWithJSONFallback{},
|
||||
expected: []string{"user_id", "name"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetModelColumns(tt.model)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("GetModelColumns() returned %d columns, want %d", len(result), len(tt.expected))
|
||||
return
|
||||
}
|
||||
for i, col := range result {
|
||||
if col != tt.expected[i] {
|
||||
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test models with embedded structs
|
||||
|
||||
type BaseModel struct {
|
||||
ID int `bun:"rid_base,pk" json:"id"`
|
||||
CreatedAt string `bun:"created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type AdhocBuffer struct {
|
||||
CQL1 string `json:"cql1,omitempty" gorm:"->" bun:",scanonly"`
|
||||
CQL2 string `json:"cql2,omitempty" gorm:"->" bun:",scanonly"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||
}
|
||||
|
||||
type ModelWithEmbedded struct {
|
||||
BaseModel
|
||||
Name string `bun:"name" json:"name"`
|
||||
Description string `bun:"description" json:"description"`
|
||||
AdhocBuffer
|
||||
}
|
||||
|
||||
type GormBaseModel struct {
|
||||
ID int `gorm:"column:rid_base;primaryKey" json:"id"`
|
||||
CreatedAt string `gorm:"column:created_at" json:"created_at"`
|
||||
}
|
||||
|
||||
type GormAdhocBuffer struct {
|
||||
CQL1 string `json:"cql1,omitempty" gorm:"column:cql1;->" bun:",scanonly"`
|
||||
CQL2 string `json:"cql2,omitempty" gorm:"column:cql2;->" bun:",scanonly"`
|
||||
RowNumber int64 `json:"_rownumber,omitempty" gorm:"-" bun:",scanonly"`
|
||||
}
|
||||
|
||||
type GormModelWithEmbedded struct {
|
||||
GormBaseModel
|
||||
Name string `gorm:"column:name" json:"name"`
|
||||
Description string `gorm:"column:description" json:"description"`
|
||||
GormAdhocBuffer
|
||||
}
|
||||
|
||||
func TestGetPrimaryKeyNameWithEmbedded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Bun model with embedded base",
|
||||
model: ModelWithEmbedded{},
|
||||
expected: "rid_base",
|
||||
},
|
||||
{
|
||||
name: "Bun model with embedded base (pointer)",
|
||||
model: &ModelWithEmbedded{},
|
||||
expected: "rid_base",
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded base",
|
||||
model: GormModelWithEmbedded{},
|
||||
expected: "rid_base",
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded base (pointer)",
|
||||
model: &GormModelWithEmbedded{},
|
||||
expected: "rid_base",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetPrimaryKeyName(tt.model)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetPrimaryKeyName() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrimaryKeyValueWithEmbedded(t *testing.T) {
|
||||
bunModel := ModelWithEmbedded{
|
||||
BaseModel: BaseModel{
|
||||
ID: 123,
|
||||
CreatedAt: "2024-01-01",
|
||||
},
|
||||
Name: "Test",
|
||||
Description: "Test Description",
|
||||
}
|
||||
|
||||
gormModel := GormModelWithEmbedded{
|
||||
GormBaseModel: GormBaseModel{
|
||||
ID: 456,
|
||||
CreatedAt: "2024-01-02",
|
||||
},
|
||||
Name: "GORM Test",
|
||||
Description: "GORM Test Description",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected any
|
||||
}{
|
||||
{
|
||||
name: "Bun model with embedded base",
|
||||
model: bunModel,
|
||||
expected: 123,
|
||||
},
|
||||
{
|
||||
name: "Bun model with embedded base (pointer)",
|
||||
model: &bunModel,
|
||||
expected: 123,
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded base",
|
||||
model: gormModel,
|
||||
expected: 456,
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded base (pointer)",
|
||||
model: &gormModel,
|
||||
expected: 456,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetPrimaryKeyValue(tt.model)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetPrimaryKeyValue() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelColumnsWithEmbedded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Bun model with embedded structs",
|
||||
model: ModelWithEmbedded{},
|
||||
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||
},
|
||||
{
|
||||
name: "Bun model with embedded structs (pointer)",
|
||||
model: &ModelWithEmbedded{},
|
||||
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded structs",
|
||||
model: GormModelWithEmbedded{},
|
||||
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded structs (pointer)",
|
||||
model: &GormModelWithEmbedded{},
|
||||
expected: []string{"rid_base", "created_at", "name", "description", "cql1", "cql2", "_rownumber"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetModelColumns(tt.model)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("GetModelColumns() returned %d columns, want %d. Got: %v", len(result), len(tt.expected), result)
|
||||
return
|
||||
}
|
||||
for i, col := range result {
|
||||
if col != tt.expected[i] {
|
||||
t.Errorf("GetModelColumns()[%d] = %v, want %v", i, col, tt.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsColumnWritableWithEmbedded(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
columnName string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Bun model - writable column in main struct",
|
||||
model: ModelWithEmbedded{},
|
||||
columnName: "name",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Bun model - writable column in embedded base",
|
||||
model: ModelWithEmbedded{},
|
||||
columnName: "rid_base",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Bun model - scanonly column in embedded adhoc buffer",
|
||||
model: ModelWithEmbedded{},
|
||||
columnName: "cql1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Bun model - scanonly column _rownumber",
|
||||
model: ModelWithEmbedded{},
|
||||
columnName: "_rownumber",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "GORM model - writable column in main struct",
|
||||
model: GormModelWithEmbedded{},
|
||||
columnName: "name",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "GORM model - writable column in embedded base",
|
||||
model: GormModelWithEmbedded{},
|
||||
columnName: "rid_base",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "GORM model - readonly column in embedded adhoc buffer",
|
||||
model: GormModelWithEmbedded{},
|
||||
columnName: "cql1",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "GORM model - readonly column _rownumber",
|
||||
model: GormModelWithEmbedded{},
|
||||
columnName: "_rownumber",
|
||||
expected: false, // bun:",scanonly" marks it as read-only, takes precedence over gorm:"-"
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsColumnWritable(tt.model, tt.columnName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsColumnWritable(%s) = %v, want %v", tt.columnName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test models with relations for GetSQLModelColumns
|
||||
type User struct {
|
||||
ID int `bun:"id,pk" json:"id"`
|
||||
Name string `bun:"name" json:"name"`
|
||||
Email string `bun:"email" json:"email"`
|
||||
ProfileData string `json:"profile_data"` // No bun/gorm tag
|
||||
Posts []Post `bun:"rel:has-many,join:id=user_id" json:"posts"`
|
||||
Profile *Profile `bun:"rel:has-one,join:id=user_id" json:"profile"`
|
||||
RowNumber int64 `bun:",scanonly" json:"_rownumber"`
|
||||
}
|
||||
|
||||
type Post struct {
|
||||
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||
Title string `gorm:"column:title" json:"title"`
|
||||
UserID int `gorm:"column:user_id;foreignKey" json:"user_id"`
|
||||
User *User `gorm:"foreignKey:UserID;references:ID" json:"user"`
|
||||
Tags []Tag `gorm:"many2many:post_tags" json:"tags"`
|
||||
Content string `json:"content"` // No bun/gorm tag
|
||||
}
|
||||
|
||||
type Profile struct {
|
||||
ID int `bun:"id,pk" json:"id"`
|
||||
Bio string `bun:"bio" json:"bio"`
|
||||
UserID int `bun:"user_id" json:"user_id"`
|
||||
}
|
||||
|
||||
type Tag struct {
|
||||
ID int `gorm:"column:id;primaryKey" json:"id"`
|
||||
Name string `gorm:"column:name" json:"name"`
|
||||
}
|
||||
|
||||
// Model with scan-only embedded struct
|
||||
type EntityWithScanOnlyEmbedded struct {
|
||||
ID int `bun:"id,pk" json:"id"`
|
||||
Name string `bun:"name" json:"name"`
|
||||
AdhocBuffer `bun:",scanonly"` // Entire embedded struct is scan-only
|
||||
}
|
||||
|
||||
func TestGetSQLModelColumns(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model any
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Bun model with relations - excludes relations and non-SQL fields",
|
||||
model: User{},
|
||||
// Should include: id, name, email (has bun tags)
|
||||
// Should exclude: profile_data (no bun tag), Posts/Profile (relations), RowNumber (scan-only in embedded would be excluded)
|
||||
expected: []string{"id", "name", "email"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with relations - excludes relations and non-SQL fields",
|
||||
model: Post{},
|
||||
// Should include: id, title, user_id (has gorm tags)
|
||||
// Should exclude: content (no gorm tag), User/Tags (relations)
|
||||
expected: []string{"id", "title", "user_id"},
|
||||
},
|
||||
{
|
||||
name: "Model with embedded base and scan-only embedded",
|
||||
model: EntityWithScanOnlyEmbedded{},
|
||||
// Should include: id, name from main struct
|
||||
// Should exclude: all fields from AdhocBuffer (scan-only embedded struct)
|
||||
expected: []string{"id", "name"},
|
||||
},
|
||||
{
|
||||
name: "Model with embedded - includes SQL fields, excludes scan-only",
|
||||
model: ModelWithEmbedded{},
|
||||
// Should include: rid_base, created_at (from BaseModel), name, description (from main)
|
||||
// Should exclude: cql1, cql2, _rownumber (from AdhocBuffer - scan-only fields)
|
||||
expected: []string{"rid_base", "created_at", "name", "description"},
|
||||
},
|
||||
{
|
||||
name: "GORM model with embedded - includes SQL fields, excludes scan-only",
|
||||
model: GormModelWithEmbedded{},
|
||||
// Should include: rid_base, created_at (from GormBaseModel), name, description (from main)
|
||||
// Should exclude: cql1, cql2 (scan-only), _rownumber (no gorm column tag, marked as -)
|
||||
expected: []string{"rid_base", "created_at", "name", "description"},
|
||||
},
|
||||
{
|
||||
name: "Simple Profile model",
|
||||
model: Profile{},
|
||||
// Should include all fields with bun tags
|
||||
expected: []string{"id", "bio", "user_id"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetSQLModelColumns(tt.model)
|
||||
if len(result) != len(tt.expected) {
|
||||
t.Errorf("GetSQLModelColumns() returned %d columns, want %d.\nGot: %v\nWant: %v",
|
||||
len(result), len(tt.expected), result, tt.expected)
|
||||
return
|
||||
}
|
||||
for i, col := range result {
|
||||
if col != tt.expected[i] {
|
||||
t.Errorf("GetSQLModelColumns()[%d] = %v, want %v.\nFull result: %v",
|
||||
i, col, tt.expected[i], result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSQLModelColumnsVsGetModelColumns(t *testing.T) {
|
||||
// Demonstrate the difference between GetModelColumns and GetSQLModelColumns
|
||||
user := User{}
|
||||
|
||||
allColumns := GetModelColumns(user)
|
||||
sqlColumns := GetSQLModelColumns(user)
|
||||
|
||||
t.Logf("GetModelColumns(User): %v", allColumns)
|
||||
t.Logf("GetSQLModelColumns(User): %v", sqlColumns)
|
||||
|
||||
// GetModelColumns should return more columns (includes fields with only json tags)
|
||||
if len(allColumns) <= len(sqlColumns) {
|
||||
t.Errorf("Expected GetModelColumns to return more columns than GetSQLModelColumns")
|
||||
}
|
||||
|
||||
// GetSQLModelColumns should not include 'profile_data' (no bun tag)
|
||||
for _, col := range sqlColumns {
|
||||
if col == "profile_data" {
|
||||
t.Errorf("GetSQLModelColumns should not include 'profile_data' (no bun/gorm tag)")
|
||||
}
|
||||
}
|
||||
|
||||
// GetModelColumns should include 'profile_data' (has json tag)
|
||||
hasProfileData := false
|
||||
for _, col := range allColumns {
|
||||
if col == "profile_data" {
|
||||
hasProfileData = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasProfileData {
|
||||
t.Errorf("GetModelColumns should include 'profile_data' (has json tag)")
|
||||
}
|
||||
}
|
||||
138
pkg/resolvespec/context_test.go
Normal file
138
pkg/resolvespec/context_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestContextOperations(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test Schema
|
||||
t.Run("WithSchema and GetSchema", func(t *testing.T) {
|
||||
ctx = WithSchema(ctx, "public")
|
||||
schema := GetSchema(ctx)
|
||||
if schema != "public" {
|
||||
t.Errorf("Expected schema 'public', got '%s'", schema)
|
||||
}
|
||||
})
|
||||
|
||||
// Test Entity
|
||||
t.Run("WithEntity and GetEntity", func(t *testing.T) {
|
||||
ctx = WithEntity(ctx, "users")
|
||||
entity := GetEntity(ctx)
|
||||
if entity != "users" {
|
||||
t.Errorf("Expected entity 'users', got '%s'", entity)
|
||||
}
|
||||
})
|
||||
|
||||
// Test TableName
|
||||
t.Run("WithTableName and GetTableName", func(t *testing.T) {
|
||||
ctx = WithTableName(ctx, "public.users")
|
||||
tableName := GetTableName(ctx)
|
||||
if tableName != "public.users" {
|
||||
t.Errorf("Expected tableName 'public.users', got '%s'", tableName)
|
||||
}
|
||||
})
|
||||
|
||||
// Test Model
|
||||
t.Run("WithModel and GetModel", func(t *testing.T) {
|
||||
type TestModel struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
model := &TestModel{ID: 1, Name: "test"}
|
||||
ctx = WithModel(ctx, model)
|
||||
retrieved := GetModel(ctx)
|
||||
if retrieved == nil {
|
||||
t.Error("Expected model to be retrieved, got nil")
|
||||
}
|
||||
if retrievedModel, ok := retrieved.(*TestModel); ok {
|
||||
if retrievedModel.ID != 1 || retrievedModel.Name != "test" {
|
||||
t.Errorf("Expected model with ID=1 and Name='test', got ID=%d, Name='%s'", retrievedModel.ID, retrievedModel.Name)
|
||||
}
|
||||
} else {
|
||||
t.Error("Retrieved model is not of expected type")
|
||||
}
|
||||
})
|
||||
|
||||
// Test ModelPtr
|
||||
t.Run("WithModelPtr and GetModelPtr", func(t *testing.T) {
|
||||
type TestModel struct {
|
||||
ID int
|
||||
}
|
||||
models := []*TestModel{}
|
||||
ctx = WithModelPtr(ctx, &models)
|
||||
retrieved := GetModelPtr(ctx)
|
||||
if retrieved == nil {
|
||||
t.Error("Expected modelPtr to be retrieved, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
// Test WithRequestData
|
||||
t.Run("WithRequestData", func(t *testing.T) {
|
||||
type TestModel struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
model := &TestModel{ID: 1, Name: "test"}
|
||||
modelPtr := &[]*TestModel{}
|
||||
|
||||
ctx = WithRequestData(ctx, "test_schema", "test_entity", "test_schema.test_entity", model, modelPtr)
|
||||
|
||||
if GetSchema(ctx) != "test_schema" {
|
||||
t.Errorf("Expected schema 'test_schema', got '%s'", GetSchema(ctx))
|
||||
}
|
||||
if GetEntity(ctx) != "test_entity" {
|
||||
t.Errorf("Expected entity 'test_entity', got '%s'", GetEntity(ctx))
|
||||
}
|
||||
if GetTableName(ctx) != "test_schema.test_entity" {
|
||||
t.Errorf("Expected tableName 'test_schema.test_entity', got '%s'", GetTableName(ctx))
|
||||
}
|
||||
if GetModel(ctx) == nil {
|
||||
t.Error("Expected model to be set")
|
||||
}
|
||||
if GetModelPtr(ctx) == nil {
|
||||
t.Error("Expected modelPtr to be set")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmptyContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetSchema with empty context", func(t *testing.T) {
|
||||
schema := GetSchema(ctx)
|
||||
if schema != "" {
|
||||
t.Errorf("Expected empty schema, got '%s'", schema)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetEntity with empty context", func(t *testing.T) {
|
||||
entity := GetEntity(ctx)
|
||||
if entity != "" {
|
||||
t.Errorf("Expected empty entity, got '%s'", entity)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetTableName with empty context", func(t *testing.T) {
|
||||
tableName := GetTableName(ctx)
|
||||
if tableName != "" {
|
||||
t.Errorf("Expected empty tableName, got '%s'", tableName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetModel with empty context", func(t *testing.T) {
|
||||
model := GetModel(ctx)
|
||||
if model != nil {
|
||||
t.Errorf("Expected nil model, got %v", model)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetModelPtr with empty context", func(t *testing.T) {
|
||||
modelPtr := GetModelPtr(ctx)
|
||||
if modelPtr != nil {
|
||||
t.Errorf("Expected nil modelPtr, got %v", modelPtr)
|
||||
}
|
||||
})
|
||||
}
|
||||
179
pkg/resolvespec/cursor.go
Normal file
179
pkg/resolvespec/cursor.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// CursorDirection defines pagination direction
|
||||
type CursorDirection int
|
||||
|
||||
const (
|
||||
CursorForward CursorDirection = 1
|
||||
CursorBackward CursorDirection = -1
|
||||
)
|
||||
|
||||
// GetCursorFilter generates a SQL `EXISTS` subquery for cursor-based pagination.
|
||||
// It uses the current request's sort and cursor values.
|
||||
//
|
||||
// Parameters:
|
||||
// - tableName: name of the main table (e.g. "posts")
|
||||
// - pkName: primary key column (e.g. "id")
|
||||
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
||||
// - options: the request options containing sort and cursor information
|
||||
//
|
||||
// Returns SQL snippet to embed in WHERE clause.
|
||||
func GetCursorFilter(
|
||||
tableName string,
|
||||
pkName string,
|
||||
modelColumns []string,
|
||||
options common.RequestOptions,
|
||||
) (string, error) {
|
||||
// Remove schema prefix if present
|
||||
if strings.Contains(tableName, ".") {
|
||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 1. Determine active cursor
|
||||
// --------------------------------------------------------------------- //
|
||||
cursorID, direction := getActiveCursor(options)
|
||||
if cursorID == "" {
|
||||
return "", fmt.Errorf("no cursor provided for table %s", tableName)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 2. Extract sort columns
|
||||
// --------------------------------------------------------------------- //
|
||||
sortItems := options.Sort
|
||||
if len(sortItems) == 0 {
|
||||
return "", fmt.Errorf("no sort columns defined")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 3. Prepare
|
||||
// --------------------------------------------------------------------- //
|
||||
var whereClauses []string
|
||||
reverse := direction < 0
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse: "created_at", "user.name", etc.
|
||||
parts := strings.Split(col, ".")
|
||||
field := strings.TrimSpace(parts[len(parts)-1])
|
||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||
|
||||
// Direction from struct
|
||||
desc := strings.EqualFold(s.Direction, "desc")
|
||||
|
||||
if reverse {
|
||||
desc = !desc
|
||||
}
|
||||
|
||||
// Resolve column
|
||||
cursorCol, targetCol, err := resolveColumn(
|
||||
field, prefix, tableName, modelColumns,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Build inequality
|
||||
op := "<"
|
||||
if desc {
|
||||
op = ">"
|
||||
}
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
|
||||
}
|
||||
|
||||
if len(whereClauses) == 0 {
|
||||
return "", fmt.Errorf("no valid sort columns after filtering")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 5. Build priority OR-AND chain
|
||||
// --------------------------------------------------------------------- //
|
||||
orSQL := buildPriorityChain(whereClauses)
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 6. Final EXISTS subquery
|
||||
// --------------------------------------------------------------------- //
|
||||
query := fmt.Sprintf(`EXISTS (
|
||||
SELECT 1
|
||||
FROM %s cursor_select
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
tableName,
|
||||
pkName,
|
||||
cursorID,
|
||||
orSQL,
|
||||
)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
// Helper: get active cursor (forward or backward)
|
||||
func getActiveCursor(options common.RequestOptions) (id string, direction CursorDirection) {
|
||||
if options.CursorForward != "" {
|
||||
return options.CursorForward, CursorForward
|
||||
}
|
||||
if options.CursorBackward != "" {
|
||||
return options.CursorBackward, CursorBackward
|
||||
}
|
||||
return "", 0
|
||||
}
|
||||
|
||||
// Helper: resolve column (main table only for now)
|
||||
func resolveColumn(
|
||||
field, prefix, tableName string,
|
||||
modelColumns []string,
|
||||
) (cursorCol, targetCol string, err error) {
|
||||
|
||||
// JSON field
|
||||
if strings.Contains(field, "->") {
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
}
|
||||
|
||||
// Main table column
|
||||
if modelColumns != nil {
|
||||
for _, col := range modelColumns {
|
||||
if strings.EqualFold(col, field) {
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No validation → allow all main-table fields
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
}
|
||||
|
||||
// Joined column (not supported in resolvespec yet)
|
||||
if prefix != "" && prefix != tableName {
|
||||
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("invalid column: %s", field)
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
// Helper: build OR-AND priority chain
|
||||
func buildPriorityChain(clauses []string) string {
|
||||
var or []string
|
||||
for i := 0; i < len(clauses); i++ {
|
||||
and := strings.Join(clauses[:i+1], "\n AND ")
|
||||
or = append(or, "("+and+")")
|
||||
}
|
||||
return strings.Join(or, "\n OR ")
|
||||
}
|
||||
378
pkg/resolvespec/cursor_test.go
Normal file
378
pkg/resolvespec/cursor_test.go
Normal file
@@ -0,0 +1,378 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
func TestGetCursorFilter_Forward(t *testing.T) {
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
{Column: "id", Direction: "ASC"},
|
||||
},
|
||||
CursorForward: "123",
|
||||
}
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
if filter == "" {
|
||||
t.Fatal("Expected non-empty cursor filter")
|
||||
}
|
||||
|
||||
// Verify filter contains EXISTS subquery
|
||||
if !strings.Contains(filter, "EXISTS") {
|
||||
t.Errorf("Filter should contain EXISTS subquery, got: %s", filter)
|
||||
}
|
||||
|
||||
// Verify filter references the cursor ID
|
||||
if !strings.Contains(filter, "123") {
|
||||
t.Errorf("Filter should reference cursor ID 123, got: %s", filter)
|
||||
}
|
||||
|
||||
// Verify filter contains the table name
|
||||
if !strings.Contains(filter, tableName) {
|
||||
t.Errorf("Filter should reference table name %s, got: %s", tableName, filter)
|
||||
}
|
||||
|
||||
// Verify filter contains primary key
|
||||
if !strings.Contains(filter, pkName) {
|
||||
t.Errorf("Filter should reference primary key %s, got: %s", pkName, filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated cursor filter: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_Backward(t *testing.T) {
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
{Column: "id", Direction: "ASC"},
|
||||
},
|
||||
CursorBackward: "456",
|
||||
}
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
if filter == "" {
|
||||
t.Fatal("Expected non-empty cursor filter")
|
||||
}
|
||||
|
||||
// Verify filter contains cursor ID
|
||||
if !strings.Contains(filter, "456") {
|
||||
t.Errorf("Filter should reference cursor ID 456, got: %s", filter)
|
||||
}
|
||||
|
||||
// For backward cursor, sort direction should be reversed
|
||||
// This is handled internally by the GetCursorFilter function
|
||||
t.Logf("Generated backward cursor filter: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_NoCursor(t *testing.T) {
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
},
|
||||
// No cursor set
|
||||
}
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at"}
|
||||
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no cursor is provided")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no cursor provided") {
|
||||
t.Errorf("Expected 'no cursor provided' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_NoSort(t *testing.T) {
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{},
|
||||
CursorForward: "123",
|
||||
}
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title"}
|
||||
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no sort columns are defined")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no sort columns") {
|
||||
t.Errorf("Expected 'no sort columns' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "priority", Direction: "DESC"},
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
{Column: "id", Direction: "ASC"},
|
||||
},
|
||||
CursorForward: "789",
|
||||
}
|
||||
|
||||
tableName := "tasks"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "priority", "created_at"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify filter contains priority column
|
||||
if !strings.Contains(filter, "priority") {
|
||||
t.Errorf("Filter should reference priority column, got: %s", filter)
|
||||
}
|
||||
|
||||
// Verify filter contains created_at column
|
||||
if !strings.Contains(filter, "created_at") {
|
||||
t.Errorf("Filter should reference created_at column, got: %s", filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated multi-column cursor filter: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "name", Direction: "ASC"},
|
||||
},
|
||||
CursorForward: "100",
|
||||
}
|
||||
|
||||
tableName := "public.users"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "name", "email"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
// Should handle schema prefix properly
|
||||
if !strings.Contains(filter, "users") {
|
||||
t.Errorf("Filter should reference table name users, got: %s", filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetActiveCursor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
options common.RequestOptions
|
||||
expectedID string
|
||||
expectedDirection CursorDirection
|
||||
}{
|
||||
{
|
||||
name: "Forward cursor only",
|
||||
options: common.RequestOptions{
|
||||
CursorForward: "123",
|
||||
},
|
||||
expectedID: "123",
|
||||
expectedDirection: CursorForward,
|
||||
},
|
||||
{
|
||||
name: "Backward cursor only",
|
||||
options: common.RequestOptions{
|
||||
CursorBackward: "456",
|
||||
},
|
||||
expectedID: "456",
|
||||
expectedDirection: CursorBackward,
|
||||
},
|
||||
{
|
||||
name: "Both cursors - forward takes precedence",
|
||||
options: common.RequestOptions{
|
||||
CursorForward: "123",
|
||||
CursorBackward: "456",
|
||||
},
|
||||
expectedID: "123",
|
||||
expectedDirection: CursorForward,
|
||||
},
|
||||
{
|
||||
name: "No cursors",
|
||||
options: common.RequestOptions{},
|
||||
expectedID: "",
|
||||
expectedDirection: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
id, direction := getActiveCursor(tt.options)
|
||||
|
||||
if id != tt.expectedID {
|
||||
t.Errorf("Expected cursor ID %q, got %q", tt.expectedID, id)
|
||||
}
|
||||
|
||||
if direction != tt.expectedDirection {
|
||||
t.Errorf("Expected direction %d, got %d", tt.expectedDirection, direction)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveColumn(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field string
|
||||
prefix string
|
||||
tableName string
|
||||
modelColumns []string
|
||||
wantCursor string
|
||||
wantTarget string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Simple column",
|
||||
field: "id",
|
||||
prefix: "",
|
||||
tableName: "users",
|
||||
modelColumns: []string{"id", "name", "email"},
|
||||
wantCursor: "cursor_select.id",
|
||||
wantTarget: "users.id",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Column with case insensitive match",
|
||||
field: "NAME",
|
||||
prefix: "",
|
||||
tableName: "users",
|
||||
modelColumns: []string{"id", "name", "email"},
|
||||
wantCursor: "cursor_select.NAME",
|
||||
wantTarget: "users.NAME",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid column",
|
||||
field: "invalid_field",
|
||||
prefix: "",
|
||||
tableName: "users",
|
||||
modelColumns: []string{"id", "name", "email"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "JSON field",
|
||||
field: "metadata->>'key'",
|
||||
prefix: "",
|
||||
tableName: "posts",
|
||||
modelColumns: []string{"id", "metadata"},
|
||||
wantCursor: "cursor_select.metadata->>'key'",
|
||||
wantTarget: "posts.metadata->>'key'",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Joined column (not supported)",
|
||||
field: "name",
|
||||
prefix: "user",
|
||||
tableName: "posts",
|
||||
modelColumns: []string{"id", "title"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if cursor != tt.wantCursor {
|
||||
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
|
||||
}
|
||||
|
||||
if target != tt.wantTarget {
|
||||
t.Errorf("Expected target %q, got %q", tt.wantTarget, target)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPriorityChain(t *testing.T) {
|
||||
clauses := []string{
|
||||
"cursor_select.priority > tasks.priority",
|
||||
"cursor_select.created_at > tasks.created_at",
|
||||
"cursor_select.id < tasks.id",
|
||||
}
|
||||
|
||||
result := buildPriorityChain(clauses)
|
||||
|
||||
// Should build OR-AND chain for cursor comparison
|
||||
if !strings.Contains(result, "OR") {
|
||||
t.Error("Priority chain should contain OR operators")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "AND") {
|
||||
t.Error("Priority chain should contain AND operators for composite conditions")
|
||||
}
|
||||
|
||||
// First clause should appear standalone
|
||||
if !strings.Contains(result, clauses[0]) {
|
||||
t.Errorf("Priority chain should contain first clause: %s", clauses[0])
|
||||
}
|
||||
|
||||
t.Logf("Built priority chain: %s", result)
|
||||
}
|
||||
|
||||
func TestCursorFilter_SQL_Safety(t *testing.T) {
|
||||
// Test that cursor filter doesn't allow SQL injection
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
},
|
||||
CursorForward: "123; DROP TABLE users; --",
|
||||
}
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "created_at"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
// The cursor ID is inserted directly into the query
|
||||
// This should be sanitized by the sanitizeWhereClause function in the handler
|
||||
// For now, just verify it generates a filter
|
||||
if filter == "" {
|
||||
t.Error("Expected non-empty cursor filter even with special characters")
|
||||
}
|
||||
|
||||
t.Logf("Generated filter with special chars in cursor: %s", filter)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
367
pkg/resolvespec/handler_test.go
Normal file
367
pkg/resolvespec/handler_test.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
func TestNewHandler(t *testing.T) {
|
||||
// Note: We can't create a real handler without actual DB and registry
|
||||
// But we can test that the constructor doesn't panic with nil values
|
||||
handler := NewHandler(nil, nil)
|
||||
if handler == nil {
|
||||
t.Error("Expected handler to be created, got nil")
|
||||
}
|
||||
|
||||
if handler.hooks == nil {
|
||||
t.Error("Expected hooks registry to be initialized")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerHooks(t *testing.T) {
|
||||
handler := NewHandler(nil, nil)
|
||||
hooks := handler.Hooks()
|
||||
if hooks == nil {
|
||||
t.Error("Expected hooks registry, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetFallbackHandler(t *testing.T) {
|
||||
handler := NewHandler(nil, nil)
|
||||
|
||||
// We can't directly call the fallback without mocks, but we can verify it's set
|
||||
handler.SetFallbackHandler(func(w common.ResponseWriter, r common.Request, params map[string]string) {
|
||||
// Fallback handler implementation
|
||||
})
|
||||
|
||||
if handler.fallbackHandler == nil {
|
||||
t.Error("Expected fallback handler to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDatabase(t *testing.T) {
|
||||
handler := NewHandler(nil, nil)
|
||||
db := handler.GetDatabase()
|
||||
// Should return nil since we passed nil
|
||||
if db != nil {
|
||||
t.Error("Expected nil database")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTableName(t *testing.T) {
|
||||
handler := NewHandler(nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fullTableName string
|
||||
expectedSchema string
|
||||
expectedTable string
|
||||
}{
|
||||
{
|
||||
name: "Table with schema",
|
||||
fullTableName: "public.users",
|
||||
expectedSchema: "public",
|
||||
expectedTable: "users",
|
||||
},
|
||||
{
|
||||
name: "Table without schema",
|
||||
fullTableName: "users",
|
||||
expectedSchema: "",
|
||||
expectedTable: "users",
|
||||
},
|
||||
{
|
||||
name: "Multiple dots (use last)",
|
||||
fullTableName: "db.public.users",
|
||||
expectedSchema: "db.public",
|
||||
expectedTable: "users",
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
fullTableName: "",
|
||||
expectedSchema: "",
|
||||
expectedTable: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
schema, table := handler.parseTableName(tt.fullTableName)
|
||||
if schema != tt.expectedSchema {
|
||||
t.Errorf("Expected schema '%s', got '%s'", tt.expectedSchema, schema)
|
||||
}
|
||||
if table != tt.expectedTable {
|
||||
t.Errorf("Expected table '%s', got '%s'", tt.expectedTable, table)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetColumnType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field reflect.StructField
|
||||
expectedType string
|
||||
}{
|
||||
{
|
||||
name: "String field",
|
||||
field: reflect.StructField{
|
||||
Name: "Name",
|
||||
Type: reflect.TypeOf(""),
|
||||
},
|
||||
expectedType: "string",
|
||||
},
|
||||
{
|
||||
name: "Int field",
|
||||
field: reflect.StructField{
|
||||
Name: "Count",
|
||||
Type: reflect.TypeOf(int(0)),
|
||||
},
|
||||
expectedType: "integer",
|
||||
},
|
||||
{
|
||||
name: "Int32 field",
|
||||
field: reflect.StructField{
|
||||
Name: "ID",
|
||||
Type: reflect.TypeOf(int32(0)),
|
||||
},
|
||||
expectedType: "integer",
|
||||
},
|
||||
{
|
||||
name: "Int64 field",
|
||||
field: reflect.StructField{
|
||||
Name: "BigID",
|
||||
Type: reflect.TypeOf(int64(0)),
|
||||
},
|
||||
expectedType: "bigint",
|
||||
},
|
||||
{
|
||||
name: "Float32 field",
|
||||
field: reflect.StructField{
|
||||
Name: "Price",
|
||||
Type: reflect.TypeOf(float32(0)),
|
||||
},
|
||||
expectedType: "float",
|
||||
},
|
||||
{
|
||||
name: "Float64 field",
|
||||
field: reflect.StructField{
|
||||
Name: "Amount",
|
||||
Type: reflect.TypeOf(float64(0)),
|
||||
},
|
||||
expectedType: "double",
|
||||
},
|
||||
{
|
||||
name: "Bool field",
|
||||
field: reflect.StructField{
|
||||
Name: "Active",
|
||||
Type: reflect.TypeOf(false),
|
||||
},
|
||||
expectedType: "boolean",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
colType := getColumnType(tt.field)
|
||||
if colType != tt.expectedType {
|
||||
t.Errorf("Expected column type '%s', got '%s'", tt.expectedType, colType)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNullable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
field reflect.StructField
|
||||
nullable bool
|
||||
}{
|
||||
{
|
||||
name: "Pointer type is nullable",
|
||||
field: reflect.StructField{
|
||||
Name: "Name",
|
||||
Type: reflect.TypeOf((*string)(nil)),
|
||||
},
|
||||
nullable: true,
|
||||
},
|
||||
{
|
||||
name: "Non-pointer type without explicit 'not null' tag",
|
||||
field: reflect.StructField{
|
||||
Name: "ID",
|
||||
Type: reflect.TypeOf(int(0)),
|
||||
},
|
||||
nullable: true, // isNullable returns true if there's no explicit "not null" tag
|
||||
},
|
||||
{
|
||||
name: "Field with 'not null' tag is not nullable",
|
||||
field: reflect.StructField{
|
||||
Name: "Email",
|
||||
Type: reflect.TypeOf(""),
|
||||
Tag: `gorm:"not null"`,
|
||||
},
|
||||
nullable: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isNullable(tt.field)
|
||||
if result != tt.nullable {
|
||||
t.Errorf("Expected nullable=%v, got %v", tt.nullable, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToSnakeCase(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
input: "UserID",
|
||||
expected: "user_id",
|
||||
},
|
||||
{
|
||||
input: "DepartmentName",
|
||||
expected: "department_name",
|
||||
},
|
||||
{
|
||||
input: "ID",
|
||||
expected: "id",
|
||||
},
|
||||
{
|
||||
input: "HTTPServer",
|
||||
expected: "http_server",
|
||||
},
|
||||
{
|
||||
input: "createdAt",
|
||||
expected: "created_at",
|
||||
},
|
||||
{
|
||||
input: "name",
|
||||
expected: "name",
|
||||
},
|
||||
{
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
input: "A",
|
||||
expected: "a",
|
||||
},
|
||||
{
|
||||
input: "APIKey",
|
||||
expected: "api_key",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := toSnakeCase(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("toSnakeCase(%q) = %q, expected %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTagValue(t *testing.T) {
|
||||
handler := NewHandler(nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tag string
|
||||
key string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Extract foreignKey",
|
||||
tag: "foreignKey:UserID;references:ID",
|
||||
key: "foreignKey",
|
||||
expected: "UserID",
|
||||
},
|
||||
{
|
||||
name: "Extract references",
|
||||
tag: "foreignKey:UserID;references:ID",
|
||||
key: "references",
|
||||
expected: "ID",
|
||||
},
|
||||
{
|
||||
name: "Key not found",
|
||||
tag: "foreignKey:UserID;references:ID",
|
||||
key: "notfound",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Empty tag",
|
||||
tag: "",
|
||||
key: "foreignKey",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Single value",
|
||||
tag: "many2many:user_roles",
|
||||
key: "many2many",
|
||||
expected: "user_roles",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.extractTagValue(tt.tag, tt.key)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractTagValue(%q, %q) = %q, expected %q", tt.tag, tt.key, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyFilter(t *testing.T) {
|
||||
// Note: Without a real database, we can't fully test query execution
|
||||
// But we can test that the method exists
|
||||
_ = NewHandler(nil, nil)
|
||||
|
||||
// The applyFilter method exists and can be tested with actual queries
|
||||
// but requires database setup which is beyond unit test scope
|
||||
t.Log("applyFilter method exists and is used in handler operations")
|
||||
}
|
||||
|
||||
func TestShouldUseNestedProcessor(t *testing.T) {
|
||||
handler := NewHandler(nil, nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data map[string]interface{}
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Has _request field",
|
||||
data: map[string]interface{}{
|
||||
"_request": "nested",
|
||||
"name": "test",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "No special fields",
|
||||
data: map[string]interface{}{
|
||||
"name": "test",
|
||||
"email": "test@example.com",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Note: Without a real model, we can't fully test this
|
||||
// But we can verify the function exists
|
||||
result := handler.shouldUseNestedProcessor(tt.data, nil)
|
||||
// The actual result depends on the model structure
|
||||
_ = result
|
||||
})
|
||||
}
|
||||
}
|
||||
152
pkg/resolvespec/hooks.go
Normal file
152
pkg/resolvespec/hooks.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// HookType defines the type of hook to execute
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// Read operation hooks
|
||||
BeforeRead HookType = "before_read"
|
||||
AfterRead HookType = "after_read"
|
||||
|
||||
// Create operation hooks
|
||||
BeforeCreate HookType = "before_create"
|
||||
AfterCreate HookType = "after_create"
|
||||
|
||||
// Update operation hooks
|
||||
BeforeUpdate HookType = "before_update"
|
||||
AfterUpdate HookType = "after_update"
|
||||
|
||||
// Delete operation hooks
|
||||
BeforeDelete HookType = "before_delete"
|
||||
AfterDelete HookType = "after_delete"
|
||||
|
||||
// Scan/Execute operation hooks (for query building)
|
||||
BeforeScan HookType = "before_scan"
|
||||
)
|
||||
|
||||
// HookContext contains all the data available to a hook
|
||||
type HookContext struct {
|
||||
Context context.Context
|
||||
Handler *Handler // Reference to the handler for accessing database, registry, etc.
|
||||
Schema string
|
||||
Entity string
|
||||
Model interface{}
|
||||
Options common.RequestOptions
|
||||
Writer common.ResponseWriter
|
||||
Request common.Request
|
||||
|
||||
// Operation-specific fields
|
||||
ID string
|
||||
Data interface{} // For create/update operations
|
||||
Result interface{} // For after hooks
|
||||
Error error // For after hooks
|
||||
|
||||
// Query chain - allows hooks to modify the query before execution
|
||||
Query common.SelectQuery
|
||||
|
||||
// Allow hooks to abort the operation
|
||||
Abort bool // If set to true, the operation will be aborted
|
||||
AbortMessage string // Message to return if aborted
|
||||
AbortCode int // HTTP status code if aborted
|
||||
}
|
||||
|
||||
// HookFunc is the signature for hook functions
|
||||
// It receives a HookContext and can modify it or return an error
|
||||
// If an error is returned, the operation will be aborted
|
||||
type HookFunc func(*HookContext) error
|
||||
|
||||
// HookRegistry manages all registered hooks
|
||||
type HookRegistry struct {
|
||||
hooks map[HookType][]HookFunc
|
||||
}
|
||||
|
||||
// NewHookRegistry creates a new hook registry
|
||||
func NewHookRegistry() *HookRegistry {
|
||||
return &HookRegistry{
|
||||
hooks: make(map[HookType][]HookFunc),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a new hook for the specified hook type
|
||||
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
|
||||
if r.hooks == nil {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
}
|
||||
r.hooks[hookType] = append(r.hooks[hookType], hook)
|
||||
logger.Info("Registered resolvespec hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
|
||||
}
|
||||
|
||||
// RegisterMultiple registers a hook for multiple hook types
|
||||
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
|
||||
for _, hookType := range hookTypes {
|
||||
r.Register(hookType, hook)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs all hooks for the specified type in order
|
||||
// If any hook returns an error, execution stops and the error is returned
|
||||
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||
hooks, exists := r.hooks[hookType]
|
||||
if !exists || len(hooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Executing %d resolvespec hook(s) for %s", len(hooks), hookType)
|
||||
|
||||
for i, hook := range hooks {
|
||||
if err := hook(ctx); err != nil {
|
||||
logger.Error("Resolvespec hook %d for %s failed: %v", i+1, hookType, err)
|
||||
return fmt.Errorf("hook execution failed: %w", err)
|
||||
}
|
||||
|
||||
// Check if hook requested abort
|
||||
if ctx.Abort {
|
||||
logger.Warn("Resolvespec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear removes all hooks for the specified type
|
||||
func (r *HookRegistry) Clear(hookType HookType) {
|
||||
delete(r.hooks, hookType)
|
||||
logger.Info("Cleared all resolvespec hooks for %s", hookType)
|
||||
}
|
||||
|
||||
// ClearAll removes all registered hooks
|
||||
func (r *HookRegistry) ClearAll() {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
logger.Info("Cleared all resolvespec hooks")
|
||||
}
|
||||
|
||||
// Count returns the number of hooks registered for a specific type
|
||||
func (r *HookRegistry) Count(hookType HookType) int {
|
||||
if hooks, exists := r.hooks[hookType]; exists {
|
||||
return len(hooks)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// HasHooks returns true if there are any hooks registered for the specified type
|
||||
func (r *HookRegistry) HasHooks(hookType HookType) bool {
|
||||
return r.Count(hookType) > 0
|
||||
}
|
||||
|
||||
// GetAllHookTypes returns all hook types that have registered hooks
|
||||
func (r *HookRegistry) GetAllHookTypes() []HookType {
|
||||
types := make([]HookType, 0, len(r.hooks))
|
||||
for hookType := range r.hooks {
|
||||
types = append(types, hookType)
|
||||
}
|
||||
return types
|
||||
}
|
||||
400
pkg/resolvespec/hooks_test.go
Normal file
400
pkg/resolvespec/hooks_test.go
Normal file
@@ -0,0 +1,400 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHookRegistry(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
// Test registering a hook
|
||||
called := false
|
||||
hook := func(ctx *HookContext) error {
|
||||
called = true
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeRead, hook)
|
||||
|
||||
if registry.Count(BeforeRead) != 1 {
|
||||
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeRead))
|
||||
}
|
||||
|
||||
// Test executing a hook
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
}
|
||||
|
||||
err := registry.Execute(BeforeRead, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Hook execution failed: %v", err)
|
||||
}
|
||||
|
||||
if !called {
|
||||
t.Error("Hook was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookExecutionOrder(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
order := []int{}
|
||||
|
||||
hook1 := func(ctx *HookContext) error {
|
||||
order = append(order, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
hook2 := func(ctx *HookContext) error {
|
||||
order = append(order, 2)
|
||||
return nil
|
||||
}
|
||||
|
||||
hook3 := func(ctx *HookContext) error {
|
||||
order = append(order, 3)
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeCreate, hook1)
|
||||
registry.Register(BeforeCreate, hook2)
|
||||
registry.Register(BeforeCreate, hook3)
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
}
|
||||
|
||||
err := registry.Execute(BeforeCreate, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Hook execution failed: %v", err)
|
||||
}
|
||||
|
||||
if len(order) != 3 {
|
||||
t.Errorf("Expected 3 hooks to be called, got %d", len(order))
|
||||
}
|
||||
|
||||
if order[0] != 1 || order[1] != 2 || order[2] != 3 {
|
||||
t.Errorf("Hooks executed in wrong order: %v", order)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookError(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
executed := []string{}
|
||||
|
||||
hook1 := func(ctx *HookContext) error {
|
||||
executed = append(executed, "hook1")
|
||||
return nil
|
||||
}
|
||||
|
||||
hook2 := func(ctx *HookContext) error {
|
||||
executed = append(executed, "hook2")
|
||||
return fmt.Errorf("hook2 error")
|
||||
}
|
||||
|
||||
hook3 := func(ctx *HookContext) error {
|
||||
executed = append(executed, "hook3")
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeUpdate, hook1)
|
||||
registry.Register(BeforeUpdate, hook2)
|
||||
registry.Register(BeforeUpdate, hook3)
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
}
|
||||
|
||||
err := registry.Execute(BeforeUpdate, ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected error from hook execution")
|
||||
}
|
||||
|
||||
if len(executed) != 2 {
|
||||
t.Errorf("Expected only 2 hooks to be executed, got %d", len(executed))
|
||||
}
|
||||
|
||||
if executed[0] != "hook1" || executed[1] != "hook2" {
|
||||
t.Errorf("Unexpected execution order: %v", executed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookDataModification(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
modifyHook := func(ctx *HookContext) error {
|
||||
if dataMap, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
dataMap["modified"] = true
|
||||
ctx.Data = dataMap
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeCreate, modifyHook)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "test",
|
||||
}
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
Data: data,
|
||||
}
|
||||
|
||||
err := registry.Execute(BeforeCreate, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Hook execution failed: %v", err)
|
||||
}
|
||||
|
||||
modifiedData := ctx.Data.(map[string]interface{})
|
||||
if !modifiedData["modified"].(bool) {
|
||||
t.Error("Data was not modified by hook")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterMultiple(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
called := 0
|
||||
hook := func(ctx *HookContext) error {
|
||||
called++
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.RegisterMultiple([]HookType{
|
||||
BeforeRead,
|
||||
BeforeCreate,
|
||||
BeforeUpdate,
|
||||
}, hook)
|
||||
|
||||
if registry.Count(BeforeRead) != 1 {
|
||||
t.Error("Hook not registered for BeforeRead")
|
||||
}
|
||||
if registry.Count(BeforeCreate) != 1 {
|
||||
t.Error("Hook not registered for BeforeCreate")
|
||||
}
|
||||
if registry.Count(BeforeUpdate) != 1 {
|
||||
t.Error("Hook not registered for BeforeUpdate")
|
||||
}
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
}
|
||||
|
||||
registry.Execute(BeforeRead, ctx)
|
||||
registry.Execute(BeforeCreate, ctx)
|
||||
registry.Execute(BeforeUpdate, ctx)
|
||||
|
||||
if called != 3 {
|
||||
t.Errorf("Expected hook to be called 3 times, got %d", called)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearHooks(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
hook := func(ctx *HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeRead, hook)
|
||||
registry.Register(BeforeCreate, hook)
|
||||
|
||||
if registry.Count(BeforeRead) != 1 {
|
||||
t.Error("Hook not registered")
|
||||
}
|
||||
|
||||
registry.Clear(BeforeRead)
|
||||
|
||||
if registry.Count(BeforeRead) != 0 {
|
||||
t.Error("Hook not cleared")
|
||||
}
|
||||
|
||||
if registry.Count(BeforeCreate) != 1 {
|
||||
t.Error("Wrong hook was cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearAllHooks(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
hook := func(ctx *HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeRead, hook)
|
||||
registry.Register(BeforeCreate, hook)
|
||||
registry.Register(BeforeUpdate, hook)
|
||||
|
||||
registry.ClearAll()
|
||||
|
||||
if registry.Count(BeforeRead) != 0 || registry.Count(BeforeCreate) != 0 || registry.Count(BeforeUpdate) != 0 {
|
||||
t.Error("Not all hooks were cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasHooks(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
if registry.HasHooks(BeforeRead) {
|
||||
t.Error("Should not have hooks initially")
|
||||
}
|
||||
|
||||
hook := func(ctx *HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeRead, hook)
|
||||
|
||||
if !registry.HasHooks(BeforeRead) {
|
||||
t.Error("Should have hooks after registration")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllHookTypes(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
hook := func(ctx *HookContext) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeRead, hook)
|
||||
registry.Register(BeforeCreate, hook)
|
||||
registry.Register(AfterUpdate, hook)
|
||||
|
||||
types := registry.GetAllHookTypes()
|
||||
|
||||
if len(types) != 3 {
|
||||
t.Errorf("Expected 3 hook types, got %d", len(types))
|
||||
}
|
||||
|
||||
// Verify all expected types are present
|
||||
expectedTypes := map[HookType]bool{
|
||||
BeforeRead: true,
|
||||
BeforeCreate: true,
|
||||
AfterUpdate: true,
|
||||
}
|
||||
|
||||
for _, hookType := range types {
|
||||
if !expectedTypes[hookType] {
|
||||
t.Errorf("Unexpected hook type: %s", hookType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookContextHandler(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
var capturedHandler *Handler
|
||||
|
||||
hook := func(ctx *HookContext) error {
|
||||
if ctx.Handler == nil {
|
||||
return fmt.Errorf("handler is nil in hook context")
|
||||
}
|
||||
capturedHandler = ctx.Handler
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeRead, hook)
|
||||
|
||||
handler := &Handler{
|
||||
hooks: registry,
|
||||
}
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Handler: handler,
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
}
|
||||
|
||||
err := registry.Execute(BeforeRead, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Hook execution failed: %v", err)
|
||||
}
|
||||
|
||||
if capturedHandler == nil {
|
||||
t.Error("Handler was not captured from hook context")
|
||||
}
|
||||
|
||||
if capturedHandler != handler {
|
||||
t.Error("Captured handler does not match original handler")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookAbort(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
abortHook := func(ctx *HookContext) error {
|
||||
ctx.Abort = true
|
||||
ctx.AbortMessage = "Operation aborted by hook"
|
||||
ctx.AbortCode = 403
|
||||
return nil
|
||||
}
|
||||
|
||||
registry.Register(BeforeCreate, abortHook)
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
}
|
||||
|
||||
err := registry.Execute(BeforeCreate, ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected error when hook sets Abort=true")
|
||||
}
|
||||
|
||||
if err.Error() != "operation aborted by hook: Operation aborted by hook" {
|
||||
t.Errorf("Expected abort error message, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHookTypes(t *testing.T) {
|
||||
// Test all hook type constants
|
||||
hookTypes := []HookType{
|
||||
BeforeRead,
|
||||
AfterRead,
|
||||
BeforeCreate,
|
||||
AfterCreate,
|
||||
BeforeUpdate,
|
||||
AfterUpdate,
|
||||
BeforeDelete,
|
||||
AfterDelete,
|
||||
BeforeScan,
|
||||
}
|
||||
|
||||
for _, hookType := range hookTypes {
|
||||
if string(hookType) == "" {
|
||||
t.Errorf("Hook type should not be empty: %v", hookType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteWithNoHooks(t *testing.T) {
|
||||
registry := NewHookRegistry()
|
||||
|
||||
ctx := &HookContext{
|
||||
Context: context.Background(),
|
||||
Schema: "test",
|
||||
Entity: "users",
|
||||
}
|
||||
|
||||
// Executing with no registered hooks should not cause an error
|
||||
err := registry.Execute(BeforeRead, ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Execute should not fail with no hooks, got: %v", err)
|
||||
}
|
||||
}
|
||||
508
pkg/resolvespec/integration_test.go
Normal file
508
pkg/resolvespec/integration_test.go
Normal file
@@ -0,0 +1,508 @@
|
||||
// +build integration
|
||||
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// Test models
|
||||
type TestUser struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"not null" json:"name"`
|
||||
Email string `gorm:"uniqueIndex;not null" json:"email"`
|
||||
Age int `json:"age"`
|
||||
Active bool `gorm:"default:true" json:"active"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Posts []TestPost `gorm:"foreignKey:UserID" json:"posts,omitempty"`
|
||||
}
|
||||
|
||||
func (TestUser) TableName() string {
|
||||
return "test_users"
|
||||
}
|
||||
|
||||
type TestPost struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
UserID uint `gorm:"not null" json:"user_id"`
|
||||
Title string `gorm:"not null" json:"title"`
|
||||
Content string `json:"content"`
|
||||
Published bool `gorm:"default:false" json:"published"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
User *TestUser `gorm:"foreignKey:UserID" json:"user,omitempty"`
|
||||
Comments []TestComment `gorm:"foreignKey:PostID" json:"comments,omitempty"`
|
||||
}
|
||||
|
||||
func (TestPost) TableName() string {
|
||||
return "test_posts"
|
||||
}
|
||||
|
||||
type TestComment struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
PostID uint `gorm:"not null" json:"post_id"`
|
||||
Content string `gorm:"not null" json:"content"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Post *TestPost `gorm:"foreignKey:PostID" json:"post,omitempty"`
|
||||
}
|
||||
|
||||
func (TestComment) TableName() string {
|
||||
return "test_comments"
|
||||
}
|
||||
|
||||
// Test helper functions
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
// Get connection string from environment or use default
|
||||
dsn := os.Getenv("TEST_DATABASE_URL")
|
||||
if dsn == "" {
|
||||
dsn = "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable"
|
||||
}
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Skipf("Skipping integration test: database not available: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
err = db.AutoMigrate(&TestUser{}, &TestPost{}, &TestComment{})
|
||||
if err != nil {
|
||||
t.Skipf("Skipping integration test: failed to migrate database: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func cleanupTestDB(t *testing.T, db *gorm.DB) {
|
||||
// Clean up test data
|
||||
db.Exec("TRUNCATE TABLE test_comments CASCADE")
|
||||
db.Exec("TRUNCATE TABLE test_posts CASCADE")
|
||||
db.Exec("TRUNCATE TABLE test_users CASCADE")
|
||||
}
|
||||
|
||||
func createTestData(t *testing.T, db *gorm.DB) {
|
||||
users := []TestUser{
|
||||
{Name: "John Doe", Email: "john@example.com", Age: 30, Active: true},
|
||||
{Name: "Jane Smith", Email: "jane@example.com", Age: 25, Active: true},
|
||||
{Name: "Bob Johnson", Email: "bob@example.com", Age: 35, Active: false},
|
||||
}
|
||||
|
||||
for i := range users {
|
||||
if err := db.Create(&users[i]).Error; err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
posts := []TestPost{
|
||||
{UserID: users[0].ID, Title: "First Post", Content: "Hello World", Published: true},
|
||||
{UserID: users[0].ID, Title: "Second Post", Content: "More content", Published: true},
|
||||
{UserID: users[1].ID, Title: "Jane's Post", Content: "Jane's content", Published: false},
|
||||
}
|
||||
|
||||
for i := range posts {
|
||||
if err := db.Create(&posts[i]).Error; err != nil {
|
||||
t.Fatalf("Failed to create test post: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
comments := []TestComment{
|
||||
{PostID: posts[0].ID, Content: "Great post!"},
|
||||
{PostID: posts[0].ID, Content: "Thanks for sharing"},
|
||||
{PostID: posts[1].ID, Content: "Interesting"},
|
||||
}
|
||||
|
||||
for i := range comments {
|
||||
if err := db.Create(&comments[i]).Error; err != nil {
|
||||
t.Fatalf("Failed to create test comment: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Integration tests
|
||||
func TestIntegration_CreateOperation(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db)
|
||||
|
||||
handler := NewHandlerWithGORM(db)
|
||||
handler.RegisterModel("public", "test_users", TestUser{})
|
||||
|
||||
muxRouter := mux.NewRouter()
|
||||
SetupMuxRoutes(muxRouter, handler, nil)
|
||||
|
||||
// Create a new user
|
||||
requestBody := map[string]interface{}{
|
||||
"operation": "create",
|
||||
"data": map[string]interface{}{
|
||||
"name": "Test User",
|
||||
"email": "test@example.com",
|
||||
"age": 28,
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(requestBody)
|
||||
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
muxRouter.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response common.Response
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if !response.Success {
|
||||
t.Errorf("Expected success=true, got %v. Error: %v", response.Success, response.Error)
|
||||
}
|
||||
|
||||
// Verify user was created
|
||||
var user TestUser
|
||||
if err := db.Where("email = ?", "test@example.com").First(&user).Error; err != nil {
|
||||
t.Errorf("Failed to find created user: %v", err)
|
||||
}
|
||||
|
||||
if user.Name != "Test User" {
|
||||
t.Errorf("Expected name 'Test User', got '%s'", user.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_ReadOperation(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db)
|
||||
createTestData(t, db)
|
||||
|
||||
handler := NewHandlerWithGORM(db)
|
||||
handler.RegisterModel("public", "test_users", TestUser{})
|
||||
|
||||
muxRouter := mux.NewRouter()
|
||||
SetupMuxRoutes(muxRouter, handler, nil)
|
||||
|
||||
// Read all users
|
||||
requestBody := map[string]interface{}{
|
||||
"operation": "read",
|
||||
"options": map[string]interface{}{
|
||||
"limit": 10,
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(requestBody)
|
||||
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
muxRouter.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response common.Response
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if !response.Success {
|
||||
t.Errorf("Expected success=true, got %v", response.Success)
|
||||
}
|
||||
|
||||
if response.Metadata == nil {
|
||||
t.Fatal("Expected metadata, got nil")
|
||||
}
|
||||
|
||||
if response.Metadata.Total != 3 {
|
||||
t.Errorf("Expected 3 users, got %d", response.Metadata.Total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_ReadWithFilters(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db)
|
||||
createTestData(t, db)
|
||||
|
||||
handler := NewHandlerWithGORM(db)
|
||||
handler.RegisterModel("public", "test_users", TestUser{})
|
||||
|
||||
muxRouter := mux.NewRouter()
|
||||
SetupMuxRoutes(muxRouter, handler, nil)
|
||||
|
||||
// Read users with age > 25
|
||||
requestBody := map[string]interface{}{
|
||||
"operation": "read",
|
||||
"options": map[string]interface{}{
|
||||
"filters": []map[string]interface{}{
|
||||
{
|
||||
"column": "age",
|
||||
"operator": "gt",
|
||||
"value": 25,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(requestBody)
|
||||
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
muxRouter.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var response common.Response
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if !response.Success {
|
||||
t.Errorf("Expected success=true, got %v", response.Success)
|
||||
}
|
||||
|
||||
// Should return 2 users (John: 30, Bob: 35)
|
||||
if response.Metadata.Total != 2 {
|
||||
t.Errorf("Expected 2 filtered users, got %d", response.Metadata.Total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_UpdateOperation(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db)
|
||||
createTestData(t, db)
|
||||
|
||||
handler := NewHandlerWithGORM(db)
|
||||
handler.RegisterModel("public", "test_users", TestUser{})
|
||||
|
||||
muxRouter := mux.NewRouter()
|
||||
SetupMuxRoutes(muxRouter, handler, nil)
|
||||
|
||||
// Get user ID
|
||||
var user TestUser
|
||||
db.Where("email = ?", "john@example.com").First(&user)
|
||||
|
||||
// Update user
|
||||
requestBody := map[string]interface{}{
|
||||
"operation": "update",
|
||||
"data": map[string]interface{}{
|
||||
"id": user.ID,
|
||||
"age": 31,
|
||||
"name": "John Doe Updated",
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(requestBody)
|
||||
req := httptest.NewRequest("POST", fmt.Sprintf("/public/test_users/%d", user.ID), bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
muxRouter.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify update
|
||||
var updatedUser TestUser
|
||||
db.First(&updatedUser, user.ID)
|
||||
|
||||
if updatedUser.Age != 31 {
|
||||
t.Errorf("Expected age 31, got %d", updatedUser.Age)
|
||||
}
|
||||
if updatedUser.Name != "John Doe Updated" {
|
||||
t.Errorf("Expected name 'John Doe Updated', got '%s'", updatedUser.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_DeleteOperation(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db)
|
||||
createTestData(t, db)
|
||||
|
||||
handler := NewHandlerWithGORM(db)
|
||||
handler.RegisterModel("public", "test_users", TestUser{})
|
||||
|
||||
muxRouter := mux.NewRouter()
|
||||
SetupMuxRoutes(muxRouter, handler, nil)
|
||||
|
||||
// Get user ID
|
||||
var user TestUser
|
||||
db.Where("email = ?", "bob@example.com").First(&user)
|
||||
|
||||
// Delete user
|
||||
requestBody := map[string]interface{}{
|
||||
"operation": "delete",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(requestBody)
|
||||
req := httptest.NewRequest("POST", fmt.Sprintf("/public/test_users/%d", user.ID), bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
muxRouter.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Verify deletion
|
||||
var count int64
|
||||
db.Model(&TestUser{}).Where("id = ?", user.ID).Count(&count)
|
||||
|
||||
if count != 0 {
|
||||
t.Errorf("Expected user to be deleted, but found %d records", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_MetadataOperation(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db)
|
||||
|
||||
handler := NewHandlerWithGORM(db)
|
||||
handler.RegisterModel("public", "test_users", TestUser{})
|
||||
|
||||
muxRouter := mux.NewRouter()
|
||||
SetupMuxRoutes(muxRouter, handler, nil)
|
||||
|
||||
// Get metadata
|
||||
requestBody := map[string]interface{}{
|
||||
"operation": "meta",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(requestBody)
|
||||
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
muxRouter.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response common.Response
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if !response.Success {
|
||||
t.Errorf("Expected success=true, got %v", response.Success)
|
||||
}
|
||||
|
||||
// Check that metadata includes columns
|
||||
// The response.Data is an interface{}, we need to unmarshal it properly
|
||||
dataBytes, _ := json.Marshal(response.Data)
|
||||
var metadata common.TableMetadata
|
||||
if err := json.Unmarshal(dataBytes, &metadata); err != nil {
|
||||
t.Fatalf("Failed to unmarshal metadata: %v. Raw data: %+v", err, response.Data)
|
||||
}
|
||||
|
||||
if len(metadata.Columns) == 0 {
|
||||
t.Error("Expected metadata to contain columns")
|
||||
}
|
||||
|
||||
// Verify some expected columns
|
||||
hasID := false
|
||||
hasName := false
|
||||
hasEmail := false
|
||||
for _, col := range metadata.Columns {
|
||||
if col.Name == "id" {
|
||||
hasID = true
|
||||
if !col.IsPrimary {
|
||||
t.Error("Expected 'id' column to be primary key")
|
||||
}
|
||||
}
|
||||
if col.Name == "name" {
|
||||
hasName = true
|
||||
}
|
||||
if col.Name == "email" {
|
||||
hasEmail = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasID || !hasName || !hasEmail {
|
||||
t.Error("Expected metadata to contain 'id', 'name', and 'email' columns")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_ReadWithPreload(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
defer cleanupTestDB(t, db)
|
||||
createTestData(t, db)
|
||||
|
||||
handler := NewHandlerWithGORM(db)
|
||||
handler.RegisterModel("public", "test_users", TestUser{})
|
||||
handler.RegisterModel("public", "test_posts", TestPost{})
|
||||
handler.RegisterModel("public", "test_comments", TestComment{})
|
||||
|
||||
muxRouter := mux.NewRouter()
|
||||
SetupMuxRoutes(muxRouter, handler, nil)
|
||||
|
||||
// Read users with posts preloaded
|
||||
requestBody := map[string]interface{}{
|
||||
"operation": "read",
|
||||
"options": map[string]interface{}{
|
||||
"filters": []map[string]interface{}{
|
||||
{
|
||||
"column": "email",
|
||||
"operator": "eq",
|
||||
"value": "john@example.com",
|
||||
},
|
||||
},
|
||||
"preload": []map[string]interface{}{
|
||||
{"relation": "posts"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(requestBody)
|
||||
req := httptest.NewRequest("POST", "/public/test_users", bytes.NewBuffer(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
muxRouter.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("Expected status 200, got %d. Body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response common.Response
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse response: %v", err)
|
||||
}
|
||||
|
||||
if !response.Success {
|
||||
t.Errorf("Expected success=true, got %v", response.Success)
|
||||
}
|
||||
|
||||
// Verify posts are preloaded
|
||||
dataBytes, _ := json.Marshal(response.Data)
|
||||
var users []TestUser
|
||||
json.Unmarshal(dataBytes, &users)
|
||||
|
||||
if len(users) == 0 {
|
||||
t.Fatal("Expected at least one user")
|
||||
}
|
||||
|
||||
if len(users[0].Posts) == 0 {
|
||||
t.Error("Expected posts to be preloaded")
|
||||
}
|
||||
}
|
||||
@@ -10,18 +10,18 @@ type GormTableSchemaInterface interface {
|
||||
}
|
||||
|
||||
type GormTableCRUDRequest struct {
|
||||
CRUDRequest *string `json:"crud_request"`
|
||||
Request *string `json:"_request"`
|
||||
}
|
||||
|
||||
func (r *GormTableCRUDRequest) SetRequest(request string) {
|
||||
r.CRUDRequest = &request
|
||||
r.Request = &request
|
||||
}
|
||||
|
||||
func (r GormTableCRUDRequest) GetRequest() string {
|
||||
return *r.CRUDRequest
|
||||
return *r.Request
|
||||
}
|
||||
|
||||
// New interfaces that replace the legacy ones above
|
||||
// These are now defined in database.go:
|
||||
// - TableNameProvider (replaces GormTableNameInterface)
|
||||
// - TableNameProvider (replaces GormTableNameInterface)
|
||||
// - SchemaProvider (replaces GormTableSchemaInterface)
|
||||
|
||||
@@ -2,14 +2,17 @@ package resolvespec
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/common/adapters/router"
|
||||
"github.com/Warky-Devs/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/uptrace/bun"
|
||||
"github.com/uptrace/bunrouter"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/router"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// NewHandlerWithGORM creates a new Handler with GORM adapter
|
||||
@@ -36,28 +39,122 @@ func NewStandardBunRouter() *router.StandardBunRouterAdapter {
|
||||
return router.NewStandardBunRouterAdapter()
|
||||
}
|
||||
|
||||
// MiddlewareFunc is a function that wraps an http.Handler with additional functionality
|
||||
type MiddlewareFunc func(http.Handler) http.Handler
|
||||
|
||||
// SetupMuxRoutes sets up routes for the ResolveSpec API with Mux
|
||||
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
|
||||
muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
handler.Handle(respAdapter, reqAdapter, vars)
|
||||
}).Methods("POST")
|
||||
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||
// Example: SetupMuxRoutes(router, handler, func(h http.Handler) http.Handler { return security.NewAuthHandler(securityList, h) })
|
||||
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||
// Get all registered models from the registry
|
||||
allModels := handler.registry.GetAllModels()
|
||||
|
||||
muxRouter.HandleFunc("/{schema}/{entity}/{id}", func(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
handler.Handle(respAdapter, reqAdapter, vars)
|
||||
}).Methods("POST")
|
||||
// Loop through each registered model and create explicit routes
|
||||
for fullName := range allModels {
|
||||
// Parse the full name (e.g., "public.users" or just "users")
|
||||
schema, entity := parseModelName(fullName)
|
||||
|
||||
muxRouter.HandleFunc("/{schema}/{entity}", func(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
// Build the route paths
|
||||
entityPath := buildRoutePath(schema, entity)
|
||||
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
|
||||
|
||||
// Create handler functions for this specific entity
|
||||
postEntityHandler := createMuxHandler(handler, schema, entity, "")
|
||||
postEntityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
||||
getEntityHandler := createMuxGetHandler(handler, schema, entity, "")
|
||||
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
|
||||
|
||||
// Apply authentication middleware if provided
|
||||
if authMiddleware != nil {
|
||||
postEntityHandler = authMiddleware(postEntityHandler).(http.HandlerFunc)
|
||||
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler).(http.HandlerFunc)
|
||||
getEntityHandler = authMiddleware(getEntityHandler).(http.HandlerFunc)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
}
|
||||
|
||||
// Register routes for this entity
|
||||
muxRouter.Handle(entityPath, postEntityHandler).Methods("POST")
|
||||
muxRouter.Handle(entityWithIDPath, postEntityWithIDHandler).Methods("POST")
|
||||
muxRouter.Handle(entityPath, getEntityHandler).Methods("GET")
|
||||
muxRouter.Handle(entityPath, optionsEntityHandler).Methods("OPTIONS")
|
||||
muxRouter.Handle(entityWithIDPath, optionsEntityWithIDHandler).Methods("OPTIONS")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create Mux handler for a specific entity with CORS support
|
||||
func createMuxHandler(handler *Handler, schema, entity, idParam string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Set CORS headers
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
vars["entity"] = entity
|
||||
if idParam != "" {
|
||||
vars["id"] = mux.Vars(r)[idParam]
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
handler.Handle(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create Mux GET handler for a specific entity with CORS support
|
||||
func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Set CORS headers
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
vars["entity"] = entity
|
||||
if idParam != "" {
|
||||
vars["id"] = mux.Vars(r)[idParam]
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||
}).Methods("GET")
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create Mux OPTIONS handler that returns metadata
|
||||
func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMethods []string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// Set CORS headers with the allowed methods for this route
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
corsConfig.AllowedMethods = allowedMethods
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
|
||||
// Return metadata in the OPTIONS response body
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
vars["entity"] = entity
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
|
||||
// parseModelName parses a model name like "public.users" into schema and entity
|
||||
// If no schema is present, returns empty string for schema
|
||||
func parseModelName(fullName string) (schema, entity string) {
|
||||
parts := strings.Split(fullName, ".")
|
||||
if len(parts) == 2 {
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
return "", fullName
|
||||
}
|
||||
|
||||
// buildRoutePath builds a route path from schema and entity
|
||||
// If schema is empty, returns just "/entity", otherwise "/{schema}/{entity}"
|
||||
func buildRoutePath(schema, entity string) string {
|
||||
if schema == "" {
|
||||
return "/" + entity
|
||||
}
|
||||
return "/" + schema + "/" + entity
|
||||
}
|
||||
|
||||
// Example usage functions for documentation:
|
||||
@@ -67,12 +164,20 @@ func ExampleWithGORM(db *gorm.DB) {
|
||||
// Create handler using GORM
|
||||
handler := NewHandlerWithGORM(db)
|
||||
|
||||
// Setup router
|
||||
// Setup router without authentication
|
||||
muxRouter := mux.NewRouter()
|
||||
SetupMuxRoutes(muxRouter, handler)
|
||||
SetupMuxRoutes(muxRouter, handler, nil)
|
||||
|
||||
// Register models
|
||||
// handler.RegisterModel("public", "users", &User{})
|
||||
|
||||
// To add authentication, pass a middleware function:
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||
}
|
||||
|
||||
// ExampleWithBun shows how to switch to Bun ORM
|
||||
@@ -87,60 +192,118 @@ func ExampleWithBun(bunDB *bun.DB) {
|
||||
// Create handler
|
||||
handler := NewHandler(dbAdapter, registry)
|
||||
|
||||
// Setup routes
|
||||
// Setup routes without authentication
|
||||
muxRouter := mux.NewRouter()
|
||||
SetupMuxRoutes(muxRouter, handler)
|
||||
SetupMuxRoutes(muxRouter, handler, nil)
|
||||
}
|
||||
|
||||
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
|
||||
func SetupBunRouterRoutes(bunRouter *router.StandardBunRouterAdapter, handler *Handler) {
|
||||
r := bunRouter.GetBunRouter()
|
||||
|
||||
r.Handle("POST", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
params := map[string]string{
|
||||
"schema": req.Param("schema"),
|
||||
"entity": req.Param("entity"),
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
// Get all registered models from the registry
|
||||
allModels := handler.registry.GetAllModels()
|
||||
|
||||
r.Handle("POST", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
params := map[string]string{
|
||||
"schema": req.Param("schema"),
|
||||
"entity": req.Param("entity"),
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
// CORS config
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
|
||||
r.Handle("GET", "/:schema/:entity", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
params := map[string]string{
|
||||
"schema": req.Param("schema"),
|
||||
"entity": req.Param("entity"),
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
// Loop through each registered model and create explicit routes
|
||||
for fullName := range allModels {
|
||||
// Parse the full name (e.g., "public.users" or just "users")
|
||||
schema, entity := parseModelName(fullName)
|
||||
|
||||
r.Handle("GET", "/:schema/:entity/:id", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
params := map[string]string{
|
||||
"schema": req.Param("schema"),
|
||||
"entity": req.Param("entity"),
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
// Build the route paths
|
||||
entityPath := buildRoutePath(schema, entity)
|
||||
entityWithIDPath := entityPath + "/:id"
|
||||
|
||||
// Create closure variables to capture current schema and entity
|
||||
currentSchema := schema
|
||||
currentEntity := entity
|
||||
|
||||
// POST route without ID
|
||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
// POST route with ID
|
||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
// GET route without ID
|
||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
// GET route with ID
|
||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
// OPTIONS route without ID (returns metadata)
|
||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
optionsCorsConfig := corsConfig
|
||||
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
|
||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
// OPTIONS route with ID (returns metadata)
|
||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
optionsCorsConfig := corsConfig
|
||||
optionsCorsConfig.AllowedMethods = []string{"POST", "OPTIONS"}
|
||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ExampleWithBunRouter shows how to use bunrouter from uptrace
|
||||
|
||||
114
pkg/resolvespec/resolvespec_test.go
Normal file
114
pkg/resolvespec/resolvespec_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseModelName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fullName string
|
||||
expectedSchema string
|
||||
expectedEntity string
|
||||
}{
|
||||
{
|
||||
name: "Model with schema",
|
||||
fullName: "public.users",
|
||||
expectedSchema: "public",
|
||||
expectedEntity: "users",
|
||||
},
|
||||
{
|
||||
name: "Model without schema",
|
||||
fullName: "users",
|
||||
expectedSchema: "",
|
||||
expectedEntity: "users",
|
||||
},
|
||||
{
|
||||
name: "Model with custom schema",
|
||||
fullName: "myschema.products",
|
||||
expectedSchema: "myschema",
|
||||
expectedEntity: "products",
|
||||
},
|
||||
{
|
||||
name: "Empty string",
|
||||
fullName: "",
|
||||
expectedSchema: "",
|
||||
expectedEntity: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
schema, entity := parseModelName(tt.fullName)
|
||||
if schema != tt.expectedSchema {
|
||||
t.Errorf("Expected schema '%s', got '%s'", tt.expectedSchema, schema)
|
||||
}
|
||||
if entity != tt.expectedEntity {
|
||||
t.Errorf("Expected entity '%s', got '%s'", tt.expectedEntity, entity)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRoutePath(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
schema string
|
||||
entity string
|
||||
expectedPath string
|
||||
}{
|
||||
{
|
||||
name: "With schema",
|
||||
schema: "public",
|
||||
entity: "users",
|
||||
expectedPath: "/public/users",
|
||||
},
|
||||
{
|
||||
name: "Without schema",
|
||||
schema: "",
|
||||
entity: "users",
|
||||
expectedPath: "/users",
|
||||
},
|
||||
{
|
||||
name: "Custom schema",
|
||||
schema: "admin",
|
||||
entity: "logs",
|
||||
expectedPath: "/admin/logs",
|
||||
},
|
||||
{
|
||||
name: "Empty entity with schema",
|
||||
schema: "public",
|
||||
entity: "",
|
||||
expectedPath: "/public/",
|
||||
},
|
||||
{
|
||||
name: "Both empty",
|
||||
schema: "",
|
||||
entity: "",
|
||||
expectedPath: "/",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
path := buildRoutePath(tt.schema, tt.entity)
|
||||
if path != tt.expectedPath {
|
||||
t.Errorf("Expected path '%s', got '%s'", tt.expectedPath, path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStandardMuxRouter(t *testing.T) {
|
||||
router := NewStandardMuxRouter()
|
||||
if router == nil {
|
||||
t.Error("Expected router to be created, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewStandardBunRouter(t *testing.T) {
|
||||
router := NewStandardBunRouter()
|
||||
if router == nil {
|
||||
t.Error("Expected router to be created, got nil")
|
||||
}
|
||||
}
|
||||
85
pkg/resolvespec/security_hooks.go
Normal file
85
pkg/resolvespec/security_hooks.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.LoadSecurityRules(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 2: BeforeScan - Apply row-level security filters
|
||||
handler.Hooks().Register(BeforeScan, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.ApplyRowSecurity(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 3: AfterRead - Apply column-level security (masking)
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.ApplyColumnSecurity(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 4 (Optional): Audit logging
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for resolvespec handler")
|
||||
}
|
||||
|
||||
// securityContext adapts resolvespec.HookContext to security.SecurityContext interface
|
||||
type securityContext struct {
|
||||
ctx *HookContext
|
||||
}
|
||||
|
||||
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||
return &securityContext{ctx: ctx}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetContext() context.Context {
|
||||
return s.ctx.Context
|
||||
}
|
||||
|
||||
func (s *securityContext) GetUserID() (int, bool) {
|
||||
return security.GetUserID(s.ctx.Context)
|
||||
}
|
||||
|
||||
func (s *securityContext) GetSchema() string {
|
||||
return s.ctx.Schema
|
||||
}
|
||||
|
||||
func (s *securityContext) GetEntity() string {
|
||||
return s.ctx.Entity
|
||||
}
|
||||
|
||||
func (s *securityContext) GetModel() interface{} {
|
||||
return s.ctx.Model
|
||||
}
|
||||
|
||||
func (s *securityContext) GetQuery() interface{} {
|
||||
return s.ctx.Query
|
||||
}
|
||||
|
||||
func (s *securityContext) SetQuery(query interface{}) {
|
||||
if q, ok := query.(common.SelectQuery); ok {
|
||||
s.ctx.Query = q
|
||||
}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetResult() interface{} {
|
||||
return s.ctx.Result
|
||||
}
|
||||
|
||||
func (s *securityContext) SetResult(result interface{}) {
|
||||
s.ctx.Result = result
|
||||
}
|
||||
@@ -13,6 +13,7 @@ const (
|
||||
contextKeyTableName contextKey = "tableName"
|
||||
contextKeyModel contextKey = "model"
|
||||
contextKeyModelPtr contextKey = "modelPtr"
|
||||
contextKeyOptions contextKey = "options"
|
||||
)
|
||||
|
||||
// WithSchema adds schema to context
|
||||
@@ -74,12 +75,28 @@ func GetModelPtr(ctx context.Context) interface{} {
|
||||
return ctx.Value(contextKeyModelPtr)
|
||||
}
|
||||
|
||||
// WithOptions adds request options to context
|
||||
func WithOptions(ctx context.Context, options ExtendedRequestOptions) context.Context {
|
||||
return context.WithValue(ctx, contextKeyOptions, options)
|
||||
}
|
||||
|
||||
// GetOptions retrieves request options from context
|
||||
func GetOptions(ctx context.Context) *ExtendedRequestOptions {
|
||||
if v := ctx.Value(contextKeyOptions); v != nil {
|
||||
if opts, ok := v.(ExtendedRequestOptions); ok {
|
||||
return &opts
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithRequestData adds all request-scoped data to context at once
|
||||
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
|
||||
func WithRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}, options ExtendedRequestOptions) context.Context {
|
||||
ctx = WithSchema(ctx, schema)
|
||||
ctx = WithEntity(ctx, entity)
|
||||
ctx = WithTableName(ctx, tableName)
|
||||
ctx = WithModel(ctx, model)
|
||||
ctx = WithModelPtr(ctx, modelPtr)
|
||||
ctx = WithOptions(ctx, options)
|
||||
return ctx
|
||||
}
|
||||
|
||||
181
pkg/restheadspec/context_test.go
Normal file
181
pkg/restheadspec/context_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
func TestContextOperations(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test Schema
|
||||
t.Run("WithSchema and GetSchema", func(t *testing.T) {
|
||||
ctx = WithSchema(ctx, "public")
|
||||
schema := GetSchema(ctx)
|
||||
if schema != "public" {
|
||||
t.Errorf("Expected schema 'public', got '%s'", schema)
|
||||
}
|
||||
})
|
||||
|
||||
// Test Entity
|
||||
t.Run("WithEntity and GetEntity", func(t *testing.T) {
|
||||
ctx = WithEntity(ctx, "users")
|
||||
entity := GetEntity(ctx)
|
||||
if entity != "users" {
|
||||
t.Errorf("Expected entity 'users', got '%s'", entity)
|
||||
}
|
||||
})
|
||||
|
||||
// Test TableName
|
||||
t.Run("WithTableName and GetTableName", func(t *testing.T) {
|
||||
ctx = WithTableName(ctx, "public.users")
|
||||
tableName := GetTableName(ctx)
|
||||
if tableName != "public.users" {
|
||||
t.Errorf("Expected tableName 'public.users', got '%s'", tableName)
|
||||
}
|
||||
})
|
||||
|
||||
// Test Model
|
||||
t.Run("WithModel and GetModel", func(t *testing.T) {
|
||||
type TestModel struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
model := &TestModel{ID: 1, Name: "test"}
|
||||
ctx = WithModel(ctx, model)
|
||||
retrieved := GetModel(ctx)
|
||||
if retrieved == nil {
|
||||
t.Error("Expected model to be retrieved, got nil")
|
||||
}
|
||||
if retrievedModel, ok := retrieved.(*TestModel); ok {
|
||||
if retrievedModel.ID != 1 || retrievedModel.Name != "test" {
|
||||
t.Errorf("Expected model with ID=1 and Name='test', got ID=%d, Name='%s'", retrievedModel.ID, retrievedModel.Name)
|
||||
}
|
||||
} else {
|
||||
t.Error("Retrieved model is not of expected type")
|
||||
}
|
||||
})
|
||||
|
||||
// Test ModelPtr
|
||||
t.Run("WithModelPtr and GetModelPtr", func(t *testing.T) {
|
||||
type TestModel struct {
|
||||
ID int
|
||||
}
|
||||
models := []*TestModel{}
|
||||
ctx = WithModelPtr(ctx, &models)
|
||||
retrieved := GetModelPtr(ctx)
|
||||
if retrieved == nil {
|
||||
t.Error("Expected modelPtr to be retrieved, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
// Test Options
|
||||
t.Run("WithOptions and GetOptions", func(t *testing.T) {
|
||||
limit := 10
|
||||
options := ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Limit: &limit,
|
||||
},
|
||||
}
|
||||
ctx = WithOptions(ctx, options)
|
||||
retrieved := GetOptions(ctx)
|
||||
if retrieved == nil {
|
||||
t.Error("Expected options to be retrieved, got nil")
|
||||
return
|
||||
}
|
||||
if retrieved.Limit == nil || *retrieved.Limit != 10 {
|
||||
t.Error("Expected options to be retrieved with limit=10")
|
||||
}
|
||||
})
|
||||
|
||||
// Test WithRequestData
|
||||
t.Run("WithRequestData", func(t *testing.T) {
|
||||
type TestModel struct {
|
||||
ID int
|
||||
Name string
|
||||
}
|
||||
model := &TestModel{ID: 1, Name: "test"}
|
||||
modelPtr := &[]*TestModel{}
|
||||
limit := 20
|
||||
options := ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Limit: &limit,
|
||||
},
|
||||
}
|
||||
|
||||
ctx = WithRequestData(ctx, "test_schema", "test_entity", "test_schema.test_entity", model, modelPtr, options)
|
||||
|
||||
if GetSchema(ctx) != "test_schema" {
|
||||
t.Errorf("Expected schema 'test_schema', got '%s'", GetSchema(ctx))
|
||||
}
|
||||
if GetEntity(ctx) != "test_entity" {
|
||||
t.Errorf("Expected entity 'test_entity', got '%s'", GetEntity(ctx))
|
||||
}
|
||||
if GetTableName(ctx) != "test_schema.test_entity" {
|
||||
t.Errorf("Expected tableName 'test_schema.test_entity', got '%s'", GetTableName(ctx))
|
||||
}
|
||||
if GetModel(ctx) == nil {
|
||||
t.Error("Expected model to be set")
|
||||
}
|
||||
if GetModelPtr(ctx) == nil {
|
||||
t.Error("Expected modelPtr to be set")
|
||||
}
|
||||
opts := GetOptions(ctx)
|
||||
if opts == nil {
|
||||
t.Error("Expected options to be set")
|
||||
return
|
||||
}
|
||||
if opts.Limit == nil || *opts.Limit != 20 {
|
||||
t.Error("Expected options to be set with limit=20")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEmptyContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("GetSchema with empty context", func(t *testing.T) {
|
||||
schema := GetSchema(ctx)
|
||||
if schema != "" {
|
||||
t.Errorf("Expected empty schema, got '%s'", schema)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetEntity with empty context", func(t *testing.T) {
|
||||
entity := GetEntity(ctx)
|
||||
if entity != "" {
|
||||
t.Errorf("Expected empty entity, got '%s'", entity)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetTableName with empty context", func(t *testing.T) {
|
||||
tableName := GetTableName(ctx)
|
||||
if tableName != "" {
|
||||
t.Errorf("Expected empty tableName, got '%s'", tableName)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetModel with empty context", func(t *testing.T) {
|
||||
model := GetModel(ctx)
|
||||
if model != nil {
|
||||
t.Errorf("Expected nil model, got %v", model)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetModelPtr with empty context", func(t *testing.T) {
|
||||
modelPtr := GetModelPtr(ctx)
|
||||
if modelPtr != nil {
|
||||
t.Errorf("Expected nil modelPtr, got %v", modelPtr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetOptions with empty context", func(t *testing.T) {
|
||||
options := GetOptions(ctx)
|
||||
// GetOptions returns nil when context is empty
|
||||
if options != nil {
|
||||
t.Errorf("Expected nil options in empty context, got %v", options)
|
||||
}
|
||||
})
|
||||
}
|
||||
226
pkg/restheadspec/cursor.go
Normal file
226
pkg/restheadspec/cursor.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// CursorDirection defines pagination direction
|
||||
type CursorDirection int
|
||||
|
||||
const (
|
||||
CursorForward CursorDirection = 1
|
||||
CursorBackward CursorDirection = -1
|
||||
)
|
||||
|
||||
// GetCursorFilter generates a SQL `EXISTS` subquery for cursor-based pagination.
|
||||
// It uses the current request's sort, cursor, joins (via Expand), and CQL (via ComputedQL).
|
||||
//
|
||||
// Parameters:
|
||||
// - tableName: name of the main table (e.g. "post")
|
||||
// - pkName: primary key column (e.g. "id")
|
||||
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
||||
// - expandJoins: optional map[alias]string of JOIN clauses (e.g. "user": "LEFT JOIN user ON ...")
|
||||
//
|
||||
// Returns SQL snippet to embed in WHERE clause.
|
||||
func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
tableName string,
|
||||
pkName string,
|
||||
modelColumns []string, // optional: for validation
|
||||
expandJoins map[string]string, // optional: alias → JOIN SQL
|
||||
) (string, error) {
|
||||
if strings.Contains(tableName, ".") {
|
||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||
}
|
||||
// --------------------------------------------------------------------- //
|
||||
// 1. Determine active cursor
|
||||
// --------------------------------------------------------------------- //
|
||||
cursorID, direction := opts.getActiveCursor()
|
||||
if cursorID == "" {
|
||||
return "", fmt.Errorf("no cursor provided for table %s", tableName)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 2. Extract sort columns
|
||||
// --------------------------------------------------------------------- //
|
||||
sortItems := opts.getSortColumns()
|
||||
if len(sortItems) == 0 {
|
||||
return "", fmt.Errorf("no sort columns defined")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 3. Prepare
|
||||
// --------------------------------------------------------------------- //
|
||||
var whereClauses []string
|
||||
joinSQL := ""
|
||||
reverse := direction < 0
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse: "user.name desc nulls last"
|
||||
parts := strings.Split(col, ".")
|
||||
field := strings.TrimSpace(parts[len(parts)-1])
|
||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||
|
||||
// Direction from struct or string
|
||||
desc := strings.EqualFold(s.Direction, "desc") ||
|
||||
strings.Contains(strings.ToLower(field), "desc")
|
||||
field = opts.cleanSortField(field)
|
||||
|
||||
if reverse {
|
||||
desc = !desc
|
||||
}
|
||||
|
||||
// Resolve column
|
||||
cursorCol, targetCol, isJoin, err := opts.resolveColumn(
|
||||
field, prefix, tableName, modelColumns,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle joins
|
||||
if isJoin && expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
|
||||
// Build inequality
|
||||
op := "<"
|
||||
if desc {
|
||||
op = ">"
|
||||
}
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
|
||||
}
|
||||
|
||||
if len(whereClauses) == 0 {
|
||||
return "", fmt.Errorf("no valid sort columns after filtering")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 5. Build priority OR-AND chain
|
||||
// --------------------------------------------------------------------- //
|
||||
orSQL := buildPriorityChain(whereClauses)
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 6. Final EXISTS subquery
|
||||
// --------------------------------------------------------------------- //
|
||||
query := fmt.Sprintf(`EXISTS (
|
||||
SELECT 1
|
||||
FROM %s cursor_select
|
||||
%s
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
tableName,
|
||||
joinSQL,
|
||||
pkName,
|
||||
cursorID,
|
||||
orSQL,
|
||||
)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
// Helper: get active cursor (forward or backward)
|
||||
func (opts *ExtendedRequestOptions) getActiveCursor() (id string, direction CursorDirection) {
|
||||
if opts.CursorForward != "" {
|
||||
return opts.CursorForward, CursorForward
|
||||
}
|
||||
if opts.CursorBackward != "" {
|
||||
return opts.CursorBackward, CursorBackward
|
||||
}
|
||||
return "", 0
|
||||
}
|
||||
|
||||
// Helper: extract sort columns
|
||||
func (opts *ExtendedRequestOptions) getSortColumns() []common.SortOption {
|
||||
if opts.Sort != nil {
|
||||
return opts.Sort
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper: clean sort field (remove desc, asc, nulls)
|
||||
func (opts *ExtendedRequestOptions) cleanSortField(field string) string {
|
||||
f := strings.ToLower(field)
|
||||
for _, token := range []string{"desc", "asc", "nulls last", "nulls first"} {
|
||||
f = strings.ReplaceAll(f, token, "")
|
||||
}
|
||||
return strings.TrimSpace(f)
|
||||
}
|
||||
|
||||
// Helper: resolve column (main, JSON, CQL, join)
|
||||
func (opts *ExtendedRequestOptions) resolveColumn(
|
||||
field, prefix, tableName string,
|
||||
modelColumns []string,
|
||||
) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||
|
||||
// JSON field
|
||||
if strings.Contains(field, "->") {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
// CQL via ComputedQL
|
||||
if strings.Contains(strings.ToLower(field), "cql") && opts.ComputedQL != nil {
|
||||
if expr, ok := opts.ComputedQL[field]; ok {
|
||||
return "cursor_select." + expr, expr, false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Main table column
|
||||
if modelColumns != nil {
|
||||
for _, col := range modelColumns {
|
||||
if strings.EqualFold(col, field) {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No validation → allow all main-table fields
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
// Joined column
|
||||
if prefix != "" && prefix != tableName {
|
||||
return "", "", true, nil
|
||||
}
|
||||
|
||||
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
// Helper: rewrite JOIN clause for cursor subquery
|
||||
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||
cursorAlias = "cursor_select_" + alias
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||
return joinSQL, cursorAlias
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
// Helper: build OR-AND priority chain
|
||||
func buildPriorityChain(clauses []string) string {
|
||||
var or []string
|
||||
for i := 0; i < len(clauses); i++ {
|
||||
and := strings.Join(clauses[:i+1], "\n AND ")
|
||||
or = append(or, "("+and+")")
|
||||
}
|
||||
return strings.Join(or, "\n OR ")
|
||||
}
|
||||
305
pkg/restheadspec/cursor_test.go
Normal file
305
pkg/restheadspec/cursor_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
func TestGetCursorFilter_Forward(t *testing.T) {
|
||||
opts := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
{Column: "id", Direction: "ASC"},
|
||||
},
|
||||
},
|
||||
}
|
||||
opts.CursorForward = "123"
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
if filter == "" {
|
||||
t.Fatal("Expected non-empty cursor filter")
|
||||
}
|
||||
|
||||
// Verify filter contains EXISTS subquery
|
||||
if !strings.Contains(filter, "EXISTS") {
|
||||
t.Errorf("Filter should contain EXISTS subquery, got: %s", filter)
|
||||
}
|
||||
|
||||
// Verify filter references the cursor ID
|
||||
if !strings.Contains(filter, "123") {
|
||||
t.Errorf("Filter should reference cursor ID 123, got: %s", filter)
|
||||
}
|
||||
|
||||
// Verify filter contains the table name
|
||||
if !strings.Contains(filter, tableName) {
|
||||
t.Errorf("Filter should reference table name %s, got: %s", tableName, filter)
|
||||
}
|
||||
|
||||
// Verify filter contains primary key
|
||||
if !strings.Contains(filter, pkName) {
|
||||
t.Errorf("Filter should reference primary key %s, got: %s", pkName, filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated cursor filter: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_Backward(t *testing.T) {
|
||||
opts := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
{Column: "id", Direction: "ASC"},
|
||||
},
|
||||
},
|
||||
}
|
||||
opts.CursorBackward = "456"
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
if filter == "" {
|
||||
t.Fatal("Expected non-empty cursor filter")
|
||||
}
|
||||
|
||||
// Verify filter contains cursor ID
|
||||
if !strings.Contains(filter, "456") {
|
||||
t.Errorf("Filter should reference cursor ID 456, got: %s", filter)
|
||||
}
|
||||
|
||||
// For backward cursor, sort direction should be reversed
|
||||
// This is handled internally by the GetCursorFilter method
|
||||
t.Logf("Generated backward cursor filter: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_NoCursor(t *testing.T) {
|
||||
opts := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
},
|
||||
},
|
||||
}
|
||||
// No cursor set
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at"}
|
||||
|
||||
_, err := opts.GetCursorFilter(tableName, pkName, modelColumns, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no cursor is provided")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no cursor provided") {
|
||||
t.Errorf("Expected 'no cursor provided' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_NoSort(t *testing.T) {
|
||||
opts := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Sort: []common.SortOption{},
|
||||
},
|
||||
}
|
||||
opts.CursorForward = "123"
|
||||
|
||||
tableName := "posts"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title"}
|
||||
|
||||
_, err := opts.GetCursorFilter(tableName, pkName, modelColumns, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no sort columns are defined")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "no sort columns") {
|
||||
t.Errorf("Expected 'no sort columns' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
|
||||
opts := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "priority", Direction: "DESC"},
|
||||
{Column: "created_at", Direction: "DESC"},
|
||||
{Column: "id", Direction: "ASC"},
|
||||
},
|
||||
},
|
||||
}
|
||||
opts.CursorForward = "789"
|
||||
|
||||
tableName := "tasks"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "priority", "created_at"}
|
||||
|
||||
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify filter contains priority column
|
||||
if !strings.Contains(filter, "priority") {
|
||||
t.Errorf("Filter should reference priority column, got: %s", filter)
|
||||
}
|
||||
|
||||
// Verify filter contains created_at column
|
||||
if !strings.Contains(filter, "created_at") {
|
||||
t.Errorf("Filter should reference created_at column, got: %s", filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated multi-column cursor filter: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
||||
opts := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "name", Direction: "ASC"},
|
||||
},
|
||||
},
|
||||
}
|
||||
opts.CursorForward = "100"
|
||||
|
||||
tableName := "public.users"
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "name", "email"}
|
||||
|
||||
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
// Should handle schema prefix properly
|
||||
if !strings.Contains(filter, "users") {
|
||||
t.Errorf("Filter should reference table name users, got: %s", filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetActiveCursor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cursorForward string
|
||||
cursorBackward string
|
||||
expectedID string
|
||||
expectedDirection CursorDirection
|
||||
}{
|
||||
{
|
||||
name: "Forward cursor only",
|
||||
cursorForward: "123",
|
||||
cursorBackward: "",
|
||||
expectedID: "123",
|
||||
expectedDirection: CursorForward,
|
||||
},
|
||||
{
|
||||
name: "Backward cursor only",
|
||||
cursorForward: "",
|
||||
cursorBackward: "456",
|
||||
expectedID: "456",
|
||||
expectedDirection: CursorBackward,
|
||||
},
|
||||
{
|
||||
name: "Both cursors - forward takes precedence",
|
||||
cursorForward: "123",
|
||||
cursorBackward: "456",
|
||||
expectedID: "123",
|
||||
expectedDirection: CursorForward,
|
||||
},
|
||||
{
|
||||
name: "No cursors",
|
||||
cursorForward: "",
|
||||
cursorBackward: "",
|
||||
expectedID: "",
|
||||
expectedDirection: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := &ExtendedRequestOptions{}
|
||||
opts.CursorForward = tt.cursorForward
|
||||
opts.CursorBackward = tt.cursorBackward
|
||||
|
||||
id, direction := opts.getActiveCursor()
|
||||
|
||||
if id != tt.expectedID {
|
||||
t.Errorf("Expected cursor ID %q, got %q", tt.expectedID, id)
|
||||
}
|
||||
|
||||
if direction != tt.expectedDirection {
|
||||
t.Errorf("Expected direction %d, got %d", tt.expectedDirection, direction)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanSortField(t *testing.T) {
|
||||
opts := &ExtendedRequestOptions{}
|
||||
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"created_at desc", "created_at"},
|
||||
{"name asc", "name"},
|
||||
{"priority desc nulls last", "priority"},
|
||||
{"id asc nulls first", "id"},
|
||||
{"title", "title"},
|
||||
{"updated_at DESC", "updated_at"},
|
||||
{" status asc ", "status"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := opts.cleanSortField(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("cleanSortField(%q) = %q, expected %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPriorityChain(t *testing.T) {
|
||||
clauses := []string{
|
||||
"cursor_select.priority > posts.priority",
|
||||
"cursor_select.created_at > posts.created_at",
|
||||
"cursor_select.id < posts.id",
|
||||
}
|
||||
|
||||
result := buildPriorityChain(clauses)
|
||||
|
||||
// Should build OR-AND chain for cursor comparison
|
||||
if !strings.Contains(result, "OR") {
|
||||
t.Error("Priority chain should contain OR operators")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "AND") {
|
||||
t.Error("Priority chain should contain AND operators for composite conditions")
|
||||
}
|
||||
|
||||
// First clause should appear standalone
|
||||
if !strings.Contains(result, clauses[0]) {
|
||||
t.Errorf("Priority chain should contain first clause: %s", clauses[0])
|
||||
}
|
||||
|
||||
t.Logf("Built priority chain: %s", result)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
425
pkg/restheadspec/handler_nested_test.go
Normal file
425
pkg/restheadspec/handler_nested_test.go
Normal file
@@ -0,0 +1,425 @@
|
||||
// +build !integration
|
||||
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Test models for nested CRUD operations
|
||||
type TestUser struct {
|
||||
ID int64 `json:"id" bun:"id,pk,autoincrement"`
|
||||
Name string `json:"name"`
|
||||
Posts []TestPost `json:"posts" gorm:"foreignKey:UserID"`
|
||||
}
|
||||
|
||||
type TestPost struct {
|
||||
ID int64 `json:"id" bun:"id,pk,autoincrement"`
|
||||
UserID int64 `json:"user_id"`
|
||||
Title string `json:"title"`
|
||||
Comments []TestComment `json:"comments" gorm:"foreignKey:PostID"`
|
||||
}
|
||||
|
||||
type TestComment struct {
|
||||
ID int64 `json:"id" bun:"id,pk,autoincrement"`
|
||||
PostID int64 `json:"post_id"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (TestUser) TableName() string { return "users" }
|
||||
func (TestPost) TableName() string { return "posts" }
|
||||
func (TestComment) TableName() string { return "comments" }
|
||||
|
||||
// Test extractNestedRelations function
|
||||
func TestExtractNestedRelations(t *testing.T) {
|
||||
// Create handler
|
||||
registry := &mockRegistry{
|
||||
models: map[string]interface{}{
|
||||
"users": TestUser{},
|
||||
"posts": TestPost{},
|
||||
"comments": TestComment{},
|
||||
},
|
||||
}
|
||||
handler := NewHandler(nil, registry)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data map[string]interface{}
|
||||
model interface{}
|
||||
expectedCleanCount int
|
||||
expectedRelCount int
|
||||
}{
|
||||
{
|
||||
name: "User with posts",
|
||||
data: map[string]interface{}{
|
||||
"name": "John Doe",
|
||||
"posts": []map[string]interface{}{
|
||||
{"title": "Post 1"},
|
||||
},
|
||||
},
|
||||
model: TestUser{},
|
||||
expectedCleanCount: 1, // name
|
||||
expectedRelCount: 1, // posts
|
||||
},
|
||||
{
|
||||
name: "Post with comments",
|
||||
data: map[string]interface{}{
|
||||
"title": "Test Post",
|
||||
"comments": []map[string]interface{}{
|
||||
{"content": "Comment 1"},
|
||||
{"content": "Comment 2"},
|
||||
},
|
||||
},
|
||||
model: TestPost{},
|
||||
expectedCleanCount: 1, // title
|
||||
expectedRelCount: 1, // comments
|
||||
},
|
||||
{
|
||||
name: "User with nested posts and comments",
|
||||
data: map[string]interface{}{
|
||||
"name": "Jane Doe",
|
||||
"posts": []map[string]interface{}{
|
||||
{
|
||||
"title": "Post 1",
|
||||
"comments": []map[string]interface{}{
|
||||
{"content": "Comment 1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
model: TestUser{},
|
||||
expectedCleanCount: 1, // name
|
||||
expectedRelCount: 1, // posts (which contains nested comments)
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cleanedData, relations, err := handler.extractNestedRelations(tt.data, tt.model)
|
||||
if err != nil {
|
||||
t.Errorf("extractNestedRelations() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(cleanedData) != tt.expectedCleanCount {
|
||||
t.Errorf("Expected %d cleaned fields, got %d: %+v", tt.expectedCleanCount, len(cleanedData), cleanedData)
|
||||
}
|
||||
|
||||
if len(relations) != tt.expectedRelCount {
|
||||
t.Errorf("Expected %d relation fields, got %d: %+v", tt.expectedRelCount, len(relations), relations)
|
||||
}
|
||||
|
||||
t.Logf("Cleaned data: %+v", cleanedData)
|
||||
t.Logf("Relations: %+v", relations)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test shouldUseNestedProcessor function
|
||||
func TestShouldUseNestedProcessor(t *testing.T) {
|
||||
registry := &mockRegistry{
|
||||
models: map[string]interface{}{
|
||||
"users": TestUser{},
|
||||
"posts": TestPost{},
|
||||
},
|
||||
}
|
||||
handler := NewHandler(nil, registry)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data map[string]interface{}
|
||||
model interface{}
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "Data with simple nested posts (no further nesting)",
|
||||
data: map[string]interface{}{
|
||||
"name": "John",
|
||||
"posts": []map[string]interface{}{
|
||||
{"title": "Post 1"},
|
||||
},
|
||||
},
|
||||
model: TestUser{},
|
||||
expected: false, // Simple one-level nesting doesn't require nested processor
|
||||
},
|
||||
{
|
||||
name: "Data with deeply nested relations",
|
||||
data: map[string]interface{}{
|
||||
"name": "John",
|
||||
"posts": []map[string]interface{}{
|
||||
{
|
||||
"title": "Post 1",
|
||||
"comments": []map[string]interface{}{
|
||||
{"content": "Comment 1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
model: TestUser{},
|
||||
expected: true, // Multi-level nesting requires nested processor
|
||||
},
|
||||
{
|
||||
name: "Data without nested relations",
|
||||
data: map[string]interface{}{
|
||||
"name": "John",
|
||||
},
|
||||
model: TestUser{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "Data with _request field",
|
||||
data: map[string]interface{}{
|
||||
"_request": "insert",
|
||||
"name": "John",
|
||||
},
|
||||
model: TestUser{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "Nested data with _request field",
|
||||
data: map[string]interface{}{
|
||||
"name": "John",
|
||||
"posts": []map[string]interface{}{
|
||||
{
|
||||
"_request": "insert",
|
||||
"title": "Post 1",
|
||||
},
|
||||
},
|
||||
},
|
||||
model: TestUser{},
|
||||
expected: true, // _request at nested level requires nested processor
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.shouldUseNestedProcessor(tt.data, tt.model)
|
||||
if result != tt.expected {
|
||||
t.Errorf("shouldUseNestedProcessor() = %v, expected %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test normalizeToSlice function
|
||||
func TestNormalizeToSlice(t *testing.T) {
|
||||
registry := &mockRegistry{}
|
||||
handler := NewHandler(nil, registry)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
expected int // expected slice length
|
||||
}{
|
||||
{
|
||||
name: "Single object",
|
||||
input: map[string]interface{}{"name": "John"},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Slice of objects",
|
||||
input: []map[string]interface{}{
|
||||
{"name": "John"},
|
||||
{"name": "Jane"},
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Array of interfaces",
|
||||
input: []interface{}{
|
||||
map[string]interface{}{"name": "John"},
|
||||
map[string]interface{}{"name": "Jane"},
|
||||
map[string]interface{}{"name": "Bob"},
|
||||
},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "Nil input",
|
||||
input: nil,
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.normalizeToSlice(tt.input)
|
||||
if len(result) != tt.expected {
|
||||
t.Errorf("normalizeToSlice() returned slice of length %d, expected %d", len(result), tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test GetRelationshipInfo function
|
||||
func TestGetRelationshipInfo(t *testing.T) {
|
||||
registry := &mockRegistry{}
|
||||
handler := NewHandler(nil, registry)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
modelType reflect.Type
|
||||
relationName string
|
||||
expectNil bool
|
||||
}{
|
||||
{
|
||||
name: "User posts relation",
|
||||
modelType: reflect.TypeOf(TestUser{}),
|
||||
relationName: "posts",
|
||||
expectNil: false,
|
||||
},
|
||||
{
|
||||
name: "Post comments relation",
|
||||
modelType: reflect.TypeOf(TestPost{}),
|
||||
relationName: "comments",
|
||||
expectNil: false,
|
||||
},
|
||||
{
|
||||
name: "Non-existent relation",
|
||||
modelType: reflect.TypeOf(TestUser{}),
|
||||
relationName: "nonexistent",
|
||||
expectNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := handler.GetRelationshipInfo(tt.modelType, tt.relationName)
|
||||
if tt.expectNil && result != nil {
|
||||
t.Errorf("Expected nil, got %+v", result)
|
||||
}
|
||||
if !tt.expectNil && result == nil {
|
||||
t.Errorf("Expected non-nil relationship info")
|
||||
}
|
||||
if result != nil {
|
||||
t.Logf("Relationship info: FieldName=%s, JSONName=%s, RelationType=%s, ForeignKey=%s",
|
||||
result.FieldName, result.JSONName, result.RelationType, result.ForeignKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Mock registry for testing
|
||||
type mockRegistry struct {
|
||||
models map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockRegistry) Register(name string, model interface{}) {
|
||||
m.RegisterModel(name, model)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) RegisterModel(name string, model interface{}) error {
|
||||
if m.models == nil {
|
||||
m.models = make(map[string]interface{})
|
||||
}
|
||||
m.models[name] = model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetModelByEntity(schema, entity string) (interface{}, error) {
|
||||
if model, ok := m.models[entity]; ok {
|
||||
return model, nil
|
||||
}
|
||||
return nil, fmt.Errorf("model not found: %s", entity)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetModelByName(name string) (interface{}, error) {
|
||||
if model, ok := m.models[name]; ok {
|
||||
return model, nil
|
||||
}
|
||||
return nil, fmt.Errorf("model not found: %s", name)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetModel(name string) (interface{}, error) {
|
||||
return m.GetModelByName(name)
|
||||
}
|
||||
|
||||
func (m *mockRegistry) HasModel(schema, entity string) bool {
|
||||
_, ok := m.models[entity]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (m *mockRegistry) ListModels() []string {
|
||||
models := make([]string, 0, len(m.models))
|
||||
for name := range m.models {
|
||||
models = append(models, name)
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
func (m *mockRegistry) GetAllModels() map[string]interface{} {
|
||||
return m.models
|
||||
}
|
||||
|
||||
// TestMultiLevelRelationExtraction tests extracting deeply nested relations
|
||||
func TestMultiLevelRelationExtraction(t *testing.T) {
|
||||
registry := &mockRegistry{
|
||||
models: map[string]interface{}{
|
||||
"users": TestUser{},
|
||||
"posts": TestPost{},
|
||||
"comments": TestComment{},
|
||||
},
|
||||
}
|
||||
handler := NewHandler(nil, registry)
|
||||
|
||||
// Test data with 3 levels: User -> Posts -> Comments
|
||||
testData := map[string]interface{}{
|
||||
"name": "John Doe",
|
||||
"posts": []map[string]interface{}{
|
||||
{
|
||||
"title": "First Post",
|
||||
"comments": []map[string]interface{}{
|
||||
{"content": "Great post!"},
|
||||
{"content": "Thanks for sharing!"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"title": "Second Post",
|
||||
"comments": []map[string]interface{}{
|
||||
{"content": "Interesting read"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Extract relations from user
|
||||
cleanedData, relations, err := handler.extractNestedRelations(testData, TestUser{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to extract relations: %v", err)
|
||||
}
|
||||
|
||||
// Verify user data is cleaned
|
||||
if len(cleanedData) != 1 || cleanedData["name"] != "John Doe" {
|
||||
t.Errorf("Expected cleaned data to contain only name, got: %+v", cleanedData)
|
||||
}
|
||||
|
||||
// Verify posts relation was extracted
|
||||
if len(relations) != 1 {
|
||||
t.Errorf("Expected 1 relation (posts), got %d", len(relations))
|
||||
}
|
||||
|
||||
posts, ok := relations["posts"]
|
||||
if !ok {
|
||||
t.Fatal("Expected posts relation to be extracted")
|
||||
}
|
||||
|
||||
// Verify posts is a slice with 2 items
|
||||
postsSlice, ok := posts.([]map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("Expected posts to be []map[string]interface{}, got %T", posts)
|
||||
}
|
||||
|
||||
if len(postsSlice) != 2 {
|
||||
t.Errorf("Expected 2 posts, got %d", len(postsSlice))
|
||||
}
|
||||
|
||||
// Verify first post has comments
|
||||
if _, hasComments := postsSlice[0]["comments"]; !hasComments {
|
||||
t.Error("Expected first post to have comments")
|
||||
}
|
||||
|
||||
t.Logf("Successfully extracted multi-level nested relations")
|
||||
t.Logf("Cleaned data: %+v", cleanedData)
|
||||
t.Logf("Relations: %d posts with nested comments", len(postsSlice))
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user