Compare commits

...

152 Commits
v0.0.3 ... main

Author SHA1 Message Date
Hein
932f12ab0a Update handler fixes for Utils bug
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Has been cancelled
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Has been cancelled
Build , Vet Test, and Lint / Lint Code (push) Has been cancelled
Build , Vet Test, and Lint / Build (push) Has been cancelled
Tests / Unit Tests (push) Has been cancelled
Tests / Integration Tests (push) Has been cancelled
2025-12-12 17:01:37 +02:00
Hein
b22792bad6 Optional check for bun 2025-12-12 14:49:52 +02:00
Hein
e8111c01aa Fixed for relation preloading 2025-12-12 11:45:04 +02:00
Hein
5862016031 Added ModelRules
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-12 10:13:11 +02:00
Hein
2f18dde29c Added Tx common.Database to hooks 2025-12-12 09:45:44 +02:00
Hein
31ad217818 Event Broken Concept 2025-12-12 09:23:54 +02:00
Hein
7ef1d6424a Better handling for variables callback
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-11 15:57:01 +02:00
Hein
c50eeac5bf Change the SqlQuery functions parameters on Function Spec 2025-12-11 15:42:00 +02:00
Hein
6d88f2668a Updated login interface with meta 2025-12-11 14:05:27 +02:00
Hein
8a9423df6d Fixed DatabaseAuthenticator JSON value. Added make tag 2025-12-11 13:59:41 +02:00
Hein
4cc943b9d3 Added row PgSQLAdapter
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
2025-12-10 15:28:09 +02:00
Hein
68dee78a34 Fixed filterExtendedOptions 2025-12-10 12:25:23 +02:00
Hein
efb9e5d9d5 Removed the buggy filter expand columns 2025-12-10 12:15:18 +02:00
Hein
490ae37c6d Fixed bugs in extractTableAndColumn 2025-12-10 11:48:03 +02:00
Hein
99307e31e6 More debugging on bun for scan issues 2025-12-10 11:16:25 +02:00
Hein
e3f7869c6d Bun scan debugging 2025-12-10 11:07:18 +02:00
Hein
c696d502c5 extractTableAndColumn 2025-12-10 10:10:55 +02:00
Hein
4ed1fba6ad Fixed extractTableAndColumn 2025-12-10 10:10:43 +02:00
Hein
1d0407a16d Fixed linting
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-10 10:00:01 +02:00
Hein
99001c749d Better sql where validation 2025-12-10 09:52:13 +02:00
Hein
1f7a57f8e3 Tracking provider 2025-12-10 09:31:55 +02:00
Hein
a95c28a0bf Multi Token warning and handling 2025-12-10 08:44:37 +02:00
Hein
e1abd5ebc1 Enhanced the SanitizeWhereClause function 2025-12-10 08:36:24 +02:00
Hein
ca4e53969b Better tests
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-09 15:32:16 +02:00
Hein
db2b7e878e Better handling of preloads 2025-12-09 15:12:17 +02:00
Hein
9572bfc7b8 Fix qualified column reference (like APIL.rid_hub) in a preload: 2025-12-09 14:46:33 +02:00
Hein
f0962ea1ec Added EnableQueryDebug log 2025-12-09 14:37:09 +02:00
Hein
8fcb065b42 Better Query Debugging 2025-12-09 14:31:26 +02:00
Hein
dc3b621380 Fixed test for session id changes 2025-12-09 14:07:00 +02:00
Hein
a4dd2a7086 exposed types FromString 2025-12-09 14:03:55 +02:00
Hein
3ec2e5f15a Proper handling of fromString in the types 2025-12-09 13:55:51 +02:00
Hein
c52afe2825 Updated sql types 2025-12-09 13:14:22 +02:00
Hein
76e98d02c3 Added modelregistry.GetDefaultRegistry 2025-12-09 12:12:10 +02:00
Hein
23e2db1496 Fixed linting 2025-12-09 12:02:44 +02:00
Hein
d188f49126 Added openapi spec 2025-12-09 12:01:21 +02:00
Hein
0f05202438 Database Authenticator with cache 2025-12-09 11:32:44 +02:00
Hein
b2115038f2 Fixed providers 2025-12-09 11:18:11 +02:00
Hein
229ee4fb28 Fixed DatabaseAuthenticator sq select 2025-12-09 11:05:48 +02:00
Hein
2cf760b979 Added a few auth shortcuts 2025-12-09 10:31:08 +02:00
Hein
0a9c107095 Fixed sqlquery bug in funcspec 2025-12-09 10:19:03 +02:00
Hein
4e2fe33b77 Fixed session_rid in funcspec 2025-12-09 10:04:39 +02:00
Hein
1baa0af0ac Config Package
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-09 09:19:56 +02:00
Hein
659b2925e4 Cursor pagnation for resolvespec 2025-12-09 08:51:15 +02:00
Hein
baca70cafc Split coverage reports
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-08 17:20:40 +02:00
Hein
ed57978620 go-version 1.24 2025-12-08 17:14:04 +02:00
Hein
97b39de88a Broken linting 2025-12-08 17:12:44 +02:00
Hein
bf955b7971 Updated version
Some checks are pending
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Waiting to run
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Waiting to run
Build , Vet Test, and Lint / Lint Code (push) Waiting to run
Build , Vet Test, and Lint / Build (push) Waiting to run
Tests / Unit Tests (push) Waiting to run
Tests / Integration Tests (push) Waiting to run
2025-12-08 17:08:23 +02:00
Hein
545856f8a0 Fixed linting issues 2025-12-08 17:07:13 +02:00
Hein
8d123e47bd Updated deps on workflow 2025-12-08 16:59:49 +02:00
Hein
c9eaf84125 A lot more tests 2025-12-08 16:56:48 +02:00
Hein
aeae9d7e0c Added blacklist middleware 2025-12-08 09:26:36 +02:00
Hein
2a84652dba Middleware enhancements 2025-12-08 08:47:13 +02:00
Hein
b741958895 Code sanity fixes, added middlewares
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-12-08 08:28:43 +02:00
Hein
2442589982 Better headers
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-12-03 14:42:38 +02:00
Hein
7c1bae60c9 Added meta handlers 2025-12-03 13:52:06 +02:00
Hein
06b2404c0c Remove blank array if no args 2025-12-03 12:25:51 +02:00
Hein
32007480c6 Handle cql columns as text by default 2025-12-03 12:18:33 +02:00
Hein
ab1ce869b6 Handling JSON responses in funcspec 2025-12-03 12:10:13 +02:00
Hein
ff72e04428 Added meta operation. 2025-12-03 11:59:58 +02:00
Hein
e35f8a4f14 Fix session id that is an integer. 2025-12-03 11:49:19 +02:00
Hein
5ff9a8a24e Fixed blank params on funcspec 2025-12-03 11:42:32 +02:00
Hein
81b87af6e4 Updated doc 2025-12-03 11:30:59 +02:00
Hein
f3ba314640 Refectored the mux routers. 2025-12-03 10:42:26 +02:00
Hein
93df33e274 UnderlyingRequest and UnderlyingResponseWriter
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-12-02 17:40:44 +02:00
Hein
abd045493a mux UnderlyingRequest 2025-12-02 17:34:18 +02:00
Hein
a61556d857 Added FallbackHandler 2025-12-02 17:16:34 +02:00
Hein
eaf1133575 Fixed security rules not loading 2025-12-02 16:55:12 +02:00
Hein
8172c0495d More generic security solution. 2025-12-02 16:35:08 +02:00
Hein
7a3c368121 Pass through to default handler 2025-12-02 16:09:36 +02:00
Hein
9c5c7689e9 More common handler interface 2025-12-02 15:45:24 +02:00
Hein
08050c960d Optional Authentication 2025-12-02 14:14:38 +02:00
Hein
78029fb34f Fixed formatting issues
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-12-01 14:56:30 +02:00
Hein
1643a5e920 Added cache, funcspec and implemented total cache 2025-12-01 14:40:54 +02:00
Hein
6bbe0ec8b0 Added function api prototype
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-11-24 17:00:15 +02:00
Hein
e32ec9e17e Updated the security package 2025-11-24 17:00:05 +02:00
Hein
26c175e65e Added make release to vscode tasks
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-24 10:15:23 +02:00
Hein
aa99e8e4bc Added WrapHTTPRequest 2025-11-24 10:13:48 +02:00
Hein
163593901f Huge preload chains causing errors, workaround to do seperate selects.
Some checks failed
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
Tests / Build (push) Has been cancelled
2025-11-21 17:09:11 +02:00
Hein
1261960e97 Ability to handle multiple x-custom- headers
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-21 12:15:07 +02:00
Hein
76bbf33db2 Fixed SingleRecordAsObject true when handleRead with no id 2025-11-21 11:49:08 +02:00
Hein
02c9b96b0c Better SanitizeWhereClause 2025-11-21 11:42:01 +02:00
Hein
9a3564f05f SanitizeWhereClause with tablename on handlers. 2025-11-21 11:00:44 +02:00
Hein
a931b8cdd2 Better preloads 2025-11-21 10:41:58 +02:00
Hein
7e76977dcc Lots of refactoring, Fixes to preloads 2025-11-21 10:17:20 +02:00
Hein
7853a3f56a cql_columns parsing and recursive preloading. Also added legacy header support for limt(s,e) ,sort(x,y,-z) 2025-11-21 09:15:40 +02:00
Hein
c2e0c36c79 Restheadspec now takes parameters from query parameters and headers. Allows for backward compatibility with our old dojo clients 2025-11-21 08:56:58 +02:00
Hein
59bd709460 More reflection function to handle sql columns and get default sqlcolumn lists. 2025-11-21 08:35:46 +02:00
Hein
05962035b6 when you specify computed columns without explicitly listing base columns, you'll get all base model column
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-20 17:34:46 +02:00
Hein
1cd04b7083 Better where clause handling for preloads 2025-11-20 17:02:27 +02:00
Hein
0d4909054c Better handling of preload where conditions and a few panic changes 2025-11-20 16:50:26 +02:00
Hein
745564f2e7 More Panic Recovery for reflection on orm 2025-11-20 15:20:21 +02:00
Hein
311e50bfdd Better relation lookup
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-20 14:30:59 +02:00
Hein
c95bc9e633 Added x-files feature 2025-11-20 12:47:36 +02:00
Hein
07b09e2025 handle JSON sql columns 2025-11-20 12:04:19 +02:00
Hein
3d5334002d Fixes on Table Name on insert 2025-11-20 11:49:07 +02:00
Hein
640582d508 Better types 2025-11-20 11:40:16 +02:00
Hein
b0b3ae662b Common Sql Types 2025-11-20 11:18:49 +02:00
Hein
c9b9f75b06 Fixed go mod version issues 2025-11-20 10:34:27 +02:00
Hein
af3260864d INSERT statements were failing with duplicate key errors because the SQL being generated 2025-11-20 10:31:25 +02:00
Hein
ca6d2deff6 Fixed insert statement bug 2025-11-20 10:11:26 +02:00
Hein
1481443516 Fixed double .Model and .Table 2025-11-20 10:02:36 +02:00
Hein
cb54ec5e27 Better responses for updates and inserts 2025-11-20 09:57:24 +02:00
Hein
7d6a9025f5 Fixed hardcoded id 2025-11-20 09:40:11 +02:00
Hein
35089f511f correctly handle structs with embedded fields 2025-11-20 09:28:37 +02:00
Hein
66b6a0d835 Better registry handling
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-19 18:29:24 +02:00
Hein
456c165814 Fixed models being icorrectly set and added SetDefaultRegistry 2025-11-19 18:22:56 +02:00
Hein
850d7b546c Added modelregistry.AddRegistry 2025-11-19 18:18:18 +02:00
Hein
a44ef90d7c Fixes on getRelationshipInfo, ShouldUseNestedProcessor 2025-11-19 18:03:25 +02:00
Hein
8b7db5b31a reflection-based column validation for UpdateQuery 2025-11-19 17:41:15 +02:00
Hein
14daea3b05 Fixes for CUD operations 2025-11-19 15:08:04 +02:00
Hein
35f23b6d9e Recursive crud fix 2025-11-19 14:32:20 +02:00
Hein
53a4e67f70 Specifically call update if a ID was given. 2025-11-19 14:24:39 +02:00
Hein
1289c3af88 Fixed handling post routes as well for the restheadspec
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-19 14:04:56 +02:00
Hein
cdfb7a67fd Added Single Record as Object feature 2025-11-19 13:58:52 +02:00
Hein
7f5b851669 Empty sort appended bug fix
Some checks failed
Tests / Build (push) Has been cancelled
Tests / Run Tests (1.23.x) (push) Has been cancelled
Tests / Run Tests (1.24.x) (push) Has been cancelled
Tests / Lint Code (push) Has been cancelled
2025-11-11 17:16:59 +02:00
Hein
f0e26b1c0d Fixed and refactored reflection.Len 2025-11-11 17:07:44 +02:00
Hein
1db1b924ef Proper handling of x-preload-col-where 2025-11-11 16:53:02 +02:00
Hein
d9cf23b1dc Fixed column expression bug 2025-11-11 16:39:06 +02:00
Hein
94f013c872 Preload fixes 2025-11-11 15:54:43 +02:00
Hein
c52fcff61d Preload fixes 2025-11-11 15:34:24 +02:00
Hein
ce106fa940 Updated documentation 2025-11-11 14:57:01 +02:00
Hein
37b4b75175 Fixed preload and id fields with GetPrimaryKeyName 2025-11-11 14:32:41 +02:00
Hein
0cef0f75d3 Fixed computed columns 2025-11-11 12:28:53 +02:00
Hein
006dc4a2b2 Using scan model method for better relation handling. e.g bun When querying has-many or many-to-many relationships, you should use Model instead of the dest parameter in Scan 2025-11-11 11:58:41 +02:00
Hein
ecd7b31910 Fixed linting issues 2025-11-11 11:32:30 +02:00
Hein
7b8216b71c More fixes for _request 2025-11-11 11:16:07 +02:00
Hein
682716dd31 Linting fixes 2025-11-11 11:03:02 +02:00
Hein
412bbab560 Added testing for CRUD
Some checks are pending
Tests / Run Tests (1.23.x) (push) Waiting to run
Tests / Run Tests (1.24.x) (push) Waiting to run
Tests / Lint Code (push) Waiting to run
Tests / Build (push) Waiting to run
2025-11-11 10:46:43 +02:00
Hein
dc3254522c Added recursive crud handler. 2025-11-11 10:21:20 +02:00
Hein
2818e7e9cd Remove so debug logs 2025-11-10 17:15:55 +02:00
Hein
e39012ddbd Updates to make release 2025-11-10 17:06:47 +02:00
Hein
ceaa251301 Updated logging, added getRowNumber and a few more 2025-11-10 17:02:37 +02:00
Hein
faafe5abea Added content-range headers 2025-11-10 12:25:09 +02:00
Hein
3eb17666bf Migration to Bitech 2025-11-10 11:43:15 +02:00
Hein
c8704c07dd Added cursor filters and hooks 2025-11-10 10:22:55 +02:00
Hein
fc82a9bc50 todo 2025-11-07 16:30:02 +02:00
Hein
c26ea3cd61 todo 2025-11-07 16:12:09 +02:00
Hein
a5d97cc07b Fixed the filters 2025-11-07 15:58:24 +02:00
Hein
0899ba5029 Pointer Fixes 2025-11-07 14:22:58 +02:00
Hein
c84dd7dc91 Lets try the model approach again 2025-11-07 14:18:15 +02:00
Hein
f1c6b36374 Bun Adaptor updates 2025-11-07 14:03:40 +02:00
Hein
abee5c942f Count Fixes 2025-11-07 13:54:24 +02:00
Hein
2e9a0bd51a Better model pointers 2025-11-07 13:45:08 +02:00
Hein
f518a3c73c Some validation and header decoding 2025-11-07 13:31:48 +02:00
Hein
07c239aaa1 Make sure to enable Clean JSON when select fields given 2025-11-07 11:00:56 +02:00
Hein
1adca4c49b Content Types and Respose fixes for restheadpsec 2025-11-07 10:55:42 +02:00
Hein
eefed23766 COUNT queries were generating incorrect SQL with the table appearing twice 2025-11-07 10:37:53 +02:00
Hein
3b2d05465e Fixed tablename and schema lookups 2025-11-07 10:28:14 +02:00
Hein
e88018543e Reflect safty 2025-11-07 09:47:12 +02:00
Hein
e7e5754a47 Added panic catches 2025-11-07 09:32:37 +02:00
Hein
c88bff1883 Better handling with context 2025-11-07 09:13:06 +02:00
Hein
d122c7af42 Updated how model registry works 2025-11-07 08:26:50 +02:00
185 changed files with 54926 additions and 1194 deletions

1
.claude/readme Normal file
View File

@ -0,0 +1 @@
We use claude for testing and document generation.

52
.env.example Normal file
View File

@ -0,0 +1,52 @@
# ResolveSpec Environment Variables Example
# Environment variables override config file settings
# All variables are prefixed with RESOLVESPEC_
# Nested config uses underscores (e.g., server.addr -> RESOLVESPEC_SERVER_ADDR)
# Server Configuration
RESOLVESPEC_SERVER_ADDR=:8080
RESOLVESPEC_SERVER_SHUTDOWN_TIMEOUT=30s
RESOLVESPEC_SERVER_DRAIN_TIMEOUT=25s
RESOLVESPEC_SERVER_READ_TIMEOUT=10s
RESOLVESPEC_SERVER_WRITE_TIMEOUT=10s
RESOLVESPEC_SERVER_IDLE_TIMEOUT=120s
# Tracing Configuration
RESOLVESPEC_TRACING_ENABLED=false
RESOLVESPEC_TRACING_SERVICE_NAME=resolvespec
RESOLVESPEC_TRACING_SERVICE_VERSION=1.0.0
RESOLVESPEC_TRACING_ENDPOINT=http://localhost:4318/v1/traces
# Cache Configuration
RESOLVESPEC_CACHE_PROVIDER=memory
# Redis Cache (when provider=redis)
RESOLVESPEC_CACHE_REDIS_HOST=localhost
RESOLVESPEC_CACHE_REDIS_PORT=6379
RESOLVESPEC_CACHE_REDIS_PASSWORD=
RESOLVESPEC_CACHE_REDIS_DB=0
# Memcache (when provider=memcache)
# Note: For arrays, separate values with commas
RESOLVESPEC_CACHE_MEMCACHE_SERVERS=localhost:11211
RESOLVESPEC_CACHE_MEMCACHE_MAX_IDLE_CONNS=10
RESOLVESPEC_CACHE_MEMCACHE_TIMEOUT=100ms
# Logger Configuration
RESOLVESPEC_LOGGER_DEV=false
RESOLVESPEC_LOGGER_PATH=
# Middleware Configuration
RESOLVESPEC_MIDDLEWARE_RATE_LIMIT_RPS=100.0
RESOLVESPEC_MIDDLEWARE_RATE_LIMIT_BURST=200
RESOLVESPEC_MIDDLEWARE_MAX_REQUEST_SIZE=10485760
# CORS Configuration
# Note: For arrays in env vars, separate with commas
RESOLVESPEC_CORS_ALLOWED_ORIGINS=*
RESOLVESPEC_CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
RESOLVESPEC_CORS_ALLOWED_HEADERS=*
RESOLVESPEC_CORS_MAX_AGE=3600
# Database Configuration
RESOLVESPEC_DATABASE_URL=host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable

84
.github/workflows/maint.yml vendored Normal file
View File

@ -0,0 +1,84 @@
name: Build , Vet Test, and Lint
on:
push:
branches: [main, develop]
pull_request:
branches: [main, develop]
workflow_dispatch:
jobs:
test:
name: Run Vet Tests
runs-on: ubuntu-latest
strategy:
matrix:
go-version: ["1.23.x", "1.24.x"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
cache: true
- name: Display Go version
run: go version
- name: Download dependencies
run: go mod download
- name: Verify dependencies
run: go mod verify
- name: Run go vet
run: go vet ./...
lint:
name: Lint Code
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: true
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v9
with:
version: latest
args: --timeout=5m
build:
name: Build
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: true
- name: Build
run: go build -v ./...
- name: Check for uncommitted changes
run: |
if [[ -n $(git status -s) ]]; then
echo "Error: Uncommitted changes found after build"
git status -s
exit 1
fi

82
.github/workflows/make_tag.yml vendored Normal file
View File

@ -0,0 +1,82 @@
# This workflow will build a golang project
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
name: Create Go Release (Tag Versioning)
on:
workflow_dispatch:
inputs:
semver:
description: "New Version"
required: true
default: "patch"
type: choice
options:
- patch
- minor
- major
jobs:
tag_and_commit:
name: "Tag and Commit ${{ github.event.inputs.semver }}"
runs-on: linux
permissions:
contents: write # 'write' access to repository contents
pull-requests: write # 'write' access to pull requests
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Set up Git
run: |
git config --global user.name "Hein"
git config --global user.email "hein.puth@gmail.com"
- name: Fetch latest tag
id: latest_tag
run: |
git fetch --tags
latest_tag=$(git describe --tags `git rev-list --tags --max-count=1`)
echo "::set-output name=tag::$latest_tag"
- name: Determine new tag version
id: new_tag
run: |
current_tag=${{ steps.latest_tag.outputs.tag }}
version=$(echo $current_tag | cut -c 2-) # remove the leading 'v'
IFS='.' read -r -a version_parts <<< "$version"
major=${version_parts[0]}
minor=${version_parts[1]}
patch=${version_parts[2]}
case "${{ github.event.inputs.semver }}" in
"patch")
((patch++))
;;
"minor")
((minor++))
patch=0
;;
"release")
((major++))
minor=0
patch=0
;;
*)
echo "Invalid semver input"
exit 1
;;
esac
new_tag="v$major.$minor.$patch"
echo "::set-output name=tag::$new_tag"
- name: Create tag
run: |
git tag -a ${{ steps.new_tag.outputs.tag }} -m "Tagging ${{ steps.new_tag.outputs.tag }} for release"
- name: Push changes
uses: ad-m/github-push-action@master
with:
github_token: ${{ secrets.BITECH_GITHUB_TOKEN }}
force: true
tags: true

81
.github/workflows/tests.yml vendored Normal file
View File

@ -0,0 +1,81 @@
name: Tests
on:
push:
branches: [main, develop]
pull_request:
branches: [main, develop]
jobs:
unit-tests:
name: Unit Tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: "1.24"
- name: Run unit tests
run: go test ./pkg/resolvespec ./pkg/restheadspec -v -cover
- name: Generate coverage report
run: |
go test ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage.out
go tool cover -html=coverage.out -o coverage.html
- name: Upload coverage
uses: actions/upload-artifact@v5
with:
name: coverage-report
path: coverage.html
integration-tests:
name: Integration Tests
runs-on: ubuntu-latest
services:
postgres:
image: postgres:15-alpine
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
steps:
- uses: actions/checkout@v6
- name: Set up Go
uses: actions/setup-go@v6
with:
go-version: "1.24"
- name: Create test databases
env:
PGPASSWORD: postgres
run: |
psql -h localhost -U postgres -c "CREATE DATABASE resolvespec_test;"
psql -h localhost -U postgres -c "CREATE DATABASE restheadspec_test;"
- name: Run resolvespec integration tests
env:
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
run: go test -tags=integration ./pkg/resolvespec -v -coverprofile=coverage-resolvespec-integration.out
- name: Run restheadspec integration tests
env:
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=restheadspec_test port=5432 sslmode=disable"
run: go test -tags=integration ./pkg/restheadspec -v -coverprofile=coverage-restheadspec-integration.out
- name: Generate integration coverage
env:
TEST_DATABASE_URL: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
run: |
go tool cover -html=coverage-resolvespec-integration.out -o coverage-resolvespec-integration.html
go tool cover -html=coverage-restheadspec-integration.out -o coverage-restheadspec-integration.html
- name: Upload resolvespec integration coverage
uses: actions/upload-artifact@v5
with:
name: resolvespec-integration-coverage-report
path: coverage-resolvespec-integration.html
- name: Upload restheadspec integration coverage
uses: actions/upload-artifact@v5
with:
name: integration-coverage-restheadspec-report
path: coverage-restheadspec-integration

1
.gitignore vendored
View File

@ -24,3 +24,4 @@ go.work.sum
# env file # env file
.env .env
bin/ bin/
test.db

110
.golangci.bck.yml Normal file
View File

@ -0,0 +1,110 @@
run:
timeout: 5m
tests: true
skip-dirs:
- vendor
- .github
linters:
enable:
- errcheck
- gosimple
- govet
- ineffassign
- staticcheck
- unused
- gofmt
- goimports
- misspell
- gocritic
- revive
- stylecheck
disable:
- typecheck # Can cause issues with generics in some cases
linters-settings:
errcheck:
check-type-assertions: false
check-blank: false
govet:
check-shadowing: false
gofmt:
simplify: true
goimports:
local-prefixes: github.com/bitechdev/ResolveSpec
gocritic:
enabled-checks:
- appendAssign
- assignOp
- boolExprSimplify
- builtinShadow
- captLocal
- caseOrder
- defaultCaseOrder
- dupArg
- dupBranchBody
- dupCase
- dupSubExpr
- elseif
- emptyFallthrough
- equalFold
- flagName
- ifElseChain
- indexAlloc
- initClause
- methodExprCall
- nilValReturn
- rangeExprCopy
- rangeValCopy
- regexpMust
- singleCaseSwitch
- sloppyLen
- stringXbytes
- switchTrue
- typeAssertChain
- typeSwitchVar
- underef
- unlabelStmt
- unnamedResult
- unnecessaryBlock
- weakCond
- yodaStyleExpr
revive:
rules:
- name: exported
disabled: true
- name: package-comments
disabled: true
issues:
exclude-use-default: false
max-issues-per-linter: 0
max-same-issues: 0
# Exclude some linters from running on tests files
exclude-rules:
- path: _test\.go
linters:
- errcheck
- dupl
- gosec
- gocritic
# Ignore "error return value not checked" for defer statements
- linters:
- errcheck
text: "Error return value of .((os\\.)?std(out|err)\\..*|.*Close|.*Flush|os\\.Remove(All)?|.*print(f|ln)?|os\\.(Un)?Setenv). is not checked"
# Ignore complexity in test files
- path: _test\.go
text: "cognitive complexity|cyclomatic complexity"
output:
format: colored-line-number
print-issued-lines: true
print-linter-name: true

114
.golangci.json Normal file
View File

@ -0,0 +1,114 @@
{
"formatters": {
"enable": [
"gofmt",
"goimports"
],
"exclusions": {
"generated": "lax",
"paths": [
"third_party$",
"builtin$",
"examples$"
]
},
"settings": {
"gofmt": {
"simplify": true
},
"goimports": {
"local-prefixes": [
"github.com/bitechdev/ResolveSpec"
]
}
}
},
"issues": {
"max-issues-per-linter": 0,
"max-same-issues": 0
},
"linters": {
"enable": [
"gocritic",
"misspell",
"revive"
],
"exclusions": {
"generated": "lax",
"paths": [
"third_party$",
"builtin$",
"examples$",
"mocks?",
"tests?"
],
"rules": [
{
"linters": [
"dupl",
"errcheck",
"gocritic",
"gosec"
],
"path": "_test\\.go"
},
{
"linters": [
"errcheck"
],
"text": "Error return value of .((os\\.)?std(out|err)\\..*|.*Close|.*Flush|os\\.Remove(All)?|.*print(f|ln)?|os\\.(Un)?Setenv). is not checked"
},
{
"path": "_test\\.go",
"text": "cognitive complexity|cyclomatic complexity"
}
]
},
"settings": {
"errcheck": {
"check-blank": false,
"check-type-assertions": false
},
"gocritic": {
"enabled-checks": [
"boolExprSimplify",
"builtinShadow",
"emptyFallthrough",
"equalFold",
"indexAlloc",
"initClause",
"methodExprCall",
"nilValReturn",
"rangeExprCopy",
"rangeValCopy",
"stringXbytes",
"typeAssertChain",
"unlabelStmt",
"unnamedResult",
"unnecessaryBlock",
"weakCond",
"yodaStyleExpr"
],
"disabled-checks": [
"ifElseChain"
]
},
"revive": {
"rules": [
{
"disabled": true,
"name": "exported"
},
{
"disabled": true,
"name": "package-comments"
}
]
}
}
},
"run": {
"tests": true
},
"version": "2"
}

56
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,56 @@
{
"go.testFlags": [
"-v"
],
"go.testTimeout": "300s",
"go.coverOnSave": false,
"go.coverOnSingleTest": true,
"go.coverageDecorator": {
"type": "gutter"
},
"go.testEnvVars": {
"TEST_DATABASE_URL": "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5432 sslmode=disable"
},
"go.toolsEnvVars": {
"CGO_ENABLED": "0"
},
"go.buildTags": "",
"go.testTags": "",
"files.exclude": {
"**/.git": true,
"**/.DS_Store": true,
"**/coverage.out": true,
"**/coverage.html": true,
"**/coverage-integration.out": true,
"**/coverage-integration.html": true
},
"files.watcherExclude": {
"**/.git/objects/**": true,
"**/.git/subtree-cache/**": true,
"**/node_modules/*/**": true,
"**/.hg/store/**": true,
"**/vendor/**": true
},
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
},
"[go]": {
"editor.defaultFormatter": "golang.go",
"editor.formatOnSave": true,
"editor.insertSpaces": false,
"editor.tabSize": 4
},
"gopls": {
"ui.completion.usePlaceholders": true,
"ui.semanticTokens": true,
"ui.codelenses": {
"generate": true,
"regenerate_cgo": true,
"test": true,
"tidy": true,
"upgrade_dependency": true,
"vendor": true
}
}
}

268
.vscode/tasks.json vendored
View File

@ -9,7 +9,7 @@
"env": { "env": {
"CGO_ENABLED": "0" "CGO_ENABLED": "0"
}, },
"cwd": "${workspaceFolder}/bin", "cwd": "${workspaceFolder}/bin"
}, },
"args": [ "args": [
"../..." "../..."
@ -17,28 +17,272 @@
"problemMatcher": [ "problemMatcher": [
"$go" "$go"
], ],
"group": "build", "group": "build"
},
{
"type": "shell",
"label": "test: unit tests (all)",
"command": "go test ./pkg/resolvespec ./pkg/restheadspec -v -cover",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": {
"kind": "test",
"isDefault": true
},
"presentation": {
"reveal": "always",
"panel": "shared",
"focus": true
}
},
{
"type": "shell",
"label": "test: unit tests (resolvespec)",
"command": "go test ./pkg/resolvespec -v -cover",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "test: unit tests (restheadspec)",
"command": "go test ./pkg/restheadspec -v -cover",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "test: integration tests (automated)",
"command": "./scripts/run-integration-tests.sh",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated",
"focus": true
}
},
{
"type": "shell",
"label": "test: integration tests (resolvespec only)",
"command": "./scripts/run-integration-tests.sh resolvespec",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated"
}
},
{
"type": "shell",
"label": "test: integration tests (restheadspec only)",
"command": "./scripts/run-integration-tests.sh restheadspec",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated"
}
},
{
"type": "shell",
"label": "test: coverage report",
"command": "make coverage",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "test: integration coverage report",
"command": "make coverage-integration",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "docker: start postgres",
"command": "make docker-up",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "docker: stop postgres",
"command": "make docker-down",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"presentation": {
"reveal": "always",
"panel": "shared"
}
},
{
"type": "shell",
"label": "docker: clean postgres data",
"command": "make clean",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"presentation": {
"reveal": "always",
"panel": "shared"
}
}, },
{ {
"type": "go", "type": "go",
"label": "go: test workspace", "label": "go: test workspace (with race)",
"command": "test", "command": "test",
"options": { "options": {
"env": { "cwd": "${workspaceFolder}"
"CGO_ENABLED": "0"
},
"cwd": "${workspaceFolder}/bin",
}, },
"args": [ "args": [
"../..." "-v",
"-race",
"-coverprofile=coverage.out",
"-covermode=atomic",
"./..."
], ],
"problemMatcher": [ "problemMatcher": [
"$go" "$go"
], ],
"group": "build", "group": "test",
"presentation": {
"reveal": "always",
"panel": "shared"
}
}, },
{
"type": "shell",
"label": "go: vet workspace",
"command": "go vet ./...",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [
"$go"
],
"group": "test"
},
{
"type": "shell",
"label": "go: lint workspace",
"command": "golangci-lint run --timeout=5m",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"group": "build"
},
{
"type": "shell",
"label": "go: lint workspace (fix)",
"command": "golangci-lint run --timeout=5m --fix",
"options": {
"cwd": "${workspaceFolder}"
},
"problemMatcher": [],
"group": "build"
},
{
"type": "shell",
"label": "test: all tests (unit + integration)",
"command": "make test",
"options": {
"cwd": "${workspaceFolder}"
},
"dependsOn": [
"docker: start postgres"
],
"problemMatcher": [
"$go"
],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated",
"focus": true
}
},
{
"type": "shell",
"label": "test: full suite with checks",
"dependsOrder": "sequence",
"dependsOn": [
"go: vet workspace",
"test: unit tests (all)",
"test: integration tests (automated)"
],
"problemMatcher": [],
"group": "test",
"presentation": {
"reveal": "always",
"panel": "dedicated"
}
},
{
"type": "shell",
"label": "Make Release",
"problemMatcher": [],
"command": "sh ${workspaceFolder}/make_release.sh"
}
] ]
} }

View File

@ -1,6 +1,6 @@
MIT License MIT License
Copyright (c) 2025 Warky Devs Pty Ltd Copyright (c) 2025
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View File

@ -1,173 +0,0 @@
# Migration Guide: Database and Router Abstraction
This guide explains how to migrate from the direct GORM/Router dependencies to the new abstracted interfaces.
## Overview of Changes
### What was changed:
1. **Database Operations**: GORM-specific code is now abstracted behind `Database` interface
2. **Router Integration**: HTTP router dependencies are abstracted behind `Router` interface
3. **Model Registry**: Models are now managed through a `ModelRegistry` interface
4. **Backward Compatibility**: Existing code continues to work with `NewAPIHandler()`
### Benefits:
- **Database Flexibility**: Switch between GORM, Bun, or other ORMs without code changes
- **Router Flexibility**: Use Gorilla Mux, Gin, Echo, or other routers
- **Better Testing**: Easy to mock database and router interactions
- **Cleaner Separation**: Business logic separated from ORM/router specifics
## Migration Path
### Option 1: No Changes Required (Backward Compatible)
Your existing code continues to work without any changes:
```go
// This still works exactly as before
handler := resolvespec.NewAPIHandler(db)
```
### Option 2: Gradual Migration to New API
#### Step 1: Use New Handler Constructor
```go
// Old way
handler := resolvespec.NewAPIHandler(gormDB)
// New way
handler := resolvespec.NewHandlerWithGORM(gormDB)
```
#### Step 2: Use Interface-based Approach
```go
// Create database adapter
dbAdapter := resolvespec.NewGormAdapter(gormDB)
// Create model registry
registry := resolvespec.NewModelRegistry()
// Register your models
registry.RegisterModel("public.users", &User{})
registry.RegisterModel("public.orders", &Order{})
// Create handler
handler := resolvespec.NewHandler(dbAdapter, registry)
```
## Switching Database Backends
### From GORM to Bun
```go
// Add bun dependency first:
// go get github.com/uptrace/bun
// Old GORM setup
gormDB, _ := gorm.Open(sqlite.Open("test.db"), &gorm.Config{})
gormAdapter := resolvespec.NewGormAdapter(gormDB)
// New Bun setup
sqlDB, _ := sql.Open("sqlite3", "test.db")
bunDB := bun.NewDB(sqlDB, sqlitedialect.New())
bunAdapter := resolvespec.NewBunAdapter(bunDB)
// Handler creation is identical
handler := resolvespec.NewHandler(bunAdapter, registry)
```
## Router Flexibility
### Current Gorilla Mux (Default)
```go
router := mux.NewRouter()
resolvespec.SetupRoutes(router, handler)
```
### BunRouter (Built-in Support)
```go
// Simple setup
router := bunrouter.New()
resolvespec.SetupBunRouterWithResolveSpec(router, handler)
// Or using adapter
routerAdapter := resolvespec.NewStandardBunRouterAdapter()
// Use routerAdapter.GetBunRouter() for the underlying router
```
### Using Router Adapters (Advanced)
```go
// For when you want router abstraction
routerAdapter := resolvespec.NewStandardRouter()
routerAdapter.RegisterRoute("/{schema}/{entity}", handlerFunc)
```
## Model Registration
### Old Way (Still Works)
```go
// Models registered through existing models package
handler.RegisterModel("public", "users", &User{})
```
### New Way (Recommended)
```go
registry := resolvespec.NewModelRegistry()
registry.RegisterModel("public.users", &User{})
registry.RegisterModel("public.orders", &Order{})
handler := resolvespec.NewHandler(dbAdapter, registry)
```
## Interface Definitions
### Database Interface
```go
type Database interface {
NewSelect() SelectQuery
NewInsert() InsertQuery
NewUpdate() UpdateQuery
NewDelete() DeleteQuery
// ... transaction methods
}
```
### Available Adapters
- `GormAdapter` - For GORM (ready to use)
- `BunAdapter` - For Bun (add dependency: `go get github.com/uptrace/bun`)
- Easy to create custom adapters for other ORMs
## Testing Benefits
### Before (Tightly Coupled)
```go
// Hard to test - requires real GORM setup
func TestHandler(t *testing.T) {
db := setupRealGormDB()
handler := resolvespec.NewAPIHandler(db)
// ... test logic
}
```
### After (Mockable)
```go
// Easy to test - mock the interfaces
func TestHandler(t *testing.T) {
mockDB := &MockDatabase{}
mockRegistry := &MockModelRegistry{}
handler := resolvespec.NewHandler(mockDB, mockRegistry)
// ... test logic with mocks
}
```
## Breaking Changes
- **None for existing code** - Full backward compatibility maintained
- New interfaces are additive, not replacing existing APIs
## Recommended Migration Timeline
1. **Phase 1**: Use existing code (no changes needed)
2. **Phase 2**: Gradually adopt new constructors (`NewHandlerWithGORM`)
3. **Phase 3**: Move to interface-based approach when needed
4. **Phase 4**: Switch database backends if desired
## Getting Help
- Check example functions in `resolvespec.go`
- Review interface definitions in `database.go`
- Examine adapter implementations for patterns

66
Makefile Normal file
View File

@ -0,0 +1,66 @@
.PHONY: test test-unit test-integration docker-up docker-down clean
# Run all unit tests
test-unit:
@echo "Running unit tests..."
@go test ./pkg/resolvespec ./pkg/restheadspec -v -cover
# Run all integration tests (requires PostgreSQL)
test-integration:
@echo "Running integration tests..."
@go test -tags=integration ./pkg/resolvespec ./pkg/restheadspec -v
# Run all tests (unit + integration)
test: test-unit test-integration
# Start PostgreSQL for integration tests
docker-up:
@echo "Starting PostgreSQL container..."
@docker-compose up -d postgres-test
@echo "Waiting for PostgreSQL to be ready..."
@sleep 5
@echo "PostgreSQL is ready!"
# Stop PostgreSQL container
docker-down:
@echo "Stopping PostgreSQL container..."
@docker-compose down
# Clean up Docker volumes and test data
clean:
@echo "Cleaning up..."
@docker-compose down -v
@echo "Cleanup complete!"
# Run integration tests with Docker (full workflow)
test-integration-docker: docker-up
@echo "Running integration tests with Docker..."
@go test -tags=integration ./pkg/resolvespec ./pkg/restheadspec -v
@$(MAKE) docker-down
# Check test coverage
coverage:
@echo "Generating coverage report..."
@go test ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage.out
@go tool cover -html=coverage.out -o coverage.html
@echo "Coverage report generated: coverage.html"
# Run integration tests coverage
coverage-integration:
@echo "Generating integration test coverage report..."
@go test -tags=integration ./pkg/resolvespec ./pkg/restheadspec -coverprofile=coverage-integration.out
@go tool cover -html=coverage-integration.out -o coverage-integration.html
@echo "Integration coverage report generated: coverage-integration.html"
help:
@echo "Available targets:"
@echo " test-unit - Run unit tests"
@echo " test-integration - Run integration tests (requires PostgreSQL)"
@echo " test - Run all tests"
@echo " docker-up - Start PostgreSQL container"
@echo " docker-down - Stop PostgreSQL container"
@echo " test-integration-docker - Run integration tests with Docker (automated)"
@echo " clean - Clean up Docker volumes"
@echo " coverage - Generate unit test coverage report"
@echo " coverage-integration - Generate integration test coverage report"
@echo " help - Show this help message"

975
README.md

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +1,18 @@
package main package main
import ( import (
"fmt"
"log" "log"
"net/http" "net/http"
"os" "os"
"time" "time"
"github.com/Warky-Devs/ResolveSpec/pkg/logger" "github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/Warky-Devs/ResolveSpec/pkg/modelregistry" "github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/Warky-Devs/ResolveSpec/pkg/testmodels" "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/server"
"github.com/bitechdev/ResolveSpec/pkg/testmodels"
"github.com/Warky-Devs/ResolveSpec/pkg/resolvespec" "github.com/bitechdev/ResolveSpec/pkg/resolvespec"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/glebarez/sqlite" "github.com/glebarez/sqlite"
@ -20,12 +21,27 @@ import (
) )
func main() { func main() {
// Initialize logger // Load configuration
fmt.Println("ResolveSpec test server starting") cfgMgr := config.NewManager()
logger.Init(true) if err := cfgMgr.Load(); err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
cfg, err := cfgMgr.GetConfig()
if err != nil {
log.Fatalf("Failed to get configuration: %v", err)
}
// Initialize logger with configuration
logger.Init(cfg.Logger.Dev)
if cfg.Logger.Path != "" {
logger.UpdateLoggerPath(cfg.Logger.Path, cfg.Logger.Dev)
}
logger.Info("ResolveSpec test server starting")
logger.Info("Configuration loaded - Server will listen on: %s", cfg.Server.Addr)
// Initialize database // Initialize database
db, err := initDB() db, err := initDB(cfg)
if err != nil { if err != nil {
logger.Error("Failed to initialize database: %+v", err) logger.Error("Failed to initialize database: %+v", err)
os.Exit(1) os.Exit(1)
@ -48,32 +64,54 @@ func main() {
handler.RegisterModel("public", modelNames[i], model) handler.RegisterModel("public", modelNames[i], model)
} }
// Setup routes using new SetupMuxRoutes function // Setup routes using new SetupMuxRoutes function (without authentication)
resolvespec.SetupMuxRoutes(r, handler) resolvespec.SetupMuxRoutes(r, handler, nil)
// Start server // Create graceful server with configuration
logger.Info("Starting server on :8080") srv := server.NewGracefulServer(server.Config{
if err := http.ListenAndServe(":8080", r); err != nil { Addr: cfg.Server.Addr,
Handler: r,
ShutdownTimeout: cfg.Server.ShutdownTimeout,
DrainTimeout: cfg.Server.DrainTimeout,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
})
// Start server with graceful shutdown
logger.Info("Starting server on %s", cfg.Server.Addr)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("Server failed to start: %v", err) logger.Error("Server failed to start: %v", err)
os.Exit(1) os.Exit(1)
} }
} }
func initDB() (*gorm.DB, error) { func initDB(cfg *config.Config) (*gorm.DB, error) {
// Configure GORM logger based on config
logLevel := gormlog.Info
if !cfg.Logger.Dev {
logLevel = gormlog.Warn
}
newLogger := gormlog.New( newLogger := gormlog.New(
log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer log.New(os.Stdout, "\r\n", log.LstdFlags), // io writer
gormlog.Config{ gormlog.Config{
SlowThreshold: time.Second, // Slow SQL threshold SlowThreshold: time.Second, // Slow SQL threshold
LogLevel: gormlog.Info, // Log level LogLevel: logLevel, // Log level
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
ParameterizedQueries: true, // Don't include params in the SQL log ParameterizedQueries: true, // Don't include params in the SQL log
Colorful: true, // Disable color Colorful: cfg.Logger.Dev,
}, },
) )
// Use database URL from config if available, otherwise use default SQLite
dbURL := cfg.Database.URL
if dbURL == "" {
dbURL = "test.db"
}
// Create SQLite database // Create SQLite database
db, err := gorm.Open(sqlite.Open("test.db"), &gorm.Config{Logger: newLogger, FullSaveAssociations: false}) db, err := gorm.Open(sqlite.Open(dbURL), &gorm.Config{Logger: newLogger, FullSaveAssociations: false})
if err != nil { if err != nil {
return nil, err return nil, err
} }

41
config.yaml Normal file
View File

@ -0,0 +1,41 @@
# ResolveSpec Test Server Configuration
# This is a minimal configuration for the test server
server:
addr: ":8080"
shutdown_timeout: 30s
drain_timeout: 25s
read_timeout: 10s
write_timeout: 10s
idle_timeout: 120s
logger:
dev: true # Enable development mode for readable logs
path: "" # Empty means log to stdout
cache:
provider: "memory"
middleware:
rate_limit_rps: 100.0
rate_limit_burst: 200
max_request_size: 10485760 # 10MB
cors:
allowed_origins:
- "*"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
- "OPTIONS"
allowed_headers:
- "*"
max_age: 3600
tracing:
enabled: false
database:
url: "" # Empty means use default SQLite (test.db)

57
config.yaml.example Normal file
View File

@ -0,0 +1,57 @@
# ResolveSpec Configuration Example
# This file demonstrates all available configuration options
# Copy this file to config.yaml and customize as needed
server:
addr: ":8080"
shutdown_timeout: 30s
drain_timeout: 25s
read_timeout: 10s
write_timeout: 10s
idle_timeout: 120s
tracing:
enabled: false
service_name: "resolvespec"
service_version: "1.0.0"
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
cache:
provider: "memory" # Options: memory, redis, memcache
redis:
host: "localhost"
port: 6379
password: ""
db: 0
memcache:
servers:
- "localhost:11211"
max_idle_conns: 10
timeout: 100ms
logger:
dev: false
path: "" # Empty for stdout, or specify file path
middleware:
rate_limit_rps: 100.0
rate_limit_burst: 200
max_request_size: 10485760 # 10MB in bytes
cors:
allowed_origins:
- "*"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
- "OPTIONS"
allowed_headers:
- "*"
max_age: 3600
database:
url: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable"

27
docker-compose.yml Normal file
View File

@ -0,0 +1,27 @@
services:
postgres-test:
image: postgres:15-alpine
container_name: resolvespec-postgres-test
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
ports:
- "5434:5432"
volumes:
- postgres-test-data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
timeout: 5s
retries: 5
networks:
- resolvespec-test
volumes:
postgres-test-data:
driver: local
networks:
resolvespec-test:
driver: bridge

80
go.mod
View File

@ -1,39 +1,97 @@
module github.com/Warky-Devs/ResolveSpec module github.com/bitechdev/ResolveSpec
go 1.23.0 go 1.24.0
toolchain go1.24.6 toolchain go1.24.6
require ( require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf
github.com/getsentry/sentry-go v0.40.0
github.com/glebarez/sqlite v1.11.0 github.com/glebarez/sqlite v1.11.0
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.1 github.com/gorilla/mux v1.8.1
github.com/stretchr/testify v1.8.1 github.com/prometheus/client_golang v1.23.2
github.com/redis/go-redis/v9 v9.17.1
github.com/spf13/viper v1.21.0
github.com/stretchr/testify v1.11.1
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/uptrace/bun v1.2.15 github.com/uptrace/bun v1.2.15
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15
github.com/uptrace/bun/driver/sqliteshim v1.2.15
github.com/uptrace/bunrouter v1.0.23
go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0
go.opentelemetry.io/otel/sdk v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/zap v1.27.0 go.uber.org/zap v1.27.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.6.0
gorm.io/gorm v1.25.12 gorm.io/gorm v1.25.12
) )
require ( require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/glebarez/go-sqlite v1.21.2 // indirect github.com/glebarez/go-sqlite v1.21.2 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.6.0 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-sqlite3 v1.14.28 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncruces/go-strftime v0.1.9 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.66.1 // indirect
github.com/prometheus/procfs v0.16.1 // indirect
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/uptrace/bunrouter v1.0.23 // indirect
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
go.uber.org/multierr v1.10.0 // indirect go.uber.org/multierr v1.10.0 // indirect
golang.org/x/sys v0.34.0 // indirect go.yaml.in/yaml/v2 v2.4.2 // indirect
golang.org/x/text v0.21.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/crypto v0.41.0 // indirect
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.28.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
google.golang.org/grpc v1.75.0 // indirect
google.golang.org/protobuf v1.36.8 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
modernc.org/libc v1.22.5 // indirect modernc.org/libc v1.66.3 // indirect
modernc.org/mathutil v1.5.0 // indirect modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.5.0 // indirect modernc.org/memory v1.11.0 // indirect
modernc.org/sqlite v1.23.1 // indirect modernc.org/sqlite v1.38.0 // indirect
) )

214
go.sum
View File

@ -1,76 +1,240 @@
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM=
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo= github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o=
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg= github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs=
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo=
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU=
github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE= github.com/uptrace/bun v1.2.15 h1:Ut68XRBLDgp9qG9QBMa9ELWaZOmzHNdczHQdrOZbEFE=
github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038= github.com/uptrace/bun v1.2.15/go.mod h1:Eghz7NonZMiTX/Z6oKYytJ0oaMEJ/eq3kEV4vSqG038=
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15 h1:7upGMVjFRB1oI78GQw6ruNLblYn5CR+kxqcbbeBBils=
github.com/uptrace/bun/dialect/sqlitedialect v1.2.15/go.mod h1:c7YIDaPNS2CU2uI1p7umFuFWkuKbDcPDDvp+DLHZnkI=
github.com/uptrace/bun/driver/sqliteshim v1.2.15 h1:M/rZJSjOPV4OmfTVnDPtL+wJmdMTqDUn8cuk5ycfABA=
github.com/uptrace/bun/driver/sqliteshim v1.2.15/go.mod h1:YqwxFyvM992XOCpGJtXyKPkgkb+aZpIIMzGbpaw1hIk=
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU= github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M= github.com/uptrace/bunrouter v1.0.23/go.mod h1:O3jAcl+5qgnF+ejhgkmbceEk0E/mqaK+ADOocdNpY8M=
github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8=
github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4=
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc h1:TS73t7x3KarrNd5qAipmspBDS1rkMcgVG/fS1aRb4Rc=
golang.org/x/exp v0.0.0-20250711185948-6ae5c78190dc/go.mod h1:A+z0yzpGtvnG90cToK5n2tu8UJVP2XUATh+r+sfOOOc=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4=
gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE= modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=
modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY= modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU=
modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE=
modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= modernc.org/fileutil v1.3.8 h1:qtzNm7ED75pd1C7WgAGcK4edm4fvhtBsEiI/0NQ54YM=
modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= modernc.org/fileutil v1.3.8/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM= modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk= modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ=
modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.38.0 h1:+4OrfPQ8pxHKuWG4md1JpR/EYAh3Md7TdejuuzE7EUI=
modernc.org/sqlite v1.38.0/go.mod h1:1Bj+yES4SVvBZ4cBOpVZ6QgesMCKpJZDq0nxYzOpmNE=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

View File

@ -4,23 +4,68 @@
read -p "Do you want to make a release version? (y/n): " make_release read -p "Do you want to make a release version? (y/n): " make_release
if [[ $make_release =~ ^[Yy]$ ]]; then if [[ $make_release =~ ^[Yy]$ ]]; then
# Ask the user for the version number # Get the latest tag from git
read -p "Enter the version number : " version latest_tag=$(git describe --tags --abbrev=0 2>/dev/null)
if [ -z "$latest_tag" ]; then
# No tags exist yet, start with v1.0.0
suggested_version="v1.0.0"
echo "No existing tags found. Starting with $suggested_version"
else
echo "Latest tag: $latest_tag"
# Remove 'v' prefix if present
version_number="${latest_tag#v}"
# Split version into major.minor.patch
IFS='.' read -r major minor patch <<< "$version_number"
# Increment patch version
patch=$((patch + 1))
# Construct new version
suggested_version="v${major}.${minor}.${patch}"
echo "Suggested next version: $suggested_version"
fi
# Ask the user for the version number with the suggested version as default
read -p "Enter the version number (press Enter for $suggested_version): " version
# Use suggested version if user pressed Enter without input
if [ -z "$version" ]; then
version="$suggested_version"
fi
# Prepend 'v' to the version if it doesn't start with it # Prepend 'v' to the version if it doesn't start with it
if ! [[ $version =~ ^v ]]; then if ! [[ $version =~ ^v ]]; then
version="v$version" version="v$version"
else
echo "Version already starts with 'v'."
fi fi
# Create an annotated tag # Get commit logs since the last tag
git tag -a "$version" -m "Released Core $version" if [ -z "$latest_tag" ]; then
# No previous tag, get all commits
commit_logs=$(git log --pretty=format:"- %s" --no-merges)
else
# Get commits since the last tag
commit_logs=$(git log "${latest_tag}..HEAD" --pretty=format:"- %s" --no-merges)
fi
# Create the tag message
if [ -z "$commit_logs" ]; then
tag_message="Release $version"
else
tag_message="Release $version
${commit_logs}"
fi
# Create an annotated tag with the commit logs
git tag -a "$version" -m "$tag_message"
# Push the tag to the remote repository # Push the tag to the remote repository
git push origin "$version" git push origin "$version"
echo "Tag $version created for Core and pushed to the remote repository." echo "Tag $version created and pushed to the remote repository."
else else
echo "No release version created." echo "No release version created."
fi fi

340
pkg/cache/README.md vendored Normal file
View File

@ -0,0 +1,340 @@
# Cache Package
A flexible, provider-based caching library for Go that supports multiple backend storage systems including in-memory, Redis, and Memcache.
## Features
- **Multiple Providers**: Support for in-memory, Redis, and Memcache backends
- **Pluggable Architecture**: Easy to add custom cache providers
- **Type-Safe API**: Automatic JSON serialization/deserialization
- **TTL Support**: Configurable time-to-live for cache entries
- **Context-Aware**: All operations support Go contexts
- **Statistics**: Built-in cache statistics and monitoring
- **Pattern Deletion**: Delete keys by pattern (Redis)
- **Lazy Loading**: GetOrSet pattern for easy cache-aside implementation
## Installation
```bash
go get github.com/bitechdev/ResolveSpec/pkg/cache
```
For Redis support:
```bash
go get github.com/redis/go-redis/v9
```
For Memcache support:
```bash
go get github.com/bradfitz/gomemcache/memcache
```
## Quick Start
### In-Memory Cache
```go
package main
import (
"context"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
func main() {
// Initialize with in-memory provider
cache.UseMemory(&cache.Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
})
defer cache.Close()
ctx := context.Background()
c := cache.GetDefaultCache()
// Store a value
type User struct {
ID int
Name string
}
user := User{ID: 1, Name: "John"}
c.Set(ctx, "user:1", user, 10*time.Minute)
// Retrieve a value
var retrieved User
c.Get(ctx, "user:1", &retrieved)
}
```
### Redis Cache
```go
cache.UseRedis(&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "",
DB: 0,
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
})
defer cache.Close()
```
### Memcache
```go
cache.UseMemcache(&cache.MemcacheConfig{
Servers: []string{"localhost:11211"},
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
})
defer cache.Close()
```
## API Reference
### Core Methods
#### Set
```go
Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error
```
Stores a value in the cache with automatic JSON serialization.
#### Get
```go
Get(ctx context.Context, key string, dest interface{}) error
```
Retrieves and deserializes a value from the cache.
#### SetBytes / GetBytes
```go
SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error
GetBytes(ctx context.Context, key string) ([]byte, error)
```
Store and retrieve raw bytes without serialization.
#### Delete
```go
Delete(ctx context.Context, key string) error
```
Removes a key from the cache.
#### DeleteByPattern
```go
DeleteByPattern(ctx context.Context, pattern string) error
```
Removes all keys matching a pattern (Redis only).
#### Clear
```go
Clear(ctx context.Context) error
```
Removes all items from the cache.
#### Exists
```go
Exists(ctx context.Context, key string) bool
```
Checks if a key exists in the cache.
#### GetOrSet
```go
GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration,
loader func() (interface{}, error)) error
```
Retrieves a value from cache, or loads and caches it if not found (lazy loading).
#### Stats
```go
Stats(ctx context.Context) (*CacheStats, error)
```
Returns cache statistics including hits, misses, and key counts.
## Provider Configuration
### In-Memory Options
```go
&cache.Options{
DefaultTTL: 5 * time.Minute, // Default expiration time
MaxSize: 10000, // Maximum number of items
EvictionPolicy: "LRU", // Eviction strategy (future)
}
```
### Redis Configuration
```go
&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "", // Optional authentication
DB: 0, // Database number
PoolSize: 10, // Connection pool size
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
}
```
### Memcache Configuration
```go
&cache.MemcacheConfig{
Servers: []string{"localhost:11211"},
MaxIdleConns: 2,
Timeout: 1 * time.Second,
Options: &cache.Options{
DefaultTTL: 5 * time.Minute,
},
}
```
## Advanced Usage
### Custom Provider
```go
// Create a custom provider instance
memProvider := cache.NewMemoryProvider(&cache.Options{
DefaultTTL: 10 * time.Minute,
MaxSize: 500,
})
// Initialize with custom provider
cache.Initialize(memProvider)
```
### Lazy Loading Pattern
```go
var data ExpensiveData
err := c.GetOrSet(ctx, "expensive:key", &data, 10*time.Minute, func() (interface{}, error) {
// This expensive operation only runs if key is not in cache
return computeExpensiveData(), nil
})
```
### Query API Cache
The package includes specialized functions for caching query results:
```go
// Cache a query result
api := "GetUsers"
query := "SELECT * FROM users WHERE active = true"
tablenames := "users"
total := int64(150)
cache.PutQueryAPICache(ctx, api, query, tablenames, total)
// Retrieve cached query
hash := cache.HashQueryAPICache(api, query)
cachedQuery, err := cache.FetchQueryAPICache(ctx, hash)
```
## Provider Comparison
| Feature | In-Memory | Redis | Memcache |
|---------|-----------|-------|----------|
| Persistence | No | Yes | No |
| Distributed | No | Yes | Yes |
| Pattern Delete | No | Yes | No |
| Statistics | Full | Full | Limited |
| Atomic Operations | Yes | Yes | Yes |
| Max Item Size | Memory | 512MB | 1MB |
## Best Practices
1. **Use contexts**: Always pass context for cancellation and timeout control
2. **Set appropriate TTLs**: Balance between freshness and performance
3. **Handle errors**: Cache misses and errors should be handled gracefully
4. **Monitor statistics**: Use Stats() to monitor cache performance
5. **Clean up**: Always call Close() when shutting down
6. **Pattern consistency**: Use consistent key naming patterns (e.g., "user:id:field")
## Example: Complete Application
```go
package main
import (
"context"
"log"
"time"
"github.com/bitechdev/ResolveSpec/pkg/cache"
)
type UserService struct {
cache *cache.Cache
}
func NewUserService() *UserService {
// Initialize with Redis in production, memory for testing
cache.UseRedis(&cache.RedisConfig{
Host: "localhost",
Port: 6379,
Options: &cache.Options{
DefaultTTL: 10 * time.Minute,
},
})
return &UserService{
cache: cache.GetDefaultCache(),
}
}
func (s *UserService) GetUser(ctx context.Context, userID int) (*User, error) {
var user User
cacheKey := fmt.Sprintf("user:%d", userID)
// Try to get from cache first
err := s.cache.GetOrSet(ctx, cacheKey, &user, 15*time.Minute, func() (interface{}, error) {
// Load from database if not in cache
return s.loadUserFromDB(userID)
})
if err != nil {
return nil, err
}
return &user, nil
}
func (s *UserService) InvalidateUser(ctx context.Context, userID int) error {
cacheKey := fmt.Sprintf("user:%d", userID)
return s.cache.Delete(ctx, cacheKey)
}
func main() {
service := NewUserService()
defer cache.Close()
ctx := context.Background()
user, err := service.GetUser(ctx, 123)
if err != nil {
log.Fatal(err)
}
log.Printf("User: %+v", user)
}
```
## Performance Considerations
- **In-Memory**: Fastest but limited by RAM and not distributed
- **Redis**: Great for distributed systems, persistent, but network overhead
- **Memcache**: Good for distributed caching, simpler than Redis but less features
Choose based on your needs:
- Single instance? Use in-memory
- Need persistence or advanced features? Use Redis
- Simple distributed cache? Use Memcache
## License
See repository license.

76
pkg/cache/cache.go vendored Normal file
View File

@ -0,0 +1,76 @@
package cache
import (
"context"
"fmt"
"time"
)
var (
defaultCache *Cache
)
// Initialize initializes the cache with a provider.
// If not called, the package will use an in-memory provider by default.
func Initialize(provider Provider) {
defaultCache = NewCache(provider)
}
// UseMemory configures the cache to use in-memory storage.
func UseMemory(opts *Options) error {
provider := NewMemoryProvider(opts)
defaultCache = NewCache(provider)
return nil
}
// UseRedis configures the cache to use Redis storage.
func UseRedis(config *RedisConfig) error {
provider, err := NewRedisProvider(config)
if err != nil {
return fmt.Errorf("failed to initialize Redis provider: %w", err)
}
defaultCache = NewCache(provider)
return nil
}
// UseMemcache configures the cache to use Memcache storage.
func UseMemcache(config *MemcacheConfig) error {
provider, err := NewMemcacheProvider(config)
if err != nil {
return fmt.Errorf("failed to initialize Memcache provider: %w", err)
}
defaultCache = NewCache(provider)
return nil
}
// GetDefaultCache returns the default cache instance.
// Initializes with in-memory provider if not already initialized.
func GetDefaultCache() *Cache {
if defaultCache == nil {
_ = UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
})
}
return defaultCache
}
// SetDefaultCache sets a custom cache instance as the default cache.
// This is useful for testing or when you want to use a pre-configured cache instance.
func SetDefaultCache(cache *Cache) {
defaultCache = cache
}
// GetStats returns cache statistics.
func GetStats(ctx context.Context) (*CacheStats, error) {
cache := GetDefaultCache()
return cache.Stats(ctx)
}
// Close closes the cache and releases resources.
func Close() error {
if defaultCache != nil {
return defaultCache.Close()
}
return nil
}

147
pkg/cache/cache_manager.go vendored Normal file
View File

@ -0,0 +1,147 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
)
// Cache is the main cache manager that wraps a Provider.
type Cache struct {
provider Provider
}
// NewCache creates a new cache manager with the specified provider.
func NewCache(provider Provider) *Cache {
return &Cache{
provider: provider,
}
}
// Get retrieves and deserializes a value from the cache.
func (c *Cache) Get(ctx context.Context, key string, dest interface{}) error {
data, exists := c.provider.Get(ctx, key)
if !exists {
return fmt.Errorf("key not found: %s", key)
}
if err := json.Unmarshal(data, dest); err != nil {
return fmt.Errorf("failed to deserialize: %w", err)
}
return nil
}
// GetBytes retrieves raw bytes from the cache.
func (c *Cache) GetBytes(ctx context.Context, key string) ([]byte, error) {
data, exists := c.provider.Get(ctx, key)
if !exists {
return nil, fmt.Errorf("key not found: %s", key)
}
return data, nil
}
// Set serializes and stores a value in the cache with the specified TTL.
func (c *Cache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to serialize: %w", err)
}
return c.provider.Set(ctx, key, data, ttl)
}
// SetBytes stores raw bytes in the cache with the specified TTL.
func (c *Cache) SetBytes(ctx context.Context, key string, value []byte, ttl time.Duration) error {
return c.provider.Set(ctx, key, value, ttl)
}
// Delete removes a key from the cache.
func (c *Cache) Delete(ctx context.Context, key string) error {
return c.provider.Delete(ctx, key)
}
// DeleteByPattern removes all keys matching the pattern.
func (c *Cache) DeleteByPattern(ctx context.Context, pattern string) error {
return c.provider.DeleteByPattern(ctx, pattern)
}
// Clear removes all items from the cache.
func (c *Cache) Clear(ctx context.Context) error {
return c.provider.Clear(ctx)
}
// Exists checks if a key exists in the cache.
func (c *Cache) Exists(ctx context.Context, key string) bool {
return c.provider.Exists(ctx, key)
}
// Stats returns statistics about the cache.
func (c *Cache) Stats(ctx context.Context) (*CacheStats, error) {
return c.provider.Stats(ctx)
}
// Close closes the cache and releases any resources.
func (c *Cache) Close() error {
return c.provider.Close()
}
// GetOrSet retrieves a value from cache, or sets it if it doesn't exist.
// The loader function is called only if the key is not found in cache.
func (c *Cache) GetOrSet(ctx context.Context, key string, dest interface{}, ttl time.Duration, loader func() (interface{}, error)) error {
// Try to get from cache first
err := c.Get(ctx, key, dest)
if err == nil {
return nil
}
// Load the value
value, err := loader()
if err != nil {
return fmt.Errorf("loader failed: %w", err)
}
// Store in cache
if err := c.Set(ctx, key, value, ttl); err != nil {
return fmt.Errorf("failed to cache value: %w", err)
}
// Populate dest with the loaded value
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to serialize loaded value: %w", err)
}
if err := json.Unmarshal(data, dest); err != nil {
return fmt.Errorf("failed to deserialize loaded value: %w", err)
}
return nil
}
// Remember is a convenience function that caches the result of a function call.
// It's similar to GetOrSet but returns the value directly.
func (c *Cache) Remember(ctx context.Context, key string, ttl time.Duration, loader func() (interface{}, error)) (interface{}, error) {
// Try to get from cache first as bytes
data, err := c.GetBytes(ctx, key)
if err == nil {
var result interface{}
if err := json.Unmarshal(data, &result); err == nil {
return result, nil
}
}
// Load the value
value, err := loader()
if err != nil {
return nil, fmt.Errorf("loader failed: %w", err)
}
// Store in cache
if err := c.Set(ctx, key, value, ttl); err != nil {
return nil, fmt.Errorf("failed to cache value: %w", err)
}
return value, nil
}

69
pkg/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,69 @@
package cache
import (
"context"
"testing"
"time"
)
func TestSetDefaultCache(t *testing.T) {
// Create a custom cache instance
provider := NewMemoryProvider(&Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 50,
})
customCache := NewCache(provider)
// Set it as the default
SetDefaultCache(customCache)
// Verify it's now the default
retrievedCache := GetDefaultCache()
if retrievedCache != customCache {
t.Error("SetDefaultCache did not set the cache correctly")
}
// Test that we can use it
ctx := context.Background()
testKey := "test_key"
testValue := "test_value"
err := retrievedCache.Set(ctx, testKey, testValue, time.Minute)
if err != nil {
t.Fatalf("Failed to set value: %v", err)
}
var result string
err = retrievedCache.Get(ctx, testKey, &result)
if err != nil {
t.Fatalf("Failed to get value: %v", err)
}
if result != testValue {
t.Errorf("Expected %s, got %s", testValue, result)
}
// Clean up - reset to default
SetDefaultCache(nil)
}
func TestGetDefaultCacheInitialization(t *testing.T) {
// Reset to nil first
SetDefaultCache(nil)
// GetDefaultCache should auto-initialize
cache := GetDefaultCache()
if cache == nil {
t.Error("GetDefaultCache should auto-initialize, got nil")
}
// Should be usable
ctx := context.Background()
err := cache.Set(ctx, "test", "value", time.Minute)
if err != nil {
t.Errorf("Failed to use auto-initialized cache: %v", err)
}
// Clean up
SetDefaultCache(nil)
}

266
pkg/cache/example_usage.go vendored Normal file
View File

@ -0,0 +1,266 @@
package cache
import (
"context"
"fmt"
"log"
"time"
)
// ExampleInMemoryCache demonstrates using the in-memory cache provider.
func ExampleInMemoryCache() {
// Initialize with in-memory provider
err := UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 1000,
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store a value
type User struct {
ID int
Name string
}
user := User{ID: 1, Name: "John Doe"}
err = cache.Set(ctx, "user:1", user, 10*time.Minute)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Retrieve a value
var retrieved User
err = cache.Get(ctx, "user:1", &retrieved)
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Retrieved user: %+v\n", retrieved)
// Check if key exists
exists := cache.Exists(ctx, "user:1")
fmt.Printf("Key exists: %v\n", exists)
// Delete a key
err = cache.Delete(ctx, "user:1")
if err != nil {
_ = Close()
log.Fatal(err)
}
// Get statistics
stats, err := cache.Stats(ctx)
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Cache stats: %+v\n", stats)
_ = Close()
}
// ExampleRedisCache demonstrates using the Redis cache provider.
func ExampleRedisCache() {
// Initialize with Redis provider
err := UseRedis(&RedisConfig{
Host: "localhost",
Port: 6379,
Password: "", // Set if Redis requires authentication
DB: 0,
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store raw bytes
data := []byte("Hello, Redis!")
err = cache.SetBytes(ctx, "greeting", data, 1*time.Hour)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Retrieve raw bytes
retrieved, err := cache.GetBytes(ctx, "greeting")
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Retrieved data: %s\n", string(retrieved))
// Clear all cache
err = cache.Clear(ctx)
if err != nil {
_ = Close()
log.Fatal(err)
}
_ = Close()
}
// ExampleMemcacheCache demonstrates using the Memcache cache provider.
func ExampleMemcacheCache() {
// Initialize with Memcache provider
err := UseMemcache(&MemcacheConfig{
Servers: []string{"localhost:11211"},
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
// Get the cache instance
cache := GetDefaultCache()
// Store a value
type Product struct {
ID int
Name string
Price float64
}
product := Product{ID: 100, Name: "Widget", Price: 29.99}
err = cache.Set(ctx, "product:100", product, 30*time.Minute)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Retrieve a value
var retrieved Product
err = cache.Get(ctx, "product:100", &retrieved)
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Retrieved product: %+v\n", retrieved)
_ = Close()
}
// ExampleGetOrSet demonstrates the GetOrSet pattern for lazy loading.
func ExampleGetOrSet() {
err := UseMemory(&Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 1000,
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
cache := GetDefaultCache()
type ExpensiveData struct {
Result string
}
var data ExpensiveData
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
// This expensive operation only runs if the key is not in cache
fmt.Println("Computing expensive result...")
time.Sleep(1 * time.Second)
return ExpensiveData{Result: "computed value"}, nil
})
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Data: %+v\n", data)
// Second call will use cached value
err = cache.GetOrSet(ctx, "expensive:computation", &data, 10*time.Minute, func() (interface{}, error) {
fmt.Println("This won't be called!")
return ExpensiveData{Result: "new value"}, nil
})
if err != nil {
_ = Close()
log.Fatal(err)
}
fmt.Printf("Cached data: %+v\n", data)
_ = Close()
}
// ExampleCustomProvider demonstrates using a custom provider.
func ExampleCustomProvider() {
// Create a custom provider
memProvider := NewMemoryProvider(&Options{
DefaultTTL: 10 * time.Minute,
MaxSize: 500,
})
// Initialize with custom provider
Initialize(memProvider)
ctx := context.Background()
cache := GetDefaultCache()
// Use the cache
err := cache.SetBytes(ctx, "key", []byte("value"), 5*time.Minute)
if err != nil {
_ = Close()
log.Fatal(err)
}
// Clean expired items (memory provider specific)
if mp, ok := cache.provider.(*MemoryProvider); ok {
count := mp.CleanExpired(ctx)
fmt.Printf("Cleaned %d expired items\n", count)
}
_ = Close()
}
// ExampleDeleteByPattern demonstrates pattern-based deletion (Redis only).
func ExampleDeleteByPattern() {
err := UseRedis(&RedisConfig{
Host: "localhost",
Port: 6379,
Options: &Options{
DefaultTTL: 5 * time.Minute,
},
})
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
cache := GetDefaultCache()
// Store multiple keys with a pattern
_ = cache.SetBytes(ctx, "user:1:profile", []byte("profile1"), 10*time.Minute)
_ = cache.SetBytes(ctx, "user:2:profile", []byte("profile2"), 10*time.Minute)
_ = cache.SetBytes(ctx, "user:1:settings", []byte("settings1"), 10*time.Minute)
// Delete all keys matching pattern (Redis glob pattern)
err = cache.DeleteByPattern(ctx, "user:*:profile")
if err != nil {
_ = Close()
log.Print(err)
return
}
fmt.Println("Deleted all user profile keys")
_ = Close()
}

57
pkg/cache/provider.go vendored Normal file
View File

@ -0,0 +1,57 @@
package cache
import (
"context"
"time"
)
// Provider defines the interface that all cache providers must implement.
type Provider interface {
// Get retrieves a value from the cache by key.
// Returns nil, false if key doesn't exist or is expired.
Get(ctx context.Context, key string) ([]byte, bool)
// Set stores a value in the cache with the specified TTL.
// If ttl is 0, the item never expires.
Set(ctx context.Context, key string, value []byte, ttl time.Duration) error
// Delete removes a key from the cache.
Delete(ctx context.Context, key string) error
// DeleteByPattern removes all keys matching the pattern.
// Pattern syntax depends on the provider implementation.
DeleteByPattern(ctx context.Context, pattern string) error
// Clear removes all items from the cache.
Clear(ctx context.Context) error
// Exists checks if a key exists in the cache.
Exists(ctx context.Context, key string) bool
// Close closes the provider and releases any resources.
Close() error
// Stats returns statistics about the cache provider.
Stats(ctx context.Context) (*CacheStats, error)
}
// CacheStats contains cache statistics.
type CacheStats struct {
Hits int64 `json:"hits"`
Misses int64 `json:"misses"`
Keys int64 `json:"keys"`
ProviderType string `json:"provider_type"`
ProviderStats map[string]any `json:"provider_stats,omitempty"`
}
// Options contains configuration options for cache providers.
type Options struct {
// DefaultTTL is the default time-to-live for cache items.
DefaultTTL time.Duration
// MaxSize is the maximum number of items (for in-memory provider).
MaxSize int
// EvictionPolicy determines how items are evicted (LRU, LFU, etc).
EvictionPolicy string
}

144
pkg/cache/provider_memcache.go vendored Normal file
View File

@ -0,0 +1,144 @@
package cache
import (
"context"
"fmt"
"time"
"github.com/bradfitz/gomemcache/memcache"
)
// MemcacheProvider is a Memcache implementation of the Provider interface.
type MemcacheProvider struct {
client *memcache.Client
options *Options
}
// MemcacheConfig contains Memcache-specific configuration.
type MemcacheConfig struct {
// Servers is a list of memcache server addresses (e.g., "localhost:11211")
Servers []string
// MaxIdleConns is the maximum number of idle connections (default: 2)
MaxIdleConns int
// Timeout for connection operations (default: 1 second)
Timeout time.Duration
// Options contains general cache options
Options *Options
}
// NewMemcacheProvider creates a new Memcache cache provider.
func NewMemcacheProvider(config *MemcacheConfig) (*MemcacheProvider, error) {
if config == nil {
config = &MemcacheConfig{
Servers: []string{"localhost:11211"},
}
}
if len(config.Servers) == 0 {
config.Servers = []string{"localhost:11211"}
}
if config.MaxIdleConns == 0 {
config.MaxIdleConns = 2
}
if config.Timeout == 0 {
config.Timeout = 1 * time.Second
}
if config.Options == nil {
config.Options = &Options{
DefaultTTL: 5 * time.Minute,
}
}
client := memcache.New(config.Servers...)
client.MaxIdleConns = config.MaxIdleConns
client.Timeout = config.Timeout
// Test connection
if err := client.Ping(); err != nil {
return nil, fmt.Errorf("failed to connect to Memcache: %w", err)
}
return &MemcacheProvider{
client: client,
options: config.Options,
}, nil
}
// Get retrieves a value from the cache by key.
func (m *MemcacheProvider) Get(ctx context.Context, key string) ([]byte, bool) {
item, err := m.client.Get(key)
if err == memcache.ErrCacheMiss {
return nil, false
}
if err != nil {
return nil, false
}
return item.Value, true
}
// Set stores a value in the cache with the specified TTL.
func (m *MemcacheProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if ttl == 0 {
ttl = m.options.DefaultTTL
}
item := &memcache.Item{
Key: key,
Value: value,
Expiration: int32(ttl.Seconds()),
}
return m.client.Set(item)
}
// Delete removes a key from the cache.
func (m *MemcacheProvider) Delete(ctx context.Context, key string) error {
err := m.client.Delete(key)
if err == memcache.ErrCacheMiss {
return nil
}
return err
}
// DeleteByPattern removes all keys matching the pattern.
// Note: Memcache does not support pattern-based deletion natively.
// This is a no-op for memcache and returns an error.
func (m *MemcacheProvider) DeleteByPattern(ctx context.Context, pattern string) error {
return fmt.Errorf("pattern-based deletion is not supported by Memcache")
}
// Clear removes all items from the cache.
func (m *MemcacheProvider) Clear(ctx context.Context) error {
return m.client.FlushAll()
}
// Exists checks if a key exists in the cache.
func (m *MemcacheProvider) Exists(ctx context.Context, key string) bool {
_, err := m.client.Get(key)
return err == nil
}
// Close closes the provider and releases any resources.
func (m *MemcacheProvider) Close() error {
// Memcache client doesn't have a close method
return nil
}
// Stats returns statistics about the cache provider.
// Note: Memcache provider returns limited statistics.
func (m *MemcacheProvider) Stats(ctx context.Context) (*CacheStats, error) {
stats := &CacheStats{
ProviderType: "memcache",
ProviderStats: map[string]any{
"note": "Memcache does not provide detailed statistics through the standard client",
},
}
return stats, nil
}

238
pkg/cache/provider_memory.go vendored Normal file
View File

@ -0,0 +1,238 @@
package cache
import (
"context"
"fmt"
"regexp"
"sync"
"sync/atomic"
"time"
)
// memoryItem represents a cached item in memory.
type memoryItem struct {
Value []byte
Expiration time.Time
LastAccess time.Time
HitCount int64
}
// isExpired checks if the item has expired.
func (m *memoryItem) isExpired() bool {
if m.Expiration.IsZero() {
return false
}
return time.Now().After(m.Expiration)
}
// MemoryProvider is an in-memory implementation of the Provider interface.
type MemoryProvider struct {
mu sync.RWMutex
items map[string]*memoryItem
options *Options
hits atomic.Int64
misses atomic.Int64
}
// NewMemoryProvider creates a new in-memory cache provider.
func NewMemoryProvider(opts *Options) *MemoryProvider {
if opts == nil {
opts = &Options{
DefaultTTL: 5 * time.Minute,
MaxSize: 10000,
}
}
return &MemoryProvider{
items: make(map[string]*memoryItem),
options: opts,
}
}
// Get retrieves a value from the cache by key.
func (m *MemoryProvider) Get(ctx context.Context, key string) ([]byte, bool) {
// First try with read lock for fast path
m.mu.RLock()
item, exists := m.items[key]
if !exists {
m.mu.RUnlock()
m.misses.Add(1)
return nil, false
}
if item.isExpired() {
m.mu.RUnlock()
// Upgrade to write lock to delete expired item
m.mu.Lock()
delete(m.items, key)
m.mu.Unlock()
m.misses.Add(1)
return nil, false
}
// Update stats and access time with write lock
value := item.Value
m.mu.RUnlock()
// Update access tracking with write lock
m.mu.Lock()
item.LastAccess = time.Now()
item.HitCount++
m.mu.Unlock()
m.hits.Add(1)
return value, true
}
// Set stores a value in the cache with the specified TTL.
func (m *MemoryProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
m.mu.Lock()
defer m.mu.Unlock()
if ttl == 0 {
ttl = m.options.DefaultTTL
}
var expiration time.Time
if ttl > 0 {
expiration = time.Now().Add(ttl)
}
// Check max size and evict if necessary
if m.options.MaxSize > 0 && len(m.items) >= m.options.MaxSize {
if _, exists := m.items[key]; !exists {
m.evictOne()
}
}
m.items[key] = &memoryItem{
Value: value,
Expiration: expiration,
LastAccess: time.Now(),
}
return nil
}
// Delete removes a key from the cache.
func (m *MemoryProvider) Delete(ctx context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.items, key)
return nil
}
// DeleteByPattern removes all keys matching the pattern.
func (m *MemoryProvider) DeleteByPattern(ctx context.Context, pattern string) error {
m.mu.Lock()
defer m.mu.Unlock()
re, err := regexp.Compile(pattern)
if err != nil {
return fmt.Errorf("invalid pattern: %w", err)
}
for key := range m.items {
if re.MatchString(key) {
delete(m.items, key)
}
}
return nil
}
// Clear removes all items from the cache.
func (m *MemoryProvider) Clear(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
m.items = make(map[string]*memoryItem)
m.hits.Store(0)
m.misses.Store(0)
return nil
}
// Exists checks if a key exists in the cache.
func (m *MemoryProvider) Exists(ctx context.Context, key string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
item, exists := m.items[key]
if !exists {
return false
}
return !item.isExpired()
}
// Close closes the provider and releases any resources.
func (m *MemoryProvider) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
m.items = nil
return nil
}
// Stats returns statistics about the cache provider.
func (m *MemoryProvider) Stats(ctx context.Context) (*CacheStats, error) {
m.mu.RLock()
defer m.mu.RUnlock()
// Clean expired items first
validKeys := 0
for _, item := range m.items {
if !item.isExpired() {
validKeys++
}
}
return &CacheStats{
Hits: m.hits.Load(),
Misses: m.misses.Load(),
Keys: int64(validKeys),
ProviderType: "memory",
ProviderStats: map[string]any{
"capacity": m.options.MaxSize,
},
}, nil
}
// evictOne removes one item from the cache using LRU strategy.
func (m *MemoryProvider) evictOne() {
var oldestKey string
var oldestTime time.Time
for key, item := range m.items {
if item.isExpired() {
delete(m.items, key)
return
}
if oldestKey == "" || item.LastAccess.Before(oldestTime) {
oldestKey = key
oldestTime = item.LastAccess
}
}
if oldestKey != "" {
delete(m.items, oldestKey)
}
}
// CleanExpired removes all expired items from the cache.
func (m *MemoryProvider) CleanExpired(ctx context.Context) int {
m.mu.Lock()
defer m.mu.Unlock()
count := 0
for key, item := range m.items {
if item.isExpired() {
delete(m.items, key)
count++
}
}
return count
}

185
pkg/cache/provider_redis.go vendored Normal file
View File

@ -0,0 +1,185 @@
package cache
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
// RedisProvider is a Redis implementation of the Provider interface.
type RedisProvider struct {
client *redis.Client
options *Options
}
// RedisConfig contains Redis-specific configuration.
type RedisConfig struct {
// Host is the Redis server host (default: localhost)
Host string
// Port is the Redis server port (default: 6379)
Port int
// Password for Redis authentication (optional)
Password string
// DB is the Redis database number (default: 0)
DB int
// PoolSize is the maximum number of connections (default: 10)
PoolSize int
// Options contains general cache options
Options *Options
}
// NewRedisProvider creates a new Redis cache provider.
func NewRedisProvider(config *RedisConfig) (*RedisProvider, error) {
if config == nil {
config = &RedisConfig{
Host: "localhost",
Port: 6379,
DB: 0,
}
}
if config.Host == "" {
config.Host = "localhost"
}
if config.Port == 0 {
config.Port = 6379
}
if config.PoolSize == 0 {
config.PoolSize = 10
}
if config.Options == nil {
config.Options = &Options{
DefaultTTL: 5 * time.Minute,
}
}
client := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", config.Host, config.Port),
Password: config.Password,
DB: config.DB,
PoolSize: config.PoolSize,
})
// Test connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
return &RedisProvider{
client: client,
options: config.Options,
}, nil
}
// Get retrieves a value from the cache by key.
func (r *RedisProvider) Get(ctx context.Context, key string) ([]byte, bool) {
val, err := r.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, false
}
if err != nil {
return nil, false
}
return val, true
}
// Set stores a value in the cache with the specified TTL.
func (r *RedisProvider) Set(ctx context.Context, key string, value []byte, ttl time.Duration) error {
if ttl == 0 {
ttl = r.options.DefaultTTL
}
return r.client.Set(ctx, key, value, ttl).Err()
}
// Delete removes a key from the cache.
func (r *RedisProvider) Delete(ctx context.Context, key string) error {
return r.client.Del(ctx, key).Err()
}
// DeleteByPattern removes all keys matching the pattern.
func (r *RedisProvider) DeleteByPattern(ctx context.Context, pattern string) error {
iter := r.client.Scan(ctx, 0, pattern, 0).Iterator()
pipe := r.client.Pipeline()
count := 0
for iter.Next(ctx) {
pipe.Del(ctx, iter.Val())
count++
// Execute pipeline in batches of 100
if count%100 == 0 {
if _, err := pipe.Exec(ctx); err != nil {
return err
}
pipe = r.client.Pipeline()
}
}
if err := iter.Err(); err != nil {
return err
}
// Execute remaining commands
if count%100 != 0 {
_, err := pipe.Exec(ctx)
return err
}
return nil
}
// Clear removes all items from the cache.
func (r *RedisProvider) Clear(ctx context.Context) error {
return r.client.FlushDB(ctx).Err()
}
// Exists checks if a key exists in the cache.
func (r *RedisProvider) Exists(ctx context.Context, key string) bool {
result, err := r.client.Exists(ctx, key).Result()
if err != nil {
return false
}
return result > 0
}
// Close closes the provider and releases any resources.
func (r *RedisProvider) Close() error {
return r.client.Close()
}
// Stats returns statistics about the cache provider.
func (r *RedisProvider) Stats(ctx context.Context) (*CacheStats, error) {
info, err := r.client.Info(ctx, "stats", "keyspace").Result()
if err != nil {
return nil, fmt.Errorf("failed to get Redis stats: %w", err)
}
dbSize, err := r.client.DBSize(ctx).Result()
if err != nil {
return nil, fmt.Errorf("failed to get DB size: %w", err)
}
// Parse stats from INFO command
// This is a simplified version - you may want to parse more detailed stats
stats := &CacheStats{
Keys: dbSize,
ProviderType: "redis",
ProviderStats: map[string]any{
"info": info,
},
}
return stats, nil
}

127
pkg/cache/query_cache.go vendored Normal file
View File

@ -0,0 +1,127 @@
package cache
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// QueryCacheKey represents the components used to build a cache key for query total count
type QueryCacheKey struct {
TableName string `json:"table_name"`
Filters []common.FilterOption `json:"filters"`
Sort []common.SortOption `json:"sort"`
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
CustomSQLOr string `json:"custom_sql_or,omitempty"`
Expand []ExpandOptionKey `json:"expand,omitempty"`
Distinct bool `json:"distinct,omitempty"`
CursorForward string `json:"cursor_forward,omitempty"`
CursorBackward string `json:"cursor_backward,omitempty"`
}
// ExpandOptionKey represents expand options for cache key
type ExpandOptionKey struct {
Relation string `json:"relation"`
Where string `json:"where,omitempty"`
}
// BuildQueryCacheKey builds a cache key from query parameters for total count caching
// This is used to cache the total count of records matching a query
func BuildQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption, customWhere, customOr string) string {
key := QueryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s", tableName, filters, sort, customWhere, customOr))
}
return hashString(string(jsonData))
}
// BuildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
// Includes expand, distinct, and cursor pagination options
func BuildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
key := QueryCacheKey{
TableName: tableName,
Filters: filters,
Sort: sort,
CustomSQLWhere: customWhere,
CustomSQLOr: customOr,
Distinct: distinct,
CursorForward: cursorFwd,
CursorBackward: cursorBwd,
}
// Convert expand options to cache key format
if len(expandOpts) > 0 {
key.Expand = make([]ExpandOptionKey, 0, len(expandOpts))
for _, exp := range expandOpts {
// Type assert to get the expand option fields we care about for caching
if expMap, ok := exp.(map[string]interface{}); ok {
expKey := ExpandOptionKey{}
if rel, ok := expMap["relation"].(string); ok {
expKey.Relation = rel
}
if where, ok := expMap["where"].(string); ok {
expKey.Where = where
}
key.Expand = append(key.Expand, expKey)
}
}
// Sort expand options for consistent hashing (already sorted by relation name above)
}
// Serialize to JSON for consistent hashing
jsonData, err := json.Marshal(key)
if err != nil {
// Fallback to simple string concatenation if JSON fails
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s",
tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd))
}
return hashString(string(jsonData))
}
// hashString computes SHA256 hash of a string
func hashString(s string) string {
h := sha256.New()
h.Write([]byte(s))
return hex.EncodeToString(h.Sum(nil))
}
// GetQueryTotalCacheKey returns a formatted cache key for storing/retrieving total count
func GetQueryTotalCacheKey(hash string) string {
return fmt.Sprintf("query_total:%s", hash)
}
// CachedTotal represents a cached total count
type CachedTotal struct {
Total int `json:"total"`
}
// InvalidateCacheForTable removes all cached totals for a specific table
// This should be called when data in the table changes (insert/update/delete)
func InvalidateCacheForTable(ctx context.Context, tableName string) error {
cache := GetDefaultCache()
// Build a pattern to match all query totals for this table
// Note: This requires pattern matching support in the provider
pattern := fmt.Sprintf("query_total:*%s*", strings.ToLower(tableName))
return cache.DeleteByPattern(ctx, pattern)
}

151
pkg/cache/query_cache_test.go vendored Normal file
View File

@ -0,0 +1,151 @@
package cache
import (
"context"
"testing"
"time"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
func TestBuildQueryCacheKey(t *testing.T) {
filters := []common.FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
{Column: "age", Operator: "gt", Value: 25},
}
sorts := []common.SortOption{
{Column: "name", Direction: "asc"},
}
// Generate cache key
key1 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
// Same parameters should generate same key
key2 := BuildQueryCacheKey("users", filters, sorts, "status = 'active'", "")
if key1 != key2 {
t.Errorf("Expected same cache keys for identical parameters, got %s and %s", key1, key2)
}
// Different parameters should generate different key
key3 := BuildQueryCacheKey("users", filters, sorts, "status = 'inactive'", "")
if key1 == key3 {
t.Errorf("Expected different cache keys for different parameters, got %s and %s", key1, key3)
}
}
func TestBuildExtendedQueryCacheKey(t *testing.T) {
filters := []common.FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
}
sorts := []common.SortOption{
{Column: "name", Direction: "asc"},
}
expandOpts := []interface{}{
map[string]interface{}{
"relation": "posts",
"where": "status = 'published'",
},
}
// Generate cache key
key1 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
// Same parameters should generate same key
key2 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, false, "", "")
if key1 != key2 {
t.Errorf("Expected same cache keys for identical parameters")
}
// Different distinct value should generate different key
key3 := BuildExtendedQueryCacheKey("users", filters, sorts, "", "", expandOpts, true, "", "")
if key1 == key3 {
t.Errorf("Expected different cache keys for different distinct values")
}
}
func TestGetQueryTotalCacheKey(t *testing.T) {
hash := "abc123"
key := GetQueryTotalCacheKey(hash)
expected := "query_total:abc123"
if key != expected {
t.Errorf("Expected %s, got %s", expected, key)
}
}
func TestCachedTotalIntegration(t *testing.T) {
// Initialize cache with memory provider for testing
UseMemory(&Options{
DefaultTTL: 1 * time.Minute,
MaxSize: 100,
})
ctx := context.Background()
// Create test data
filters := []common.FilterOption{
{Column: "status", Operator: "eq", Value: "active"},
}
sorts := []common.SortOption{
{Column: "created_at", Direction: "desc"},
}
// Build cache key
cacheKeyHash := BuildQueryCacheKey("test_table", filters, sorts, "", "")
cacheKey := GetQueryTotalCacheKey(cacheKeyHash)
// Store a total count in cache
totalToCache := CachedTotal{Total: 42}
err := GetDefaultCache().Set(ctx, cacheKey, totalToCache, time.Minute)
if err != nil {
t.Fatalf("Failed to set cache: %v", err)
}
// Retrieve from cache
var cachedTotal CachedTotal
err = GetDefaultCache().Get(ctx, cacheKey, &cachedTotal)
if err != nil {
t.Fatalf("Failed to get from cache: %v", err)
}
if cachedTotal.Total != 42 {
t.Errorf("Expected total 42, got %d", cachedTotal.Total)
}
// Test cache miss
nonExistentKey := GetQueryTotalCacheKey("nonexistent")
var missedTotal CachedTotal
err = GetDefaultCache().Get(ctx, nonExistentKey, &missedTotal)
if err == nil {
t.Errorf("Expected error for cache miss, got nil")
}
}
func TestHashString(t *testing.T) {
input1 := "test string"
input2 := "test string"
input3 := "different string"
hash1 := hashString(input1)
hash2 := hashString(input2)
hash3 := hashString(input3)
// Same input should produce same hash
if hash1 != hash2 {
t.Errorf("Expected same hash for identical inputs")
}
// Different input should produce different hash
if hash1 == hash3 {
t.Errorf("Expected different hash for different inputs")
}
// Hash should be hex encoded SHA256 (64 characters)
if len(hash1) != 64 {
t.Errorf("Expected hash length of 64, got %d", len(hash1))
}
}

View File

@ -0,0 +1,218 @@
# Automatic Relation Loading Strategies
## Overview
**NEW:** The database adapters now **automatically** choose the optimal loading strategy by inspecting your model's relationship tags!
Simply use `PreloadRelation()` and the system automatically:
- Detects relationship type from Bun/GORM tags
- Uses **JOIN** for many-to-one and one-to-one (efficient, no duplication)
- Uses **separate query** for one-to-many and many-to-many (avoids duplication)
## How It Works
```go
// Just write this - the system handles the rest!
db.NewSelect().
Model(&links).
PreloadRelation("Provider"). // ✓ Auto-detects belongs-to → uses JOIN
PreloadRelation("Tags"). // ✓ Auto-detects has-many → uses separate query
Scan(ctx, &links)
```
### Detection Logic
The system inspects your model's struct tags:
**Bun models:**
```go
type Link struct {
Provider *Provider `bun:"rel:belongs-to"` // → Detected: belongs-to → JOIN
Tags []Tag `bun:"rel:has-many"` // → Detected: has-many → Separate query
}
```
**GORM models:**
```go
type Link struct {
ProviderID int
Provider *Provider `gorm:"foreignKey:ProviderID"` // → Detected: belongs-to → JOIN
Tags []Tag `gorm:"many2many:link_tags"` // → Detected: many-to-many → Separate query
}
```
**Type inference (fallback):**
- `[]Type` (slice) → has-many → Separate query
- `*Type` (pointer) → belongs-to → JOIN
- `Type` (struct) → belongs-to → JOIN
### What Gets Logged
Enable debug logging to see strategy selection:
```go
bunAdapter.EnableQueryDebug()
```
**Output:**
```
DEBUG: PreloadRelation 'Provider' detected as: belongs-to
INFO: Using JOIN strategy for belongs-to relation 'Provider'
DEBUG: PreloadRelation 'Links' detected as: has-many
DEBUG: Using separate query for has-many relation 'Links'
```
## Relationship Types
| Bun Tag | GORM Pattern | Field Type | Strategy | Why |
|---------|--------------|------------|----------|-----|
| `rel:has-many` | Slice field | `[]Type` | Separate Query | Avoids duplicating parent data |
| `rel:belongs-to` | `foreignKey:` | `*Type` | JOIN | Single parent, no duplication |
| `rel:has-one` | Single pointer | `*Type` | JOIN | One-to-one, no duplication |
| `rel:many-to-many` | `many2many:` | `[]Type` | Separate Query | Complex join, avoid cartesian |
## Manual Override
If you need to force a specific strategy, use `JoinRelation()`:
```go
// Force JOIN even for has-many (not recommended)
db.NewSelect().
Model(&providers).
JoinRelation("Links"). // Explicitly use JOIN
Scan(ctx, &providers)
```
## Examples
### Automatic Strategy Selection (Recommended)
```go
// Example 1: Loading parent provider for each link
// System detects belongs-to → uses JOIN automatically
db.NewSelect().
Model(&links).
PreloadRelation("Provider", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &links)
// Generated SQL: Single query with JOIN
// SELECT links.*, providers.*
// FROM links
// LEFT JOIN providers ON links.provider_id = providers.id
// WHERE providers.active = true
// Example 2: Loading child links for each provider
// System detects has-many → uses separate query automatically
db.NewSelect().
Model(&providers).
PreloadRelation("Links", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &providers)
// Generated SQL: Two queries
// Query 1: SELECT * FROM providers
// Query 2: SELECT * FROM links
// WHERE provider_id IN (1, 2, 3, ...)
// AND active = true
```
### Mixed Relationships
```go
type Order struct {
ID int
CustomerID int
Customer *Customer `bun:"rel:belongs-to"` // JOIN
Items []Item `bun:"rel:has-many"` // Separate
Invoice *Invoice `bun:"rel:has-one"` // JOIN
}
// All three handled optimally!
db.NewSelect().
Model(&orders).
PreloadRelation("Customer"). // → JOIN (many-to-one)
PreloadRelation("Items"). // → Separate (one-to-many)
PreloadRelation("Invoice"). // → JOIN (one-to-one)
Scan(ctx, &orders)
```
## Performance Benefits
### Before (Manual Strategy Selection)
```go
// You had to remember which to use:
.PreloadRelation("Provider") // Should I use PreloadRelation or JoinRelation?
.PreloadRelation("Links") // Which is more efficient here?
```
### After (Automatic Selection)
```go
// Just use PreloadRelation everywhere:
.PreloadRelation("Provider") // ✓ System uses JOIN automatically
.PreloadRelation("Links") // ✓ System uses separate query automatically
```
## Migration Guide
**No changes needed!** If you're already using `PreloadRelation()`, it now automatically optimizes:
```go
// Before: Always used separate query
.PreloadRelation("Provider") // Inefficient: extra round trip
// After: Automatic optimization
.PreloadRelation("Provider") // ✓ Now uses JOIN automatically!
```
## Implementation Details
### Supported Bun Tags
- `rel:has-many` → Separate query
- `rel:belongs-to` → JOIN
- `rel:has-one` → JOIN
- `rel:many-to-many` or `rel:m2m` → Separate query
### Supported GORM Patterns
- `many2many:` tag → Separate query
- `foreignKey:` tag → JOIN (belongs-to)
- `[]Type` slice without many2many → Separate query (has-many)
- `*Type` pointer with foreignKey → JOIN (belongs-to)
- `*Type` pointer without foreignKey → JOIN (has-one)
### Fallback Behavior
- `[]Type` (slice) → Separate query (safe default for collections)
- `*Type` or `Type` (single) → JOIN (safe default for single relations)
- Unknown → Separate query (safest default)
## Debugging
To see strategy selection in action:
```go
// Enable debug logging
bunAdapter.EnableQueryDebug() // or gormAdapter.EnableQueryDebug()
// Run your query
db.NewSelect().
Model(&records).
PreloadRelation("RelationName").
Scan(ctx, &records)
// Check logs for:
// - "PreloadRelation 'X' detected as: belongs-to"
// - "Using JOIN strategy for belongs-to relation 'X'"
// - Actual SQL queries executed
```
## Best Practices
1. **Use PreloadRelation() for everything** - Let the system optimize
2. **Define proper relationship tags** - Ensures correct detection
3. **Only use JoinRelation() for overrides** - When you know better than auto-detection
4. **Enable debug logging during development** - Verify optimal strategies are chosen
5. **Trust the system** - It's designed to choose correctly based on relationship type

View File

@ -0,0 +1,81 @@
package database
import (
"testing"
)
func TestNormalizeTableAlias(t *testing.T) {
tests := []struct {
name string
query string
expectedAlias string
tableName string
want string
}{
{
name: "strips plausible alias from simple condition",
query: "APIL.rid_hub = 2576",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = 2576",
},
{
name: "keeps correct alias",
query: "apiproviderlink.rid_hub = 2576",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "apiproviderlink.rid_hub = 2576",
},
{
name: "strips plausible alias with multiple conditions",
query: "APIL.rid_hub = ? AND APIL.active = ?",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = ? AND active = ?",
},
{
name: "handles mixed correct and plausible aliases",
query: "APIL.rid_hub = ? AND apiproviderlink.active = ?",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = ? AND apiproviderlink.active = ?",
},
{
name: "handles parentheses",
query: "(APIL.rid_hub = ?)",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "(rid_hub = ?)",
},
{
name: "no alias in query",
query: "rid_hub = ?",
expectedAlias: "apiproviderlink",
tableName: "apiproviderlink",
want: "rid_hub = ?",
},
{
name: "keeps reference to different table (not in current table name)",
query: "APIL.rid_hub = ?",
expectedAlias: "apiprovider",
tableName: "apiprovider",
want: "APIL.rid_hub = ?",
},
{
name: "keeps reference with short prefix that might be ambiguous",
query: "AP.rid = ?",
expectedAlias: "apiprovider",
tableName: "apiprovider",
want: "AP.rid = ?",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeTableAlias(tt.query, tt.expectedAlias, tt.tableName)
if got != tt.want {
t.Errorf("normalizeTableAlias() = %q, want %q", got, tt.want)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,213 @@
package database
import (
"context"
"database/sql"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/sqlitedialect"
"github.com/uptrace/bun/driver/sqliteshim"
)
// TestInsertModel is a test model for insert operations
type TestInsertModel struct {
bun.BaseModel `bun:"table:test_inserts"`
ID int64 `bun:"id,pk,autoincrement"`
Name string `bun:"name,notnull"`
Email string `bun:"email"`
Age int `bun:"age"`
}
func setupBunTestDB(t *testing.T) *bun.DB {
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
require.NoError(t, err, "Failed to open SQLite database")
db := bun.NewDB(sqldb, sqlitedialect.New())
// Create test table
_, err = db.NewCreateTable().
Model((*TestInsertModel)(nil)).
IfNotExists().
Exec(context.Background())
require.NoError(t, err, "Failed to create test table")
return db
}
func TestBunInsertQuery_Model(t *testing.T) {
db := setupBunTestDB(t)
defer db.Close()
adapter := NewBunAdapter(db)
ctx := context.Background()
// Test inserting with Model()
model := &TestInsertModel{
Name: "John Doe",
Email: "john@example.com",
Age: 30,
}
result, err := adapter.NewInsert().
Model(model).
Returning("*").
Exec(ctx)
require.NoError(t, err, "Insert should succeed")
assert.Equal(t, int64(1), result.RowsAffected(), "Should insert 1 row")
// Verify the data was inserted
var retrieved TestInsertModel
err = db.NewSelect().
Model(&retrieved).
Where("id = ?", model.ID).
Scan(ctx)
require.NoError(t, err, "Should retrieve inserted row")
assert.Equal(t, "John Doe", retrieved.Name)
assert.Equal(t, "john@example.com", retrieved.Email)
assert.Equal(t, 30, retrieved.Age)
}
func TestBunInsertQuery_Value(t *testing.T) {
db := setupBunTestDB(t)
defer db.Close()
adapter := NewBunAdapter(db)
ctx := context.Background()
// Test inserting with Value() method - this was the bug
result, err := adapter.NewInsert().
Table("test_inserts").
Value("name", "Jane Smith").
Value("email", "jane@example.com").
Value("age", 25).
Exec(ctx)
require.NoError(t, err, "Insert with Value() should succeed")
assert.Equal(t, int64(1), result.RowsAffected(), "Should insert 1 row")
// Verify the data was inserted
var retrieved TestInsertModel
err = db.NewSelect().
Model(&retrieved).
Where("name = ?", "Jane Smith").
Scan(ctx)
require.NoError(t, err, "Should retrieve inserted row")
assert.Equal(t, "Jane Smith", retrieved.Name)
assert.Equal(t, "jane@example.com", retrieved.Email)
assert.Equal(t, 25, retrieved.Age)
}
func TestBunInsertQuery_MultipleValues(t *testing.T) {
db := setupBunTestDB(t)
defer db.Close()
adapter := NewBunAdapter(db)
ctx := context.Background()
// Test inserting multiple values
result, err := adapter.NewInsert().
Table("test_inserts").
Value("name", "Alice").
Value("email", "alice@example.com").
Value("age", 28).
Exec(ctx)
require.NoError(t, err, "First insert should succeed")
assert.Equal(t, int64(1), result.RowsAffected())
result, err = adapter.NewInsert().
Table("test_inserts").
Value("name", "Bob").
Value("email", "bob@example.com").
Value("age", 35).
Exec(ctx)
require.NoError(t, err, "Second insert should succeed")
assert.Equal(t, int64(1), result.RowsAffected())
// Verify both rows exist
var count int
count, err = db.NewSelect().
Model((*TestInsertModel)(nil)).
Count(ctx)
require.NoError(t, err, "Count should succeed")
assert.Equal(t, 2, count, "Should have 2 rows")
}
func TestBunInsertQuery_ValueWithNil(t *testing.T) {
db := setupBunTestDB(t)
defer db.Close()
adapter := NewBunAdapter(db)
ctx := context.Background()
// Test inserting with nil value for nullable field
result, err := adapter.NewInsert().
Table("test_inserts").
Value("name", "Test User").
Value("email", nil). // NULL email
Value("age", 20).
Exec(ctx)
require.NoError(t, err, "Insert with nil value should succeed")
assert.Equal(t, int64(1), result.RowsAffected())
// Verify the data was inserted with NULL email
var retrieved TestInsertModel
err = db.NewSelect().
Model(&retrieved).
Where("name = ?", "Test User").
Scan(ctx)
require.NoError(t, err, "Should retrieve inserted row")
assert.Equal(t, "Test User", retrieved.Name)
assert.Equal(t, "", retrieved.Email) // NULL becomes empty string
assert.Equal(t, 20, retrieved.Age)
}
func TestBunInsertQuery_Returning(t *testing.T) {
db := setupBunTestDB(t)
defer db.Close()
adapter := NewBunAdapter(db)
ctx := context.Background()
// Test insert with RETURNING clause
// Note: SQLite has limited RETURNING support, but this tests the API
result, err := adapter.NewInsert().
Table("test_inserts").
Value("name", "Return Test").
Value("email", "return@example.com").
Value("age", 40).
Returning("*").
Exec(ctx)
require.NoError(t, err, "Insert with RETURNING should succeed")
assert.Equal(t, int64(1), result.RowsAffected())
}
func TestBunInsertQuery_EmptyValues(t *testing.T) {
db := setupBunTestDB(t)
defer db.Close()
adapter := NewBunAdapter(db)
ctx := context.Background()
// Test insert without calling Value() - should use Model() or fail gracefully
result, err := adapter.NewInsert().
Table("test_inserts").
Exec(ctx)
// This should fail because no values are provided
assert.Error(t, err, "Insert without values should fail")
if result != nil {
assert.Equal(t, int64(0), result.RowsAffected())
}
}

View File

@ -2,9 +2,15 @@ package database
import ( import (
"context" "context"
"fmt"
"strings"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
"gorm.io/gorm" "gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
) )
// GormAdapter adapts GORM to work with our Database interface // GormAdapter adapts GORM to work with our Database interface
@ -17,6 +23,22 @@ func NewGormAdapter(db *gorm.DB) *GormAdapter {
return &GormAdapter{db: db} return &GormAdapter{db: db}
} }
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
// This is useful for debugging preload queries that may be failing
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
g.db = g.db.Debug()
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
return g
}
// DisableQueryDebug disables query debugging
func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
// GORM's Debug() creates a new session, so we need to get the base DB
// This is a simplified implementation
logger.Info("GORM debug mode - create a new adapter without Debug() to disable")
return g
}
func (g *GormAdapter) NewSelect() common.SelectQuery { func (g *GormAdapter) NewSelect() common.SelectQuery {
return &GormSelectQuery{db: g.db} return &GormSelectQuery{db: g.db}
} }
@ -33,12 +55,22 @@ func (g *GormAdapter) NewDelete() common.DeleteQuery {
return &GormDeleteQuery{db: g.db} return &GormDeleteQuery{db: g.db}
} }
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) { func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.Exec", r)
}
}()
result := g.db.WithContext(ctx).Exec(query, args...) result := g.db.WithContext(ctx).Exec(query, args...)
return &GormResult{result: result}, result.Error return &GormResult{result: result}, result.Error
} }
func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error { func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.Query", r)
}
}()
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
} }
@ -58,25 +90,54 @@ func (g *GormAdapter) RollbackTx(ctx context.Context) error {
return g.db.WithContext(ctx).Rollback().Error return g.db.WithContext(ctx).Rollback().Error
} }
func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) error { func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Database) error) (err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
}
}()
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
adapter := &GormAdapter{db: tx} adapter := &GormAdapter{db: tx}
return fn(adapter) return fn(adapter)
}) })
} }
func (g *GormAdapter) GetUnderlyingDB() interface{} {
return g.db
}
// GormSelectQuery implements SelectQuery for GORM // GormSelectQuery implements SelectQuery for GORM
type GormSelectQuery struct { type GormSelectQuery struct {
db *gorm.DB db *gorm.DB
schema string // Separated schema name
tableName string // Just the table name, without schema
tableAlias string
inJoinContext bool // Track if we're in a JOIN relation context
joinTableAlias string // Alias to use for JOIN conditions
} }
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery { func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
g.db = g.db.Model(model) g.db = g.db.Model(model)
// Try to get table name from model if it implements TableNameProvider
if provider, ok := model.(common.TableNameProvider); ok {
fullTableName := provider.TableName()
// Check if the table name contains schema (e.g., "schema.table")
g.schema, g.tableName = parseTableName(fullTableName)
}
if provider, ok := model.(common.TableAliasProvider); ok {
g.tableAlias = provider.TableAlias()
}
return g return g
} }
func (g *GormSelectQuery) Table(table string) common.SelectQuery { func (g *GormSelectQuery) Table(table string) common.SelectQuery {
g.db = g.db.Table(table) g.db = g.db.Table(table)
// Check if the table name contains schema (e.g., "schema.table")
g.schema, g.tableName = parseTableName(table)
return g return g
} }
@ -85,23 +146,146 @@ func (g *GormSelectQuery) Column(columns ...string) common.SelectQuery {
return g return g
} }
func (g *GormSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
if len(args) > 0 {
g.db = g.db.Select(query, args...)
} else {
g.db = g.db.Select(query)
}
return g
}
func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
// If we're in a JOIN context, add table prefix to unqualified columns
if g.inJoinContext && g.joinTableAlias != "" {
query = addTablePrefixGorm(query, g.joinTableAlias)
}
g.db = g.db.Where(query, args...) g.db = g.db.Where(query, args...)
return g return g
} }
// addTablePrefixGorm adds a table prefix to unqualified column references (GORM version)
func addTablePrefixGorm(query, tableAlias string) string {
if tableAlias == "" || query == "" {
return query
}
// Split on spaces and parentheses to find column references
parts := strings.FieldsFunc(query, func(r rune) bool {
return r == ' ' || r == '(' || r == ')' || r == ','
})
modified := query
for _, part := range parts {
// Check if this looks like an unqualified column reference
if !strings.Contains(part, ".") {
// Extract potential column name (before = or other operators)
for _, op := range []string{"=", "!=", "<>", ">", ">=", "<", "<=", " LIKE ", " IN ", " IS "} {
if strings.Contains(part, op) {
colName := strings.Split(part, op)[0]
colName = strings.TrimSpace(colName)
if colName != "" && !isOperatorOrKeywordGorm(colName) {
// Add table prefix
prefixed := tableAlias + "." + colName + strings.TrimPrefix(part, colName)
modified = strings.ReplaceAll(modified, part, prefixed)
logger.Debug("Adding table prefix '%s' to column '%s' in JOIN condition", tableAlias, colName)
}
break
}
}
}
}
return modified
}
// isOperatorOrKeywordGorm checks if a string is likely an operator or SQL keyword (GORM version)
func isOperatorOrKeywordGorm(s string) bool {
s = strings.ToUpper(strings.TrimSpace(s))
keywords := []string{"AND", "OR", "NOT", "IN", "IS", "NULL", "TRUE", "FALSE", "LIKE", "BETWEEN"}
for _, kw := range keywords {
if s == kw {
return true
}
}
return false
}
func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
g.db = g.db.Or(query, args...) g.db = g.db.Or(query, args...)
return g return g
} }
func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) Join(query string, args ...interface{}) common.SelectQuery {
g.db = g.db.Joins(query, args...) // Extract optional prefix from args
// If the last arg is a string that looks like a table prefix, use it
var prefix string
sqlArgs := args
if len(args) > 0 {
if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") {
// Likely a prefix, not a SQL parameter
prefix = lastArg
sqlArgs = args[:len(args)-1]
}
}
// If no prefix provided, use the table name as prefix (already separated from schema)
if prefix == "" && g.tableName != "" {
prefix = g.tableName
}
// If prefix is provided, add it as an alias in the join
// GORM expects: "JOIN table AS alias ON condition"
joinClause := query
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
// If query doesn't already have AS, check if it's a simple table name
parts := strings.Fields(query)
if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") {
// Simple table name, add prefix: "table AS prefix"
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
if len(parts) > 1 {
// Has ON clause: "table ON ..." becomes "table AS prefix ON ..."
joinClause += " " + strings.Join(parts[1:], " ")
}
}
}
g.db = g.db.Joins(joinClause, sqlArgs...)
return g return g
} }
func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery { func (g *GormSelectQuery) LeftJoin(query string, args ...interface{}) common.SelectQuery {
g.db = g.db.Joins("LEFT JOIN "+query, args...) // Extract optional prefix from args
var prefix string
sqlArgs := args
if len(args) > 0 {
if lastArg, ok := args[len(args)-1].(string); ok && len(lastArg) < 50 && !strings.Contains(lastArg, " ") {
prefix = lastArg
sqlArgs = args[:len(args)-1]
}
}
// If no prefix provided, use the table name as prefix (already separated from schema)
if prefix == "" && g.tableName != "" {
prefix = g.tableName
}
// Construct LEFT JOIN with prefix
joinClause := query
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
parts := strings.Fields(query)
if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "LEFT") && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") {
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
if len(parts) > 1 {
joinClause += " " + strings.Join(parts[1:], " ")
}
}
}
g.db = g.db.Joins("LEFT JOIN "+joinClause, sqlArgs...)
return g return g
} }
@ -110,6 +294,93 @@ func (g *GormSelectQuery) Preload(relation string, conditions ...interface{}) co
return g return g
} }
func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// Auto-detect relationship type and choose optimal loading strategy
// Get the model from GORM's statement if available
if g.db.Statement != nil && g.db.Statement.Model != nil {
relType := reflection.GetRelationType(g.db.Statement.Model, relation)
// Log the detected relationship type
logger.Debug("PreloadRelation '%s' detected as: %s", relation, relType)
// If this is a belongs-to or has-one relation, use JOIN for better performance
if relType.ShouldUseJoin() {
logger.Info("Using JOIN strategy for %s relation '%s'", relType, relation)
return g.JoinRelation(relation, apply...)
}
// For has-many, many-to-many, or unknown: use separate query (safer default)
if relType == reflection.RelationHasMany || relType == reflection.RelationManyToMany {
logger.Debug("Using separate query for %s relation '%s'", relType, relation)
}
}
// Use GORM's Preload (separate query strategy)
g.db = g.db.Preload(relation, func(db *gorm.DB) *gorm.DB {
if len(apply) == 0 {
return db
}
wrapper := &GormSelectQuery{
db: db,
}
current := common.SelectQuery(wrapper)
for _, fn := range apply {
if fn != nil {
modified := fn(current)
current = modified
}
}
if finalBun, ok := current.(*GormSelectQuery); ok {
return finalBun.db
}
return db // fallback
})
return g
}
func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
// JoinRelation uses a JOIN instead of a separate preload query
// This is more efficient for many-to-one or one-to-one relationships
// as it avoids additional round trips to the database
// GORM's Joins() method forces a JOIN for the preload
logger.Debug("JoinRelation '%s' - Using GORM Joins() with automatic WHERE prefix addition", relation)
g.db = g.db.Joins(relation, func(db *gorm.DB) *gorm.DB {
if len(apply) == 0 {
return db
}
wrapper := &GormSelectQuery{
db: db,
inJoinContext: true, // Mark as JOIN context
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
}
current := common.SelectQuery(wrapper)
for _, fn := range apply {
if fn != nil {
current = fn(current)
}
}
if finalGorm, ok := current.(*GormSelectQuery); ok {
return finalGorm.db
}
return db
})
return g
}
func (g *GormSelectQuery) Order(order string) common.SelectQuery { func (g *GormSelectQuery) Order(order string) common.SelectQuery {
g.db = g.db.Order(order) g.db = g.db.Order(order)
return g return g
@ -135,19 +406,78 @@ func (g *GormSelectQuery) Having(having string, args ...interface{}) common.Sele
return g return g
} }
func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) error { func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
return g.db.WithContext(ctx).Find(dest).Error defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Scan", r)
}
}()
err = g.db.WithContext(ctx).Find(dest).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Find(dest)
})
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
}
return err
} }
func (g *GormSelectQuery) Count(ctx context.Context) (int, error) { func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
var count int64 defer func() {
err := g.db.WithContext(ctx).Count(&count).Error if r := recover(); r != nil {
return int(count), err err = logger.HandlePanic("GormSelectQuery.ScanModel", r)
}
}()
if g.db.Statement.Model == nil {
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
}
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Find(g.db.Statement.Model)
})
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
}
return err
} }
func (g *GormSelectQuery) Exists(ctx context.Context) (bool, error) { func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Count", r)
count = 0
}
}()
var count64 int64
err = g.db.WithContext(ctx).Count(&count64).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Count(&count64)
})
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
}
return int(count64), err
}
func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormSelectQuery.Exists", r)
exists = false
}
}()
var count int64 var count int64
err := g.db.WithContext(ctx).Limit(1).Count(&count).Error err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
if err != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Limit(1).Count(&count)
})
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
}
return count > 0, err return count > 0, err
} }
@ -187,13 +517,19 @@ func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
return g return g
} }
func (g *GormInsertQuery) Exec(ctx context.Context) (common.Result, error) { func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormInsertQuery.Exec", r)
}
}()
var result *gorm.DB var result *gorm.DB
if g.model != nil { switch {
case g.model != nil:
result = g.db.WithContext(ctx).Create(g.model) result = g.db.WithContext(ctx).Create(g.model)
} else if g.values != nil { case g.values != nil:
result = g.db.WithContext(ctx).Create(g.values) result = g.db.WithContext(ctx).Create(g.values)
} else { default:
result = g.db.WithContext(ctx).Create(map[string]interface{}{}) result = g.db.WithContext(ctx).Create(map[string]interface{}{})
} }
return &GormResult{result: result}, result.Error return &GormResult{result: result}, result.Error
@ -214,10 +550,23 @@ func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery { func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
g.db = g.db.Table(table) g.db = g.db.Table(table)
if g.model == nil {
// Try to get table name from table string if model is not set
model, err := modelregistry.GetModelByName(table)
if err == nil {
g.model = model
}
}
return g return g
} }
func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery { func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQuery {
// Validate column is writable if model is set
if g.model != nil && !reflection.IsColumnWritable(g.model, column) {
// Skip read-only columns
return g
}
if g.updates == nil { if g.updates == nil {
g.updates = make(map[string]interface{}) g.updates = make(map[string]interface{})
} }
@ -228,7 +577,25 @@ func (g *GormUpdateQuery) Set(column string, value interface{}) common.UpdateQue
} }
func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery { func (g *GormUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuery {
// Filter out read-only columns if model is set
if g.model != nil {
pkName := reflection.GetPrimaryKeyName(g.model)
filteredValues := make(map[string]interface{})
for column, value := range values {
if pkName != "" && column == pkName {
// Skip primary key updates
continue
}
if reflection.IsColumnWritable(g.model, column) {
filteredValues[column] = value
}
}
g.updates = filteredValues
} else {
g.updates = values g.updates = values
}
return g return g
} }
@ -242,8 +609,20 @@ func (g *GormUpdateQuery) Returning(columns ...string) common.UpdateQuery {
return g return g
} }
func (g *GormUpdateQuery) Exec(ctx context.Context) (common.Result, error) { func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Updates(g.updates) result := g.db.WithContext(ctx).Updates(g.updates)
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Updates(g.updates)
})
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
return &GormResult{result: result}, result.Error return &GormResult{result: result}, result.Error
} }
@ -269,8 +648,20 @@ func (g *GormDeleteQuery) Where(query string, args ...interface{}) common.Delete
return g return g
} }
func (g *GormDeleteQuery) Exec(ctx context.Context) (common.Result, error) { func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
}
}()
result := g.db.WithContext(ctx).Delete(g.model) result := g.db.WithContext(ctx).Delete(g.model)
if result.Error != nil {
// Log SQL string for debugging
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Delete(g.model)
})
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
}
return &GormResult{result: result}, result.Error return &GormResult{result: result}, result.Error
} }

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,176 @@
package database
import (
"context"
"database/sql"
"fmt"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Example demonstrates how to use the PgSQL adapter
func ExamplePgSQLAdapter() error {
// Connect to PostgreSQL database
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
}
defer db.Close()
// Create the PgSQL adapter
adapter := NewPgSQLAdapter(db)
// Enable query debugging (optional)
adapter.EnableQueryDebug()
ctx := context.Background()
// Example 1: Simple SELECT query
var results []map[string]interface{}
err = adapter.NewSelect().
Table("users").
Where("age > ?", 18).
Order("created_at DESC").
Limit(10).
Scan(ctx, &results)
if err != nil {
return fmt.Errorf("select failed: %w", err)
}
// Example 2: INSERT query
result, err := adapter.NewInsert().
Table("users").
Value("name", "John Doe").
Value("email", "john@example.com").
Value("age", 25).
Returning("id").
Exec(ctx)
if err != nil {
return fmt.Errorf("insert failed: %w", err)
}
fmt.Printf("Rows affected: %d\n", result.RowsAffected())
// Example 3: UPDATE query
result, err = adapter.NewUpdate().
Table("users").
Set("name", "Jane Doe").
Where("id = ?", 1).
Exec(ctx)
if err != nil {
return fmt.Errorf("update failed: %w", err)
}
fmt.Printf("Rows updated: %d\n", result.RowsAffected())
// Example 4: DELETE query
result, err = adapter.NewDelete().
Table("users").
Where("age < ?", 18).
Exec(ctx)
if err != nil {
return fmt.Errorf("delete failed: %w", err)
}
fmt.Printf("Rows deleted: %d\n", result.RowsAffected())
// Example 5: Using transactions
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
// Insert a new user
_, err := tx.NewInsert().
Table("users").
Value("name", "Transaction User").
Value("email", "tx@example.com").
Exec(ctx)
if err != nil {
return err
}
// Update another user
_, err = tx.NewUpdate().
Table("users").
Set("verified", true).
Where("email = ?", "tx@example.com").
Exec(ctx)
if err != nil {
return err
}
// Both operations succeed or both rollback
return nil
})
if err != nil {
return fmt.Errorf("transaction failed: %w", err)
}
// Example 6: JOIN query
err = adapter.NewSelect().
Table("users u").
Column("u.id", "u.name", "p.title as post_title").
LeftJoin("posts p ON p.user_id = u.id").
Where("u.active = ?", true).
Scan(ctx, &results)
if err != nil {
return fmt.Errorf("join query failed: %w", err)
}
// Example 7: Aggregation query
count, err := adapter.NewSelect().
Table("users").
Where("active = ?", true).
Count(ctx)
if err != nil {
return fmt.Errorf("count failed: %w", err)
}
fmt.Printf("Active users: %d\n", count)
// Example 8: Raw SQL execution
_, err = adapter.Exec(ctx, "CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
if err != nil {
return fmt.Errorf("raw exec failed: %w", err)
}
// Example 9: Raw SQL query
var users []map[string]interface{}
err = adapter.Query(ctx, &users, "SELECT * FROM users WHERE age > $1 LIMIT $2", 18, 10)
if err != nil {
return fmt.Errorf("raw query failed: %w", err)
}
return nil
}
// User is an example model
type User struct {
ID int `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
Age int `json:"age"`
}
// TableName implements common.TableNameProvider
func (u User) TableName() string {
return "users"
}
// ExampleWithModel demonstrates using models with the PgSQL adapter
func ExampleWithModel() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Use model with adapter
user := User{}
err = adapter.NewSelect().
Model(&user).
Where("id = ?", 1).
Scan(ctx, &user)
return err
}

View File

@ -0,0 +1,526 @@
// +build integration
package database
import (
"context"
"database/sql"
"fmt"
"testing"
"time"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
// Integration test models
type IntegrationUser struct {
ID int `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
CreatedAt time.Time `db:"created_at"`
Posts []*IntegrationPost `bun:"rel:has-many,join:id=user_id"`
}
func (u IntegrationUser) TableName() string {
return "users"
}
type IntegrationPost struct {
ID int `db:"id"`
Title string `db:"title"`
Content string `db:"content"`
UserID int `db:"user_id"`
Published bool `db:"published"`
CreatedAt time.Time `db:"created_at"`
User *IntegrationUser `bun:"rel:belongs-to,join:user_id=id"`
Comments []*IntegrationComment `bun:"rel:has-many,join:id=post_id"`
}
func (p IntegrationPost) TableName() string {
return "posts"
}
type IntegrationComment struct {
ID int `db:"id"`
Content string `db:"content"`
PostID int `db:"post_id"`
CreatedAt time.Time `db:"created_at"`
Post *IntegrationPost `bun:"rel:belongs-to,join:post_id=id"`
}
func (c IntegrationComment) TableName() string {
return "comments"
}
// setupTestDB creates a PostgreSQL container and returns the connection
func setupTestDB(t *testing.T) (*sql.DB, func()) {
ctx := context.Background()
req := testcontainers.ContainerRequest{
Image: "postgres:15-alpine",
ExposedPorts: []string{"5432/tcp"},
Env: map[string]string{
"POSTGRES_USER": "testuser",
"POSTGRES_PASSWORD": "testpass",
"POSTGRES_DB": "testdb",
},
WaitingFor: wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).
WithStartupTimeout(60 * time.Second),
}
postgres, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
require.NoError(t, err)
host, err := postgres.Host(ctx)
require.NoError(t, err)
port, err := postgres.MappedPort(ctx, "5432")
require.NoError(t, err)
dsn := fmt.Sprintf("postgres://testuser:testpass@%s:%s/testdb?sslmode=disable",
host, port.Port())
db, err := sql.Open("pgx", dsn)
require.NoError(t, err)
// Wait for database to be ready
err = db.Ping()
require.NoError(t, err)
// Create schema
createSchema(t, db)
cleanup := func() {
db.Close()
postgres.Terminate(ctx)
}
return db, cleanup
}
// createSchema creates test tables
func createSchema(t *testing.T, db *sql.DB) {
schema := `
DROP TABLE IF EXISTS comments CASCADE;
DROP TABLE IF EXISTS posts CASCADE;
DROP TABLE IF EXISTS users CASCADE;
CREATE TABLE users (
id SERIAL PRIMARY KEY,
name VARCHAR(255) NOT NULL,
email VARCHAR(255) UNIQUE NOT NULL,
age INT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE posts (
id SERIAL PRIMARY KEY,
title VARCHAR(255) NOT NULL,
content TEXT NOT NULL,
user_id INT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
published BOOLEAN DEFAULT false,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE comments (
id SERIAL PRIMARY KEY,
content TEXT NOT NULL,
post_id INT NOT NULL REFERENCES posts(id) ON DELETE CASCADE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
`
_, err := db.Exec(schema)
require.NoError(t, err)
}
// TestIntegration_BasicCRUD tests basic CRUD operations
func TestIntegration_BasicCRUD(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// CREATE
result, err := adapter.NewInsert().
Table("users").
Value("name", "John Doe").
Value("email", "john@example.com").
Value("age", 25).
Exec(ctx)
require.NoError(t, err)
assert.Equal(t, int64(1), result.RowsAffected())
// READ
var users []IntegrationUser
err = adapter.NewSelect().
Table("users").
Where("email = ?", "john@example.com").
Scan(ctx, &users)
require.NoError(t, err)
assert.Len(t, users, 1)
assert.Equal(t, "John Doe", users[0].Name)
assert.Equal(t, 25, users[0].Age)
userID := users[0].ID
// UPDATE
result, err = adapter.NewUpdate().
Table("users").
Set("age", 26).
Where("id = ?", userID).
Exec(ctx)
require.NoError(t, err)
assert.Equal(t, int64(1), result.RowsAffected())
// Verify update
var updatedUser IntegrationUser
err = adapter.NewSelect().
Table("users").
Where("id = ?", userID).
Scan(ctx, &updatedUser)
require.NoError(t, err)
assert.Equal(t, 26, updatedUser.Age)
// DELETE
result, err = adapter.NewDelete().
Table("users").
Where("id = ?", userID).
Exec(ctx)
require.NoError(t, err)
assert.Equal(t, int64(1), result.RowsAffected())
// Verify delete
count, err := adapter.NewSelect().
Table("users").
Where("id = ?", userID).
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 0, count)
}
// TestIntegration_ScanModel tests ScanModel functionality
func TestIntegration_ScanModel(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Insert test data
_, err := adapter.NewInsert().
Table("users").
Value("name", "Jane Smith").
Value("email", "jane@example.com").
Value("age", 30).
Exec(ctx)
require.NoError(t, err)
// Test single struct scan
user := &IntegrationUser{}
err = adapter.NewSelect().
Model(user).
Table("users").
Where("email = ?", "jane@example.com").
ScanModel(ctx)
require.NoError(t, err)
assert.Equal(t, "Jane Smith", user.Name)
assert.Equal(t, 30, user.Age)
// Test slice scan
users := []*IntegrationUser{}
err = adapter.NewSelect().
Model(&users).
Table("users").
ScanModel(ctx)
require.NoError(t, err)
assert.Len(t, users, 1)
}
// TestIntegration_Transaction tests transaction handling
func TestIntegration_Transaction(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Successful transaction
err := adapter.RunInTransaction(ctx, func(tx common.Database) error {
_, err := tx.NewInsert().
Table("users").
Value("name", "Alice").
Value("email", "alice@example.com").
Value("age", 28).
Exec(ctx)
if err != nil {
return err
}
_, err = tx.NewInsert().
Table("users").
Value("name", "Bob").
Value("email", "bob@example.com").
Value("age", 32).
Exec(ctx)
return err
})
require.NoError(t, err)
// Verify both records exist
count, err := adapter.NewSelect().
Table("users").
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 2, count)
// Failed transaction (should rollback)
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
_, err := tx.NewInsert().
Table("users").
Value("name", "Charlie").
Value("email", "charlie@example.com").
Value("age", 35).
Exec(ctx)
if err != nil {
return err
}
// Intentional error - duplicate email
_, err = tx.NewInsert().
Table("users").
Value("name", "David").
Value("email", "alice@example.com"). // Duplicate
Value("age", 40).
Exec(ctx)
return err
})
assert.Error(t, err)
// Verify rollback - count should still be 2
count, err = adapter.NewSelect().
Table("users").
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 2, count)
}
// TestIntegration_Preload tests basic preload functionality
func TestIntegration_Preload(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
userID := createTestUser(t, adapter, ctx, "John Doe", "john@example.com", 25)
createTestPost(t, adapter, ctx, userID, "First Post", "Content 1", true)
createTestPost(t, adapter, ctx, userID, "Second Post", "Content 2", false)
// Test Preload
var users []*IntegrationUser
err := adapter.NewSelect().
Model(&IntegrationUser{}).
Table("users").
Preload("Posts").
Scan(ctx, &users)
require.NoError(t, err)
assert.Len(t, users, 1)
assert.NotNil(t, users[0].Posts)
assert.Len(t, users[0].Posts, 2)
}
// TestIntegration_PreloadRelation tests smart PreloadRelation
func TestIntegration_PreloadRelation(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
userID := createTestUser(t, adapter, ctx, "Jane Smith", "jane@example.com", 30)
postID := createTestPost(t, adapter, ctx, userID, "Test Post", "Test Content", true)
createTestComment(t, adapter, ctx, postID, "Great post!")
createTestComment(t, adapter, ctx, postID, "Thanks for sharing!")
// Test PreloadRelation with belongs-to (should use JOIN)
var posts []*IntegrationPost
err := adapter.NewSelect().
Model(&IntegrationPost{}).
Table("posts").
PreloadRelation("User").
Scan(ctx, &posts)
require.NoError(t, err)
assert.Len(t, posts, 1)
// Note: JOIN preloading needs proper column selection to work
// For now, we test that it doesn't error
// Test PreloadRelation with has-many (should use subquery)
posts = []*IntegrationPost{}
err = adapter.NewSelect().
Model(&IntegrationPost{}).
Table("posts").
PreloadRelation("Comments").
Scan(ctx, &posts)
require.NoError(t, err)
assert.Len(t, posts, 1)
if posts[0].Comments != nil {
assert.Len(t, posts[0].Comments, 2)
}
}
// TestIntegration_JoinRelation tests explicit JoinRelation
func TestIntegration_JoinRelation(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
userID := createTestUser(t, adapter, ctx, "Bob Wilson", "bob@example.com", 35)
createTestPost(t, adapter, ctx, userID, "Join Test", "Content", true)
// Test JoinRelation
var posts []*IntegrationPost
err := adapter.NewSelect().
Model(&IntegrationPost{}).
Table("posts").
JoinRelation("User").
Scan(ctx, &posts)
require.NoError(t, err)
assert.Len(t, posts, 1)
}
// TestIntegration_ComplexQuery tests complex queries
func TestIntegration_ComplexQuery(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
userID1 := createTestUser(t, adapter, ctx, "Alice", "alice@example.com", 25)
userID2 := createTestUser(t, adapter, ctx, "Bob", "bob@example.com", 30)
userID3 := createTestUser(t, adapter, ctx, "Charlie", "charlie@example.com", 35)
createTestPost(t, adapter, ctx, userID1, "Post 1", "Content", true)
createTestPost(t, adapter, ctx, userID2, "Post 2", "Content", true)
createTestPost(t, adapter, ctx, userID3, "Post 3", "Content", false)
// Complex query with joins, where, order, limit
var results []map[string]interface{}
err := adapter.NewSelect().
Table("posts p").
Column("p.title", "u.name as author_name", "u.age as author_age").
LeftJoin("users u ON u.id = p.user_id").
Where("p.published = ?", true).
WhereOr("u.age > ?", 25).
Order("u.age DESC").
Limit(2).
Scan(ctx, &results)
require.NoError(t, err)
assert.LessOrEqual(t, len(results), 2)
}
// TestIntegration_Aggregation tests aggregation queries
func TestIntegration_Aggregation(t *testing.T) {
db, cleanup := setupTestDB(t)
defer cleanup()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Create test data
createTestUser(t, adapter, ctx, "User 1", "user1@example.com", 20)
createTestUser(t, adapter, ctx, "User 2", "user2@example.com", 25)
createTestUser(t, adapter, ctx, "User 3", "user3@example.com", 30)
// Test Count
count, err := adapter.NewSelect().
Table("users").
Where("age >= ?", 25).
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 2, count)
// Test Exists
exists, err := adapter.NewSelect().
Table("users").
Where("email = ?", "user1@example.com").
Exists(ctx)
require.NoError(t, err)
assert.True(t, exists)
// Test Group By with aggregation
var results []map[string]interface{}
err = adapter.NewSelect().
Table("users").
Column("age", "COUNT(*) as count").
Group("age").
Having("COUNT(*) > ?", 0).
Order("age ASC").
Scan(ctx, &results)
require.NoError(t, err)
assert.Len(t, results, 3)
}
// Helper functions
func createTestUser(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, name, email string, age int) int {
var userID int
err := adapter.Query(ctx, &userID,
"INSERT INTO users (name, email, age) VALUES ($1, $2, $3) RETURNING id",
name, email, age)
require.NoError(t, err)
return userID
}
func createTestPost(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, userID int, title, content string, published bool) int {
var postID int
err := adapter.Query(ctx, &postID,
"INSERT INTO posts (title, content, user_id, published) VALUES ($1, $2, $3, $4) RETURNING id",
title, content, userID, published)
require.NoError(t, err)
return postID
}
func createTestComment(t *testing.T, adapter *PgSQLAdapter, ctx context.Context, postID int, content string) int {
var commentID int
err := adapter.Query(ctx, &commentID,
"INSERT INTO comments (content, post_id) VALUES ($1, $2) RETURNING id",
content, postID)
require.NoError(t, err)
return commentID
}

View File

@ -0,0 +1,275 @@
package database
import (
"context"
"database/sql"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Example models for demonstrating preload functionality
// Author model - has many Posts
type Author struct {
ID int `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Posts []*Post `bun:"rel:has-many,join:id=author_id"`
}
func (a Author) TableName() string {
return "authors"
}
// Post model - belongs to Author, has many Comments
type Post struct {
ID int `db:"id"`
Title string `db:"title"`
Content string `db:"content"`
AuthorID int `db:"author_id"`
Author *Author `bun:"rel:belongs-to,join:author_id=id"`
Comments []*Comment `bun:"rel:has-many,join:id=post_id"`
}
func (p Post) TableName() string {
return "posts"
}
// Comment model - belongs to Post
type Comment struct {
ID int `db:"id"`
Content string `db:"content"`
PostID int `db:"post_id"`
Post *Post `bun:"rel:belongs-to,join:post_id=id"`
}
func (c Comment) TableName() string {
return "comments"
}
// ExamplePreload demonstrates the Preload functionality
func ExamplePreload() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Example 1: Simple Preload (uses subquery for has-many)
var authors []*Author
err = adapter.NewSelect().
Model(&Author{}).
Table("authors").
Preload("Posts"). // Load all posts for each author
Scan(ctx, &authors)
if err != nil {
return err
}
// Now authors[i].Posts will be populated with their posts
return nil
}
// ExamplePreloadRelation demonstrates smart PreloadRelation with auto-detection
func ExamplePreloadRelation() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Example 1: PreloadRelation auto-detects has-many (uses subquery)
var authors []*Author
err = adapter.NewSelect().
Model(&Author{}).
Table("authors").
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
return q.Where("published = ?", true).Order("created_at DESC")
}).
Where("active = ?", true).
Scan(ctx, &authors)
if err != nil {
return err
}
// Example 2: PreloadRelation auto-detects belongs-to (uses JOIN)
var posts []*Post
err = adapter.NewSelect().
Model(&Post{}).
Table("posts").
PreloadRelation("Author"). // Will use JOIN because it's belongs-to
Scan(ctx, &posts)
if err != nil {
return err
}
// Example 3: Nested preloads
err = adapter.NewSelect().
Model(&Author{}).
Table("authors").
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
// First load posts, then preload comments for each post
return q.Limit(10)
}).
Scan(ctx, &authors)
if err != nil {
return err
}
// Manually load nested relationships (two-level preloading)
for _, author := range authors {
if author.Posts != nil {
for _, post := range author.Posts {
var comments []*Comment
err := adapter.NewSelect().
Table("comments").
Where("post_id = ?", post.ID).
Scan(ctx, &comments)
if err == nil {
post.Comments = comments
}
}
}
}
return nil
}
// ExampleJoinRelation demonstrates explicit JOIN loading
func ExampleJoinRelation() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Example 1: Force JOIN for belongs-to relationship
var posts []*Post
err = adapter.NewSelect().
Model(&Post{}).
Table("posts").
JoinRelation("Author", func(q common.SelectQuery) common.SelectQuery {
return q.Where("active = ?", true)
}).
Scan(ctx, &posts)
if err != nil {
return err
}
// Example 2: Multiple JOINs
err = adapter.NewSelect().
Model(&Post{}).
Table("posts p").
Column("p.*", "a.name as author_name", "a.email as author_email").
LeftJoin("authors a ON a.id = p.author_id").
Where("p.published = ?", true).
Scan(ctx, &posts)
return err
}
// ExampleScanModel demonstrates ScanModel with struct destinations
func ExampleScanModel() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
// Example 1: Scan single struct
author := Author{}
err = adapter.NewSelect().
Model(&author).
Table("authors").
Where("id = ?", 1).
ScanModel(ctx) // ScanModel automatically uses the model set with Model()
if err != nil {
return err
}
// Example 2: Scan slice of structs
authors := []*Author{}
err = adapter.NewSelect().
Model(&authors).
Table("authors").
Where("active = ?", true).
Limit(10).
ScanModel(ctx)
return err
}
// ExampleCompleteWorkflow demonstrates a complete workflow with preloading
func ExampleCompleteWorkflow() error {
dsn := "postgres://username:password@localhost:5432/dbname?sslmode=disable"
db, err := sql.Open("pgx", dsn)
if err != nil {
return err
}
defer db.Close()
adapter := NewPgSQLAdapter(db)
adapter.EnableQueryDebug() // Enable query logging
ctx := context.Background()
// Step 1: Create an author
author := &Author{
Name: "John Doe",
Email: "john@example.com",
}
result, err := adapter.NewInsert().
Table("authors").
Value("name", author.Name).
Value("email", author.Email).
Returning("id").
Exec(ctx)
if err != nil {
return err
}
_ = result
// Step 2: Load author with all their posts
var loadedAuthor Author
err = adapter.NewSelect().
Model(&loadedAuthor).
Table("authors").
PreloadRelation("Posts", func(q common.SelectQuery) common.SelectQuery {
return q.Order("created_at DESC").Limit(5)
}).
Where("id = ?", 1).
ScanModel(ctx)
if err != nil {
return err
}
// Step 3: Update author name
_, err = adapter.NewUpdate().
Table("authors").
Set("name", "Jane Doe").
Where("id = ?", 1).
Exec(ctx)
return err
}

View File

@ -0,0 +1,629 @@
package database
import (
"context"
"database/sql"
"reflect"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
// Test models
type TestUser struct {
ID int `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
}
func (u TestUser) TableName() string {
return "users"
}
type TestPost struct {
ID int `db:"id"`
Title string `db:"title"`
Content string `db:"content"`
UserID int `db:"user_id"`
User *TestUser `bun:"rel:belongs-to,join:user_id=id"`
Comments []TestComment `bun:"rel:has-many,join:id=post_id"`
}
func (p TestPost) TableName() string {
return "posts"
}
type TestComment struct {
ID int `db:"id"`
Content string `db:"content"`
PostID int `db:"post_id"`
}
func (c TestComment) TableName() string {
return "comments"
}
// TestNewPgSQLAdapter tests adapter creation
func TestNewPgSQLAdapter(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
adapter := NewPgSQLAdapter(db)
assert.NotNil(t, adapter)
assert.Equal(t, db, adapter.db)
}
// TestPgSQLSelectQuery_BuildSQL tests SQL query building
func TestPgSQLSelectQuery_BuildSQL(t *testing.T) {
tests := []struct {
name string
setup func(*PgSQLSelectQuery)
expected string
}{
{
name: "simple select",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
},
expected: "SELECT * FROM users",
},
{
name: "select with columns",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.columns = []string{"id", "name", "email"}
},
expected: "SELECT id, name, email FROM users",
},
{
name: "select with where",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.whereClauses = []string{"age > $1"}
q.args = []interface{}{18}
},
expected: "SELECT * FROM users WHERE (age > $1)",
},
{
name: "select with order and limit",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.orderBy = []string{"created_at DESC"}
q.limit = 10
q.offset = 5
},
expected: "SELECT * FROM users ORDER BY created_at DESC LIMIT 10 OFFSET 5",
},
{
name: "select with join",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.joins = []string{"LEFT JOIN posts ON posts.user_id = users.id"}
},
expected: "SELECT * FROM users LEFT JOIN posts ON posts.user_id = users.id",
},
{
name: "select with group and having",
setup: func(q *PgSQLSelectQuery) {
q.tableName = "users"
q.groupBy = []string{"country"}
q.havingClauses = []string{"COUNT(*) > $1"}
q.args = []interface{}{5}
},
expected: "SELECT * FROM users GROUP BY country HAVING COUNT(*) > $1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := &PgSQLSelectQuery{
columns: []string{"*"},
}
tt.setup(q)
sql := q.buildSQL()
assert.Equal(t, tt.expected, sql)
})
}
}
// TestPgSQLSelectQuery_ReplacePlaceholders tests placeholder replacement
func TestPgSQLSelectQuery_ReplacePlaceholders(t *testing.T) {
tests := []struct {
name string
query string
argCount int
paramCounter int
expected string
}{
{
name: "single placeholder",
query: "age > ?",
argCount: 1,
paramCounter: 0,
expected: "age > $1",
},
{
name: "multiple placeholders",
query: "age > ? AND status = ?",
argCount: 2,
paramCounter: 0,
expected: "age > $1 AND status = $2",
},
{
name: "with existing counter",
query: "name = ?",
argCount: 1,
paramCounter: 5,
expected: "name = $6",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := &PgSQLSelectQuery{paramCounter: tt.paramCounter}
result := q.replacePlaceholders(tt.query, tt.argCount)
assert.Equal(t, tt.expected, result)
})
}
}
// TestPgSQLSelectQuery_Chaining tests method chaining
func TestPgSQLSelectQuery_Chaining(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
adapter := NewPgSQLAdapter(db)
query := adapter.NewSelect().
Table("users").
Column("id", "name").
Where("age > ?", 18).
Order("name ASC").
Limit(10).
Offset(5)
pgQuery := query.(*PgSQLSelectQuery)
assert.Equal(t, "users", pgQuery.tableName)
assert.Equal(t, []string{"id", "name"}, pgQuery.columns)
assert.Len(t, pgQuery.whereClauses, 1)
assert.Equal(t, []string{"name ASC"}, pgQuery.orderBy)
assert.Equal(t, 10, pgQuery.limit)
assert.Equal(t, 5, pgQuery.offset)
}
// TestPgSQLSelectQuery_Model tests model setting
func TestPgSQLSelectQuery_Model(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
adapter := NewPgSQLAdapter(db)
user := &TestUser{}
query := adapter.NewSelect().Model(user)
pgQuery := query.(*PgSQLSelectQuery)
assert.Equal(t, "users", pgQuery.tableName)
assert.Equal(t, user, pgQuery.model)
}
// TestScanRowsToStructSlice tests scanning rows into struct slice
func TestScanRowsToStructSlice(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
AddRow(1, "John Doe", "john@example.com", 25).
AddRow(2, "Jane Smith", "jane@example.com", 30)
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
var users []TestUser
err = adapter.NewSelect().
Table("users").
Scan(ctx, &users)
require.NoError(t, err)
assert.Len(t, users, 2)
assert.Equal(t, "John Doe", users[0].Name)
assert.Equal(t, "jane@example.com", users[1].Email)
assert.Equal(t, 30, users[1].Age)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestScanRowsToStructSlicePointers tests scanning rows into pointer slice
func TestScanRowsToStructSlicePointers(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
AddRow(1, "John Doe", "john@example.com", 25)
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
var users []*TestUser
err = adapter.NewSelect().
Table("users").
Scan(ctx, &users)
require.NoError(t, err)
assert.Len(t, users, 1)
assert.NotNil(t, users[0])
assert.Equal(t, "John Doe", users[0].Name)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestScanRowsToSingleStruct tests scanning a single row
func TestScanRowsToSingleStruct(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
AddRow(1, "John Doe", "john@example.com", 25)
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
var user TestUser
err = adapter.NewSelect().
Table("users").
Where("id = ?", 1).
Scan(ctx, &user)
require.NoError(t, err)
assert.Equal(t, 1, user.ID)
assert.Equal(t, "John Doe", user.Name)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestScanRowsToMapSlice tests scanning into map slice
func TestScanRowsToMapSlice(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email"}).
AddRow(1, "John Doe", "john@example.com").
AddRow(2, "Jane Smith", "jane@example.com")
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
var results []map[string]interface{}
err = adapter.NewSelect().
Table("users").
Scan(ctx, &results)
require.NoError(t, err)
assert.Len(t, results, 2)
assert.Equal(t, int64(1), results[0]["id"])
assert.Equal(t, "John Doe", results[0]["name"])
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLInsertQuery_Exec tests insert query execution
func TestPgSQLInsertQuery_Exec(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock.ExpectExec("INSERT INTO users").
WithArgs("John Doe", "john@example.com", 25).
WillReturnResult(sqlmock.NewResult(1, 1))
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
result, err := adapter.NewInsert().
Table("users").
Value("name", "John Doe").
Value("email", "john@example.com").
Value("age", 25).
Exec(ctx)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int64(1), result.RowsAffected())
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLUpdateQuery_Exec tests update query execution
func TestPgSQLUpdateQuery_Exec(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
// Note: Args order is SET values first, then WHERE values
mock.ExpectExec("UPDATE users SET name = \\$1 WHERE id = \\$2").
WithArgs("Jane Doe", 1).
WillReturnResult(sqlmock.NewResult(0, 1))
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
result, err := adapter.NewUpdate().
Table("users").
Set("name", "Jane Doe").
Where("id = ?", 1).
Exec(ctx)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int64(1), result.RowsAffected())
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLDeleteQuery_Exec tests delete query execution
func TestPgSQLDeleteQuery_Exec(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock.ExpectExec("DELETE FROM users WHERE id = \\$1").
WithArgs(1).
WillReturnResult(sqlmock.NewResult(0, 1))
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
result, err := adapter.NewDelete().
Table("users").
Where("id = ?", 1).
Exec(ctx)
require.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, int64(1), result.RowsAffected())
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLSelectQuery_Count tests count query
func TestPgSQLSelectQuery_Count(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"count"}).AddRow(42)
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
count, err := adapter.NewSelect().
Table("users").
Count(ctx)
require.NoError(t, err)
assert.Equal(t, 42, count)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLSelectQuery_Exists tests exists query
func TestPgSQLSelectQuery_Exists(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"count"}).AddRow(1)
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
exists, err := adapter.NewSelect().
Table("users").
Where("email = ?", "john@example.com").
Exists(ctx)
require.NoError(t, err)
assert.True(t, exists)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLAdapter_Transaction tests transaction handling
func TestPgSQLAdapter_Transaction(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock.ExpectBegin()
mock.ExpectExec("INSERT INTO users").WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
_, err := tx.NewInsert().
Table("users").
Value("name", "John").
Exec(ctx)
return err
})
require.NoError(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestPgSQLAdapter_TransactionRollback tests transaction rollback
func TestPgSQLAdapter_TransactionRollback(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock.ExpectBegin()
mock.ExpectExec("INSERT INTO users").WillReturnError(sql.ErrConnDone)
mock.ExpectRollback()
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
err = adapter.RunInTransaction(ctx, func(tx common.Database) error {
_, err := tx.NewInsert().
Table("users").
Value("name", "John").
Exec(ctx)
return err
})
assert.Error(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestBuildFieldMap tests field mapping construction
func TestBuildFieldMap(t *testing.T) {
userType := reflect.TypeOf(TestUser{})
fieldMap := buildFieldMap(userType, nil)
assert.NotEmpty(t, fieldMap)
// Check that fields are mapped
assert.Contains(t, fieldMap, "id")
assert.Contains(t, fieldMap, "name")
assert.Contains(t, fieldMap, "email")
assert.Contains(t, fieldMap, "age")
// Check field info
idInfo := fieldMap["id"]
assert.Equal(t, "ID", idInfo.Name)
}
// TestGetRelationMetadata tests relationship metadata extraction
func TestGetRelationMetadata(t *testing.T) {
q := &PgSQLSelectQuery{
model: &TestPost{},
}
// Test belongs-to relationship
meta := q.getRelationMetadata("User")
assert.NotNil(t, meta)
assert.Equal(t, "User", meta.fieldName)
// Test has-many relationship
meta = q.getRelationMetadata("Comments")
assert.NotNil(t, meta)
assert.Equal(t, "Comments", meta.fieldName)
}
// TestPreloadConfiguration tests preload configuration
func TestPreloadConfiguration(t *testing.T) {
db, _, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
adapter := NewPgSQLAdapter(db)
// Test Preload
query := adapter.NewSelect().
Model(&TestPost{}).
Table("posts").
Preload("User")
pgQuery := query.(*PgSQLSelectQuery)
assert.Len(t, pgQuery.preloads, 1)
assert.Equal(t, "User", pgQuery.preloads[0].relation)
assert.False(t, pgQuery.preloads[0].useJoin)
// Test PreloadRelation
query = adapter.NewSelect().
Model(&TestPost{}).
Table("posts").
PreloadRelation("Comments")
pgQuery = query.(*PgSQLSelectQuery)
assert.Len(t, pgQuery.preloads, 1)
assert.Equal(t, "Comments", pgQuery.preloads[0].relation)
// Test JoinRelation
query = adapter.NewSelect().
Model(&TestPost{}).
Table("posts").
JoinRelation("User")
pgQuery = query.(*PgSQLSelectQuery)
assert.Len(t, pgQuery.preloads, 1)
assert.Equal(t, "User", pgQuery.preloads[0].relation)
assert.True(t, pgQuery.preloads[0].useJoin)
}
// TestScanModel tests ScanModel functionality
func TestScanModel(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
rows := sqlmock.NewRows([]string{"id", "name", "email", "age"}).
AddRow(1, "John Doe", "john@example.com", 25)
mock.ExpectQuery("SELECT (.+) FROM users").WillReturnRows(rows)
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
user := &TestUser{}
err = adapter.NewSelect().
Model(user).
Table("users").
Where("id = ?", 1).
ScanModel(ctx)
require.NoError(t, err)
assert.Equal(t, 1, user.ID)
assert.Equal(t, "John Doe", user.Name)
assert.NoError(t, mock.ExpectationsWereMet())
}
// TestRawSQL tests raw SQL execution
func TestRawSQL(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
// Test Exec
mock.ExpectExec("CREATE TABLE test").WillReturnResult(sqlmock.NewResult(0, 0))
adapter := NewPgSQLAdapter(db)
ctx := context.Background()
_, err = adapter.Exec(ctx, "CREATE TABLE test (id INT)")
require.NoError(t, err)
// Test Query
rows := sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "Test")
mock.ExpectQuery("SELECT (.+) FROM test").WillReturnRows(rows)
var results []map[string]interface{}
err = adapter.Query(ctx, &results, "SELECT * FROM test WHERE id = $1", 1)
require.NoError(t, err)
assert.Len(t, results, 1)
assert.NoError(t, mock.ExpectationsWereMet())
}

View File

@ -0,0 +1,132 @@
package database
import (
"context"
"database/sql"
"testing"
"github.com/stretchr/testify/require"
)
// TestHelper provides utilities for database testing
type TestHelper struct {
DB *sql.DB
Adapter *PgSQLAdapter
t *testing.T
}
// NewTestHelper creates a new test helper
func NewTestHelper(t *testing.T, db *sql.DB) *TestHelper {
return &TestHelper{
DB: db,
Adapter: NewPgSQLAdapter(db),
t: t,
}
}
// CleanupTables truncates all test tables
func (h *TestHelper) CleanupTables() {
ctx := context.Background()
tables := []string{"comments", "posts", "users"}
for _, table := range tables {
_, err := h.DB.ExecContext(ctx, "TRUNCATE TABLE "+table+" CASCADE")
require.NoError(h.t, err)
}
}
// InsertUser inserts a test user and returns the ID
func (h *TestHelper) InsertUser(name, email string, age int) int {
ctx := context.Background()
result, err := h.Adapter.NewInsert().
Table("users").
Value("name", name).
Value("email", email).
Value("age", age).
Exec(ctx)
require.NoError(h.t, err)
id, _ := result.LastInsertId()
return int(id)
}
// InsertPost inserts a test post and returns the ID
func (h *TestHelper) InsertPost(userID int, title, content string, published bool) int {
ctx := context.Background()
result, err := h.Adapter.NewInsert().
Table("posts").
Value("user_id", userID).
Value("title", title).
Value("content", content).
Value("published", published).
Exec(ctx)
require.NoError(h.t, err)
id, _ := result.LastInsertId()
return int(id)
}
// InsertComment inserts a test comment and returns the ID
func (h *TestHelper) InsertComment(postID int, content string) int {
ctx := context.Background()
result, err := h.Adapter.NewInsert().
Table("comments").
Value("post_id", postID).
Value("content", content).
Exec(ctx)
require.NoError(h.t, err)
id, _ := result.LastInsertId()
return int(id)
}
// AssertUserExists checks if a user exists by email
func (h *TestHelper) AssertUserExists(email string) {
ctx := context.Background()
exists, err := h.Adapter.NewSelect().
Table("users").
Where("email = ?", email).
Exists(ctx)
require.NoError(h.t, err)
require.True(h.t, exists, "User with email %s should exist", email)
}
// AssertUserCount asserts the number of users
func (h *TestHelper) AssertUserCount(expected int) {
ctx := context.Background()
count, err := h.Adapter.NewSelect().
Table("users").
Count(ctx)
require.NoError(h.t, err)
require.Equal(h.t, expected, count)
}
// GetUserByEmail retrieves a user by email
func (h *TestHelper) GetUserByEmail(email string) map[string]interface{} {
ctx := context.Background()
var results []map[string]interface{}
err := h.Adapter.NewSelect().
Table("users").
Where("email = ?", email).
Scan(ctx, &results)
require.NoError(h.t, err)
require.Len(h.t, results, 1, "Expected exactly one user with email %s", email)
return results[0]
}
// BeginTestTransaction starts a transaction for testing
func (h *TestHelper) BeginTestTransaction() (*PgSQLTxAdapter, func()) {
ctx := context.Background()
tx, err := h.DB.BeginTx(ctx, nil)
require.NoError(h.t, err)
adapter := &PgSQLTxAdapter{tx: tx}
cleanup := func() {
tx.Rollback()
}
return adapter, cleanup
}

View File

@ -0,0 +1,161 @@
package database
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// Test models for bun
type BunTestModel struct {
ID int `bun:"id,pk"`
Name string `bun:"name"`
Email string `bun:"email"`
ComputedCol string `bun:"computed_col,scanonly"`
}
// Test models for gorm
type GormTestModel struct {
ID int `gorm:"column:id;primaryKey"`
Name string `gorm:"column:name"`
Email string `gorm:"column:email"`
ReadOnlyCol string `gorm:"column:readonly_col;->"`
NoWriteCol string `gorm:"column:nowrite_col;<-:false"`
}
func TestIsColumnWritable_Bun(t *testing.T) {
model := &BunTestModel{}
tests := []struct {
name string
columnName string
expected bool
}{
{
name: "writable column - id",
columnName: "id",
expected: true,
},
{
name: "writable column - name",
columnName: "name",
expected: true,
},
{
name: "writable column - email",
columnName: "email",
expected: true,
},
{
name: "scanonly column should not be writable",
columnName: "computed_col",
expected: false,
},
{
name: "non-existent column should be writable (dynamic)",
columnName: "nonexistent",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := reflection.IsColumnWritable(model, tt.columnName)
if result != tt.expected {
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
}
})
}
}
func TestIsColumnWritable_Gorm(t *testing.T) {
model := &GormTestModel{}
tests := []struct {
name string
columnName string
expected bool
}{
{
name: "writable column - id",
columnName: "id",
expected: true,
},
{
name: "writable column - name",
columnName: "name",
expected: true,
},
{
name: "writable column - email",
columnName: "email",
expected: true,
},
{
name: "read-only column with -> should not be writable",
columnName: "readonly_col",
expected: false,
},
{
name: "column with <-:false should not be writable",
columnName: "nowrite_col",
expected: false,
},
{
name: "non-existent column should be writable (dynamic)",
columnName: "nonexistent",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := reflection.IsColumnWritable(model, tt.columnName)
if result != tt.expected {
t.Errorf("IsColumnWritable(%q) = %v, want %v", tt.columnName, result, tt.expected)
}
})
}
}
func TestBunUpdateQuery_SetMap_FiltersScanOnly(t *testing.T) {
// Note: This is a unit test for the validation logic only.
// We can't fully test the bun query without a database connection,
// but we've verified the validation logic in TestIsColumnWritable_Bun
t.Skip("Skipping integration test - validation logic tested in TestIsColumnWritable_Bun")
}
func TestGormUpdateQuery_SetMap_FiltersReadOnly(t *testing.T) {
model := &GormTestModel{}
query := &GormUpdateQuery{
model: model,
}
// SetMap should filter out read-only columns
values := map[string]interface{}{
"name": "John",
"email": "john@example.com",
"readonly_col": "should_be_filtered",
"nowrite_col": "should_also_be_filtered",
}
query.SetMap(values)
// Check that the updates map only contains writable columns
if updates, ok := query.updates.(map[string]interface{}); ok {
if _, exists := updates["readonly_col"]; exists {
t.Error("readonly_col should have been filtered out")
}
if _, exists := updates["nowrite_col"]; exists {
t.Error("nowrite_col should have been filtered out")
}
if _, exists := updates["name"]; !exists {
t.Error("name should be in updates")
}
if _, exists := updates["email"]; !exists {
t.Error("email should be in updates")
}
} else {
t.Error("updates should be a map[string]interface{}")
}
}

View File

@ -0,0 +1,16 @@
package database
import (
"strings"
)
// parseTableName splits a table name that may contain schema into separate schema and table
// For example: "public.users" -> ("public", "users")
//
// "users" -> ("", "users")
func parseTableName(fullTableName string) (schema, table string) {
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
return fullTableName[:idx], fullTableName[idx+1:]
}
return "", fullTableName
}

View File

@ -3,8 +3,10 @@ package router
import ( import (
"net/http" "net/http"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
"github.com/uptrace/bunrouter" "github.com/uptrace/bunrouter"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
) )
// BunRouterAdapter adapts uptrace/bunrouter to work with our Router interface // BunRouterAdapter adapts uptrace/bunrouter to work with our Router interface
@ -34,7 +36,11 @@ func (b *BunRouterAdapter) HandleFunc(pattern string, handler common.HTTPHandler
func (b *BunRouterAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) { func (b *BunRouterAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) {
// This method would be used when we need to serve through our interface // This method would be used when we need to serve through our interface
// For now, we'll work directly with the underlying router // For now, we'll work directly with the underlying router
panic("ServeHTTP not implemented - use GetBunRouter() for direct access") w.WriteHeader(http.StatusNotImplemented)
_, err := w.Write([]byte(`{"error":"ServeHTTP not implemented - use GetBunRouter() for direct access"}`))
if err != nil {
logger.Warn("Failed to write. %v", err)
}
} }
// GetBunRouter returns the underlying bunrouter for direct access // GetBunRouter returns the underlying bunrouter for direct access
@ -120,6 +126,16 @@ func (b *BunRouterRequest) QueryParam(key string) string {
return b.req.URL.Query().Get(key) return b.req.URL.Query().Get(key)
} }
func (b *BunRouterRequest) AllQueryParams() map[string]string {
params := make(map[string]string)
for key, values := range b.req.URL.Query() {
if len(values) > 0 {
params[key] = values[0]
}
}
return params
}
func (b *BunRouterRequest) AllHeaders() map[string]string { func (b *BunRouterRequest) AllHeaders() map[string]string {
headers := make(map[string]string) headers := make(map[string]string)
for key, values := range b.req.Header { for key, values := range b.req.Header {
@ -130,6 +146,12 @@ func (b *BunRouterRequest) AllHeaders() map[string]string {
return headers return headers
} }
// UnderlyingRequest returns the underlying *http.Request
// This is useful when you need to pass the request to other handlers
func (b *BunRouterRequest) UnderlyingRequest() *http.Request {
return b.req.Request
}
// StandardBunRouterAdapter creates routes compatible with standard bunrouter handlers // StandardBunRouterAdapter creates routes compatible with standard bunrouter handlers
type StandardBunRouterAdapter struct { type StandardBunRouterAdapter struct {
*BunRouterAdapter *BunRouterAdapter
@ -190,4 +212,3 @@ func DefaultBunRouterConfig() *BunRouterConfig {
HandleOPTIONS: true, HandleOPTIONS: true,
} }
} }

View File

@ -5,8 +5,10 @@ import (
"io" "io"
"net/http" "net/http"
"github.com/Warky-Devs/ResolveSpec/pkg/common"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
) )
// MuxAdapter adapts Gorilla Mux to work with our Router interface // MuxAdapter adapts Gorilla Mux to work with our Router interface
@ -31,7 +33,11 @@ func (m *MuxAdapter) HandleFunc(pattern string, handler common.HTTPHandlerFunc)
func (m *MuxAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) { func (m *MuxAdapter) ServeHTTP(w common.ResponseWriter, r common.Request) {
// This method would be used when we need to serve through our interface // This method would be used when we need to serve through our interface
// For now, we'll work directly with the underlying router // For now, we'll work directly with the underlying router
panic("ServeHTTP not implemented - use GetMuxRouter() for direct access") w.WriteHeader(http.StatusNotImplemented)
_, err := w.Write([]byte(`{"error":"ServeHTTP not implemented - use GetMuxRouter() for direct access"}`))
if err != nil {
logger.Warn("Failed to write. %v", err)
}
} }
// MuxRouteRegistration implements RouteRegistration for Mux // MuxRouteRegistration implements RouteRegistration for Mux
@ -116,6 +122,16 @@ func (h *HTTPRequest) QueryParam(key string) string {
return h.req.URL.Query().Get(key) return h.req.URL.Query().Get(key)
} }
func (h *HTTPRequest) AllQueryParams() map[string]string {
params := make(map[string]string)
for key, values := range h.req.URL.Query() {
if len(values) > 0 {
params[key] = values[0]
}
}
return params
}
func (h *HTTPRequest) AllHeaders() map[string]string { func (h *HTTPRequest) AllHeaders() map[string]string {
headers := make(map[string]string) headers := make(map[string]string)
for key, values := range h.req.Header { for key, values := range h.req.Header {
@ -126,10 +142,16 @@ func (h *HTTPRequest) AllHeaders() map[string]string {
return headers return headers
} }
// UnderlyingRequest returns the underlying *http.Request
// This is useful when you need to pass the request to other handlers
func (h *HTTPRequest) UnderlyingRequest() *http.Request {
return h.req
}
// HTTPResponseWriter adapts our ResponseWriter interface to standard http.ResponseWriter // HTTPResponseWriter adapts our ResponseWriter interface to standard http.ResponseWriter
type HTTPResponseWriter struct { type HTTPResponseWriter struct {
resp http.ResponseWriter resp http.ResponseWriter
w common.ResponseWriter w common.ResponseWriter //nolint:unused
status int status int
} }
@ -155,6 +177,12 @@ func (h *HTTPResponseWriter) WriteJSON(data interface{}) error {
return json.NewEncoder(h.resp).Encode(data) return json.NewEncoder(h.resp).Encode(data)
} }
// UnderlyingResponseWriter returns the underlying http.ResponseWriter
// This is useful when you need to pass the response writer to other handlers
func (h *HTTPResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
return h.resp
}
// StandardMuxAdapter creates routes compatible with standard http.HandlerFunc // StandardMuxAdapter creates routes compatible with standard http.HandlerFunc
type StandardMuxAdapter struct { type StandardMuxAdapter struct {
*MuxAdapter *MuxAdapter

119
pkg/common/cors.go Normal file
View File

@ -0,0 +1,119 @@
package common
import (
"fmt"
"strings"
)
// CORSConfig holds CORS configuration
type CORSConfig struct {
AllowedOrigins []string
AllowedMethods []string
AllowedHeaders []string
MaxAge int
}
// DefaultCORSConfig returns a default CORS configuration suitable for HeadSpec
func DefaultCORSConfig() CORSConfig {
return CORSConfig{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"},
AllowedHeaders: GetHeadSpecHeaders(),
MaxAge: 86400, // 24 hours
}
}
// GetHeadSpecHeaders returns all headers used by HeadSpec
func GetHeadSpecHeaders() []string {
return []string{
// Standard headers
"Content-Type",
"Authorization",
"Accept",
"Accept-Language",
"Content-Language",
// Field Selection
"X-Select-Fields",
"X-Not-Select-Fields",
"X-Clean-JSON",
// Filtering & Search
"X-FieldFilter-*",
"X-SearchFilter-*",
"X-SearchOp-*",
"X-SearchOr-*",
"X-SearchAnd-*",
"X-SearchCols",
"X-Custom-SQL-W",
"X-Custom-SQL-W-*",
"X-Custom-SQL-Or",
"X-Custom-SQL-Or-*",
// Joins & Relations
"X-Preload",
"X-Preload-*",
"X-Expand",
"X-Expand-*",
"X-Custom-SQL-Join",
"X-Custom-SQL-Join-*",
// Sorting & Pagination
"X-Sort",
"X-Sort-*",
"X-Limit",
"X-Offset",
"X-Cursor-Forward",
"X-Cursor-Backward",
// Advanced Features
"X-AdvSQL-*",
"X-CQL-Sel-*",
"X-Distinct",
"X-SkipCount",
"X-SkipCache",
"X-Fetch-RowNumber",
"X-PKRow",
// Response Format
"X-SimpleAPI",
"X-DetailAPI",
"X-Syncfusion",
"X-Single-Record-As-Object",
// Transaction Control
"X-Transaction-Atomic",
// X-Files - comprehensive JSON configuration
"X-Files",
}
}
// SetCORSHeaders sets CORS headers on a response writer
func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
// Set allowed origins
if len(config.AllowedOrigins) > 0 {
w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
}
// Set allowed methods
if len(config.AllowedMethods) > 0 {
w.SetHeader("Access-Control-Allow-Methods", strings.Join(config.AllowedMethods, ", "))
}
// Set allowed headers
if len(config.AllowedHeaders) > 0 {
w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
}
// Set max age
if config.MaxAge > 0 {
w.SetHeader("Access-Control-Max-Age", fmt.Sprintf("%d", config.MaxAge))
}
// Allow credentials
w.SetHeader("Access-Control-Allow-Credentials", "true")
// Expose headers that clients can read
w.SetHeader("Access-Control-Expose-Headers", "Content-Range, X-Api-Range-Total, X-Api-Range-Size")
}

View File

@ -0,0 +1,97 @@
package common
// Example showing how to use the common handler interfaces
// This file demonstrates the handler interface hierarchy and usage patterns
// ProcessWithAnyHandler demonstrates using the base SpecHandler interface
// which works with any handler type (resolvespec, restheadspec, or funcspec)
func ProcessWithAnyHandler(handler SpecHandler) Database {
// All handlers expose GetDatabase() through the SpecHandler interface
return handler.GetDatabase()
}
// ProcessCRUDRequest demonstrates using the CRUDHandler interface
// which works with resolvespec.Handler and restheadspec.Handler
func ProcessCRUDRequest(handler CRUDHandler, w ResponseWriter, r Request, params map[string]string) {
// Both resolvespec and restheadspec handlers implement Handle()
handler.Handle(w, r, params)
}
// ProcessMetadataRequest demonstrates getting metadata from CRUD handlers
func ProcessMetadataRequest(handler CRUDHandler, w ResponseWriter, r Request, params map[string]string) {
// Both resolvespec and restheadspec handlers implement HandleGet()
handler.HandleGet(w, r, params)
}
// Example usage patterns (not executable, just for documentation):
/*
// Example 1: Using with resolvespec.Handler
func ExampleResolveSpec() {
db := // ... get database
registry := // ... get registry
handler := resolvespec.NewHandler(db, registry)
// Can be used as SpecHandler
var specHandler SpecHandler = handler
database := specHandler.GetDatabase()
// Can be used as CRUDHandler
var crudHandler CRUDHandler = handler
crudHandler.Handle(w, r, params)
crudHandler.HandleGet(w, r, params)
}
// Example 2: Using with restheadspec.Handler
func ExampleRestHeadSpec() {
db := // ... get database
registry := // ... get registry
handler := restheadspec.NewHandler(db, registry)
// Can be used as SpecHandler
var specHandler SpecHandler = handler
database := specHandler.GetDatabase()
// Can be used as CRUDHandler
var crudHandler CRUDHandler = handler
crudHandler.Handle(w, r, params)
crudHandler.HandleGet(w, r, params)
}
// Example 3: Using with funcspec.Handler
func ExampleFuncSpec() {
db := // ... get database
handler := funcspec.NewHandler(db)
// Can be used as SpecHandler
var specHandler SpecHandler = handler
database := specHandler.GetDatabase()
// Can be used as QueryHandler
var queryHandler QueryHandler = handler
// funcspec has different methods: SqlQueryList() and SqlQuery()
// which return HTTP handler functions
}
// Example 4: Polymorphic handler processing
func ProcessHandlers(handlers []SpecHandler) {
for _, handler := range handlers {
// All handlers expose the database
db := handler.GetDatabase()
// Type switch for specific handler types
switch h := handler.(type) {
case CRUDHandler:
// This is resolvespec or restheadspec
// Can call Handle() and HandleGet()
_ = h
case QueryHandler:
// This is funcspec
// Can call SqlQueryList() and SqlQuery()
_ = h
}
}
}
*/

View File

@ -0,0 +1,47 @@
package common
import (
"fmt"
"reflect"
)
// ValidateAndUnwrapModelResult contains the result of model validation
type ValidateAndUnwrapModelResult struct {
ModelType reflect.Type
Model interface{}
ModelPtr interface{}
OriginalType reflect.Type
}
// ValidateAndUnwrapModel validates that a model is a struct type and unwraps
// pointers, slices, and arrays to get to the base struct type.
// Returns an error if the model is not a valid struct type.
func ValidateAndUnwrapModel(model interface{}) (*ValidateAndUnwrapModelResult, error) {
modelType := reflect.TypeOf(model)
originalType := modelType
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil, fmt.Errorf("model must be a struct type, got %v. Ensure you register the struct (e.g., ModelCoreAccount{}) not a slice (e.g., []*ModelCoreAccount)", originalType)
}
// If the registered model was a pointer or slice, use the unwrapped struct type
if originalType != modelType {
model = reflect.New(modelType).Elem().Interface()
}
// Create a pointer to the model type for database operations
modelPtr := reflect.New(reflect.TypeOf(model)).Interface()
return &ValidateAndUnwrapModelResult{
ModelType: modelType,
Model: model,
ModelPtr: modelPtr,
OriginalType: originalType,
}, nil
}

View File

@ -1,6 +1,11 @@
package common package common
import "context" import (
"context"
"encoding/json"
"io"
"net/http"
)
// Database interface designed to work with both GORM and Bun // Database interface designed to work with both GORM and Bun
type Database interface { type Database interface {
@ -19,6 +24,12 @@ type Database interface {
CommitTx(ctx context.Context) error CommitTx(ctx context.Context) error
RollbackTx(ctx context.Context) error RollbackTx(ctx context.Context) error
RunInTransaction(ctx context.Context, fn func(Database) error) error RunInTransaction(ctx context.Context, fn func(Database) error) error
// GetUnderlyingDB returns the underlying database connection
// For GORM, this returns *gorm.DB
// For Bun, this returns *bun.DB
// This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN
GetUnderlyingDB() interface{}
} }
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun) // SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)
@ -26,11 +37,14 @@ type SelectQuery interface {
Model(model interface{}) SelectQuery Model(model interface{}) SelectQuery
Table(table string) SelectQuery Table(table string) SelectQuery
Column(columns ...string) SelectQuery Column(columns ...string) SelectQuery
ColumnExpr(query string, args ...interface{}) SelectQuery
Where(query string, args ...interface{}) SelectQuery Where(query string, args ...interface{}) SelectQuery
WhereOr(query string, args ...interface{}) SelectQuery WhereOr(query string, args ...interface{}) SelectQuery
Join(query string, args ...interface{}) SelectQuery Join(query string, args ...interface{}) SelectQuery
LeftJoin(query string, args ...interface{}) SelectQuery LeftJoin(query string, args ...interface{}) SelectQuery
Preload(relation string, conditions ...interface{}) SelectQuery Preload(relation string, conditions ...interface{}) SelectQuery
PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery
Order(order string) SelectQuery Order(order string) SelectQuery
Limit(n int) SelectQuery Limit(n int) SelectQuery
Offset(n int) SelectQuery Offset(n int) SelectQuery
@ -39,6 +53,7 @@ type SelectQuery interface {
// Execution methods // Execution methods
Scan(ctx context.Context, dest interface{}) error Scan(ctx context.Context, dest interface{}) error
ScanModel(ctx context.Context) error
Count(ctx context.Context) (int, error) Count(ctx context.Context) (int, error)
Exists(ctx context.Context) (bool, error) Exists(ctx context.Context) (bool, error)
} }
@ -113,6 +128,8 @@ type Request interface {
Body() ([]byte, error) Body() ([]byte, error)
PathParam(key string) string PathParam(key string) string
QueryParam(key string) string QueryParam(key string) string
AllQueryParams() map[string]string // Get all query parameters as a map
UnderlyingRequest() *http.Request // Get the underlying *http.Request for forwarding to other handlers
} }
// ResponseWriter interface abstracts HTTP response // ResponseWriter interface abstracts HTTP response
@ -121,17 +138,164 @@ type ResponseWriter interface {
WriteHeader(statusCode int) WriteHeader(statusCode int)
Write(data []byte) (int, error) Write(data []byte) (int, error)
WriteJSON(data interface{}) error WriteJSON(data interface{}) error
UnderlyingResponseWriter() http.ResponseWriter // Get the underlying http.ResponseWriter for forwarding to other handlers
} }
// HTTPHandlerFunc type for HTTP handlers // HTTPHandlerFunc type for HTTP handlers
type HTTPHandlerFunc func(ResponseWriter, Request) type HTTPHandlerFunc func(ResponseWriter, Request)
// WrapHTTPRequest wraps standard http.ResponseWriter and *http.Request into common interfaces
func WrapHTTPRequest(w http.ResponseWriter, r *http.Request) (ResponseWriter, Request) {
return &StandardResponseWriter{w: w}, &StandardRequest{r: r}
}
// StandardResponseWriter adapts http.ResponseWriter to ResponseWriter interface
type StandardResponseWriter struct {
w http.ResponseWriter
status int
}
func (s *StandardResponseWriter) SetHeader(key, value string) {
s.w.Header().Set(key, value)
}
func (s *StandardResponseWriter) WriteHeader(statusCode int) {
s.status = statusCode
s.w.WriteHeader(statusCode)
}
func (s *StandardResponseWriter) Write(data []byte) (int, error) {
return s.w.Write(data)
}
func (s *StandardResponseWriter) WriteJSON(data interface{}) error {
s.SetHeader("Content-Type", "application/json")
return json.NewEncoder(s.w).Encode(data)
}
func (s *StandardResponseWriter) UnderlyingResponseWriter() http.ResponseWriter {
return s.w
}
// StandardRequest adapts *http.Request to Request interface
type StandardRequest struct {
r *http.Request
body []byte
}
func (s *StandardRequest) Method() string {
return s.r.Method
}
func (s *StandardRequest) URL() string {
return s.r.URL.String()
}
func (s *StandardRequest) Header(key string) string {
return s.r.Header.Get(key)
}
func (s *StandardRequest) AllHeaders() map[string]string {
headers := make(map[string]string)
for key, values := range s.r.Header {
if len(values) > 0 {
headers[key] = values[0]
}
}
return headers
}
func (s *StandardRequest) Body() ([]byte, error) {
if s.body != nil {
return s.body, nil
}
if s.r.Body == nil {
return nil, nil
}
defer s.r.Body.Close()
body, err := io.ReadAll(s.r.Body)
if err != nil {
return nil, err
}
s.body = body
return body, nil
}
func (s *StandardRequest) PathParam(key string) string {
// Standard http.Request doesn't have path params
// This should be set by the router
return ""
}
func (s *StandardRequest) QueryParam(key string) string {
return s.r.URL.Query().Get(key)
}
func (s *StandardRequest) AllQueryParams() map[string]string {
params := make(map[string]string)
for key, values := range s.r.URL.Query() {
if len(values) > 0 {
params[key] = values[0]
}
}
return params
}
func (s *StandardRequest) UnderlyingRequest() *http.Request {
return s.r
}
// TableNameProvider interface for models that provide table names // TableNameProvider interface for models that provide table names
type TableNameProvider interface { type TableNameProvider interface {
TableName() string TableName() string
} }
type TableAliasProvider interface {
TableAlias() string
}
// PrimaryKeyNameProvider interface for models that provide primary key column names
type PrimaryKeyNameProvider interface {
GetIDName() string
}
// SchemaProvider interface for models that provide schema names // SchemaProvider interface for models that provide schema names
type SchemaProvider interface { type SchemaProvider interface {
SchemaName() string SchemaName() string
} }
// SpecHandler interface represents common functionality across all spec handlers
// This is the base interface implemented by:
// - resolvespec.Handler: Handles CRUD operations via request body with explicit operation field
// - restheadspec.Handler: Handles CRUD operations via HTTP methods (GET/POST/PUT/DELETE)
// - funcspec.Handler: Handles custom SQL query execution with dynamic parameters
//
// The interface hierarchy is:
//
// SpecHandler (base)
// ├── CRUDHandler (resolvespec, restheadspec)
// └── QueryHandler (funcspec)
type SpecHandler interface {
// GetDatabase returns the underlying database connection
GetDatabase() Database
}
// CRUDHandler interface for handlers that support CRUD operations
// This is implemented by resolvespec.Handler and restheadspec.Handler
type CRUDHandler interface {
SpecHandler
// Handle processes API requests through router-agnostic interface
Handle(w ResponseWriter, r Request, params map[string]string)
// HandleGet processes GET requests for metadata
HandleGet(w ResponseWriter, r Request, params map[string]string)
}
// QueryHandler interface for handlers that execute SQL queries
// This is implemented by funcspec.Handler
// Note: funcspec uses standard http.ResponseWriter and *http.Request instead of common interfaces
type QueryHandler interface {
SpecHandler
// Methods are defined in funcspec package due to different function signature requirements
}

View File

@ -0,0 +1,453 @@
package common
import (
"context"
"fmt"
"reflect"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// CRUDRequestProvider interface for models that provide CRUD request strings
type CRUDRequestProvider interface {
GetRequest() string
}
// RelationshipInfoProvider interface for handlers that can provide relationship info
type RelationshipInfoProvider interface {
GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo
}
// RelationshipInfo contains information about a model relationship
type RelationshipInfo struct {
FieldName string
JSONName string
RelationType string // "belongsTo", "hasMany", "hasOne", "many2many"
ForeignKey string
References string
JoinTable string
RelatedModel interface{}
}
// NestedCUDProcessor handles recursive processing of nested object graphs
type NestedCUDProcessor struct {
db Database
registry ModelRegistry
relationshipHelper RelationshipInfoProvider
}
// NewNestedCUDProcessor creates a new nested CUD processor
func NewNestedCUDProcessor(db Database, registry ModelRegistry, relationshipHelper RelationshipInfoProvider) *NestedCUDProcessor {
return &NestedCUDProcessor{
db: db,
registry: registry,
relationshipHelper: relationshipHelper,
}
}
// ProcessResult contains the result of processing a CUD operation
type ProcessResult struct {
ID interface{} // The ID of the processed record
AffectedRows int64 // Number of rows affected
Data map[string]interface{} // The processed data
RelationData map[string]interface{} // Data from processed relations
}
// ProcessNestedCUD recursively processes nested object graphs for Create, Update, Delete operations
// with automatic foreign key resolution
func (p *NestedCUDProcessor) ProcessNestedCUD(
ctx context.Context,
operation string, // "insert", "update", or "delete"
data map[string]interface{},
model interface{},
parentIDs map[string]interface{}, // Parent IDs for foreign key resolution
tableName string,
) (*ProcessResult, error) {
logger.Info("Processing nested CUD: operation=%s, table=%s", operation, tableName)
result := &ProcessResult{
Data: make(map[string]interface{}),
RelationData: make(map[string]interface{}),
}
// Check if data has a _request field that overrides the operation
if requestOp := p.extractCRUDRequest(data); requestOp != "" {
logger.Debug("Found _request override: %s", requestOp)
operation = requestOp
}
// Get model type for reflection
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nil, fmt.Errorf("model must be a struct type, got %v", modelType)
}
// Separate relation fields from regular fields
relationFields := make(map[string]*RelationshipInfo)
regularData := make(map[string]interface{})
for key, value := range data {
// Skip _request field in actual data processing
if key == "_request" {
continue
}
// Check if this field is a relation
relInfo := p.relationshipHelper.GetRelationshipInfo(modelType, key)
if relInfo != nil {
relationFields[key] = relInfo
result.RelationData[key] = value
} else {
regularData[key] = value
}
}
// Inject parent IDs for foreign key resolution
p.injectForeignKeys(regularData, modelType, parentIDs)
// Get the primary key name for this model
pkName := reflection.GetPrimaryKeyName(model)
// Process based on operation
switch strings.ToLower(operation) {
case "insert", "create":
id, err := p.processInsert(ctx, regularData, tableName)
if err != nil {
return nil, fmt.Errorf("insert failed: %w", err)
}
result.ID = id
result.AffectedRows = 1
result.Data = regularData
// Process child relations after parent insert (to get parent ID)
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType); err != nil {
return nil, fmt.Errorf("failed to process child relations: %w", err)
}
case "update":
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
if err != nil {
return nil, fmt.Errorf("update failed: %w", err)
}
result.ID = data[pkName]
result.AffectedRows = rows
result.Data = regularData
// Process child relations for update
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil {
return nil, fmt.Errorf("failed to process child relations: %w", err)
}
case "delete":
// Process child relations first (for referential integrity)
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil {
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
}
rows, err := p.processDelete(ctx, tableName, data[pkName])
if err != nil {
return nil, fmt.Errorf("delete failed: %w", err)
}
result.ID = data[pkName]
result.AffectedRows = rows
result.Data = regularData
default:
return nil, fmt.Errorf("unsupported operation: %s", operation)
}
logger.Info("Nested CUD completed: operation=%s, id=%v, rows=%d", operation, result.ID, result.AffectedRows)
return result, nil
}
// extractCRUDRequest extracts the request field from data if present
func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) string {
if request, ok := data["_request"]; ok {
if requestStr, ok := request.(string); ok {
return strings.ToLower(strings.TrimSpace(requestStr))
}
}
return ""
}
// injectForeignKeys injects parent IDs into data for foreign key fields
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
if len(parentIDs) == 0 {
return
}
// Iterate through model fields to find foreign key fields
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
jsonTag := field.Tag.Get("json")
jsonName := strings.Split(jsonTag, ",")[0]
// Check if this field is a foreign key and we have a parent ID for it
// Common patterns: DepartmentID, ManagerID, ProjectID, etc.
for parentKey, parentID := range parentIDs {
// Match field name patterns like "department_id" with parent key "department"
if strings.EqualFold(jsonName, parentKey+"_id") ||
strings.EqualFold(jsonName, parentKey+"id") ||
strings.EqualFold(field.Name, parentKey+"ID") {
// Only inject if not already present
if _, exists := data[jsonName]; !exists {
logger.Debug("Injecting foreign key: %s = %v", jsonName, parentID)
data[jsonName] = parentID
}
}
}
}
}
// processInsert handles insert operation
func (p *NestedCUDProcessor) processInsert(
ctx context.Context,
data map[string]interface{},
tableName string,
) (interface{}, error) {
logger.Debug("Inserting into %s with data: %+v", tableName, data)
query := p.db.NewInsert().Table(tableName)
for key, value := range data {
query = query.Value(key, value)
}
// Add RETURNING clause to get the inserted ID
query = query.Returning("id")
result, err := query.Exec(ctx)
if err != nil {
return nil, fmt.Errorf("insert exec failed: %w", err)
}
// Try to get the ID
var id interface{}
if lastID, err := result.LastInsertId(); err == nil && lastID > 0 {
id = lastID
} else if data["id"] != nil {
id = data["id"]
}
logger.Debug("Insert successful, ID: %v, rows affected: %d", id, result.RowsAffected())
return id, nil
}
// processUpdate handles update operation
func (p *NestedCUDProcessor) processUpdate(
ctx context.Context,
data map[string]interface{},
tableName string,
id interface{},
) (int64, error) {
if id == nil {
return 0, fmt.Errorf("update requires an ID")
}
logger.Debug("Updating %s with ID %v, data: %+v", tableName, id, data)
query := p.db.NewUpdate().Table(tableName).SetMap(data).Where(fmt.Sprintf("%s = ?", QuoteIdent(reflection.GetPrimaryKeyName(tableName))), id)
result, err := query.Exec(ctx)
if err != nil {
return 0, fmt.Errorf("update exec failed: %w", err)
}
rows := result.RowsAffected()
logger.Debug("Update successful, rows affected: %d", rows)
return rows, nil
}
// processDelete handles delete operation
func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string, id interface{}) (int64, error) {
if id == nil {
return 0, fmt.Errorf("delete requires an ID")
}
logger.Debug("Deleting from %s with ID %v", tableName, id)
query := p.db.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", QuoteIdent(reflection.GetPrimaryKeyName(tableName))), id)
result, err := query.Exec(ctx)
if err != nil {
return 0, fmt.Errorf("delete exec failed: %w", err)
}
rows := result.RowsAffected()
logger.Debug("Delete successful, rows affected: %d", rows)
return rows, nil
}
// processChildRelations recursively processes child relations
func (p *NestedCUDProcessor) processChildRelations(
ctx context.Context,
operation string,
parentID interface{},
relationFields map[string]*RelationshipInfo,
relationData map[string]interface{},
parentModelType reflect.Type,
) error {
for relationName, relInfo := range relationFields {
relationValue, exists := relationData[relationName]
if !exists || relationValue == nil {
continue
}
logger.Debug("Processing relation: %s, type: %s", relationName, relInfo.RelationType)
// Get the related model
field, found := parentModelType.FieldByName(relInfo.FieldName)
if !found {
logger.Warn("Field %s not found in model", relInfo.FieldName)
continue
}
// Get the model type for the relation
relatedModelType := field.Type
if relatedModelType.Kind() == reflect.Slice {
relatedModelType = relatedModelType.Elem()
}
if relatedModelType.Kind() == reflect.Ptr {
relatedModelType = relatedModelType.Elem()
}
// Create an instance of the related model
relatedModel := reflect.New(relatedModelType).Elem().Interface()
// Get table name for related model
relatedTableName := p.getTableNameForModel(relatedModel, relInfo.JSONName)
// Prepare parent IDs for foreign key injection
parentIDs := make(map[string]interface{})
if relInfo.ForeignKey != "" {
// Extract the base name from foreign key (e.g., "DepartmentID" -> "Department")
baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID")
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
parentIDs[baseName] = parentID
}
// Process based on relation type and data structure
switch v := relationValue.(type) {
case map[string]interface{}:
// Single related object
_, err := p.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
if err != nil {
return fmt.Errorf("failed to process relation %s: %w", relationName, err)
}
case []interface{}:
// Multiple related objects
for i, item := range v {
if itemMap, ok := item.(map[string]interface{}); ok {
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
if err != nil {
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
}
}
}
case []map[string]interface{}:
// Multiple related objects (typed slice)
for i, itemMap := range v {
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
if err != nil {
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
}
}
default:
logger.Warn("Unsupported relation data type for %s: %T", relationName, relationValue)
}
}
return nil
}
// getTableNameForModel gets the table name for a model
func (p *NestedCUDProcessor) getTableNameForModel(model interface{}, defaultName string) string {
if provider, ok := model.(TableNameProvider); ok {
tableName := provider.TableName()
if tableName != "" {
return tableName
}
}
return defaultName
}
// ShouldUseNestedProcessor determines if we should use nested CUD processing
// It recursively checks if the data contains:
// 1. A _request field at any level, OR
// 2. Nested relations that themselves contain further nested relations or _request fields
// This ensures nested processing is only used when there are deeply nested operations
func ShouldUseNestedProcessor(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider) bool {
return shouldUseNestedProcessorDepth(data, model, relationshipHelper, 0)
}
// shouldUseNestedProcessorDepth is the internal recursive implementation with depth tracking
func shouldUseNestedProcessorDepth(data map[string]interface{}, model interface{}, relationshipHelper RelationshipInfoProvider, depth int) bool {
// Check for _request field
if _, hasCRUDRequest := data["_request"]; hasCRUDRequest {
return true
}
// Get model type
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return false
}
// Check if data contains any fields that are relations (nested objects or arrays)
for key, value := range data {
// Skip _request and regular scalar fields
if key == "_request" {
continue
}
// Check if this field is a relation in the model
relInfo := relationshipHelper.GetRelationshipInfo(modelType, key)
if relInfo != nil {
// Check if the value is actually nested data (object or array)
switch v := value.(type) {
case map[string]interface{}, []interface{}, []map[string]interface{}:
// If we're already at a nested level (depth > 0) and found a relation,
// that means we have multi-level nesting, so return true
if depth > 0 {
return true
}
// At depth 0, recurse to check if the nested data has further nesting
switch typedValue := v.(type) {
case map[string]interface{}:
if shouldUseNestedProcessorDepth(typedValue, relInfo.RelatedModel, relationshipHelper, depth+1) {
return true
}
case []interface{}:
for _, item := range typedValue {
if itemMap, ok := item.(map[string]interface{}); ok {
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
return true
}
}
}
case []map[string]interface{}:
for _, itemMap := range typedValue {
if shouldUseNestedProcessorDepth(itemMap, relInfo.RelatedModel, relationshipHelper, depth+1) {
return true
}
}
}
}
}
}
return false
}

635
pkg/common/sql_helpers.go Normal file
View File

@ -0,0 +1,635 @@
package common
import (
"fmt"
"regexp"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ValidateAndFixPreloadWhere validates and normalizes WHERE clauses for preloads
//
// NOTE: For preload queries, table aliases from the parent query are not valid since
// the preload executes as a separate query with its own table alias. This function
// now simply validates basic syntax without requiring or adding prefixes.
// The actual alias normalization happens in the database adapter layer.
//
// Returns the WHERE clause and an error if it contains obviously invalid syntax.
func ValidateAndFixPreloadWhere(where string, relationName string) (string, error) {
if where == "" {
return where, nil
}
where = strings.TrimSpace(where)
// Just do basic validation - don't require or add prefixes
// The database adapter will handle alias normalization
// Check if the WHERE clause contains any qualified column references
// If it does, log a debug message but don't fail - let the adapter handle it
if strings.Contains(where, ".") {
logger.Debug("Preload WHERE clause for '%s' contains qualified column references: '%s'. "+
"Note: In preload context, table aliases from parent query are not available. "+
"The database adapter will normalize aliases automatically.", relationName, where)
}
// Validate that it's not empty or just whitespace
if where == "" {
return where, nil
}
// Return the WHERE clause as-is
// The BunSelectQuery.Where() method will handle alias normalization via normalizeTableAlias()
return where, nil
}
// IsSQLExpression checks if a condition is a SQL expression that shouldn't be prefixed
func IsSQLExpression(cond string) bool {
// Common SQL literals and expressions
sqlLiterals := []string{"true", "false", "null", "1=1", "1 = 1", "0=0", "0 = 0"}
for _, literal := range sqlLiterals {
if cond == literal {
return true
}
}
return false
}
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
// These conditions should be removed from WHERE clauses as they have no filtering effect
func IsTrivialCondition(cond string) bool {
cond = strings.TrimSpace(cond)
lowerCond := strings.ToLower(cond)
// Conditions that always evaluate to true
trivialConditions := []string{
"1=1", "1 = 1", "1= 1", "1 =1",
"true", "true = true", "true=true", "true= true", "true =true",
"0=0", "0 = 0", "0= 0", "0 =0",
}
for _, trivial := range trivialConditions {
if lowerCond == trivial {
return true
}
}
return false
}
// validateWhereClauseSecurity checks for dangerous SQL statements in WHERE clauses
// Returns an error if any dangerous keywords are found
func validateWhereClauseSecurity(where string) error {
if where == "" {
return nil
}
lowerWhere := strings.ToLower(where)
// List of dangerous SQL keywords that should never appear in WHERE clauses
dangerousKeywords := []string{
"delete ", "delete\t", "delete\n", "delete;",
"update ", "update\t", "update\n", "update;",
"truncate ", "truncate\t", "truncate\n", "truncate;",
"drop ", "drop\t", "drop\n", "drop;",
"alter ", "alter\t", "alter\n", "alter;",
"create ", "create\t", "create\n", "create;",
"insert ", "insert\t", "insert\n", "insert;",
"grant ", "grant\t", "grant\n", "grant;",
"revoke ", "revoke\t", "revoke\n", "revoke;",
"exec ", "exec\t", "exec\n", "exec;",
"execute ", "execute\t", "execute\n", "execute;",
";delete", ";update", ";truncate", ";drop", ";alter", ";create", ";insert",
}
for _, keyword := range dangerousKeywords {
if strings.Contains(lowerWhere, keyword) {
logger.Error("Dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
return fmt.Errorf("dangerous SQL keyword detected in WHERE clause: %s", strings.TrimSpace(keyword))
}
}
return nil
}
// SanitizeWhereClause removes trivial conditions and fixes incorrect table prefixes
// This function should be used everywhere a WHERE statement is sent to ensure clean, efficient SQL
//
// Parameters:
// - where: The WHERE clause string to sanitize
// - tableName: The correct table/relation name to use when fixing incorrect prefixes
// - options: Optional RequestOptions containing preload relations that should be allowed as valid prefixes
//
// Returns:
// - The sanitized WHERE clause with trivial conditions removed and incorrect prefixes fixed
// - An empty string if all conditions were trivial or the input was empty
//
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
// prefix matches a preloaded relation name, in which case it's left unchanged.
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
if where == "" {
return ""
}
where = strings.TrimSpace(where)
// Validate that the WHERE clause doesn't contain dangerous SQL statements
if err := validateWhereClauseSecurity(where); err != nil {
logger.Debug("Security validation failed for WHERE clause: %v", err)
return ""
}
// Strip outer parentheses and re-trim
where = stripOuterParentheses(where)
// Get valid columns from the model if tableName is provided
var validColumns map[string]bool
if tableName != "" {
validColumns = getValidColumnsForTable(tableName)
}
// Build a set of allowed table prefixes (main table + preloaded relations)
allowedPrefixes := make(map[string]bool)
if tableName != "" {
allowedPrefixes[tableName] = true
}
// Add preload relation names as allowed prefixes
if len(options) > 0 && options[0] != nil {
for pi := range options[0].Preload {
if options[0].Preload[pi].Relation != "" {
allowedPrefixes[options[0].Preload[pi].Relation] = true
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
}
}
}
// Split by AND to handle multiple conditions
conditions := splitByAND(where)
validConditions := make([]string, 0, len(conditions))
for _, cond := range conditions {
cond = strings.TrimSpace(cond)
if cond == "" {
continue
}
// Strip parentheses from the condition before checking
condToCheck := stripOuterParentheses(cond)
// Skip trivial conditions that always evaluate to true
if IsTrivialCondition(condToCheck) {
logger.Debug("Removing trivial condition: '%s'", cond)
continue
}
// If tableName is provided and the condition HAS a table prefix, check if it's correct
if tableName != "" && hasTablePrefix(condToCheck) {
// Extract the current prefix and column name
currentPrefix, columnName := extractTableAndColumn(condToCheck)
if currentPrefix != "" && columnName != "" {
// Check if the prefix is allowed (main table or preload relation)
if !allowedPrefixes[currentPrefix] {
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace the incorrect prefix with the correct main table name
oldRef := currentPrefix + "." + columnName
newRef := tableName + "." + columnName
cond = strings.Replace(cond, oldRef, newRef, 1)
logger.Debug("Fixed incorrect table prefix in condition: '%s' -> '%s'", oldRef, newRef)
} else {
logger.Debug("Skipping prefix fix for '%s.%s' - not a valid column in main table (might be preload relation)", currentPrefix, columnName)
}
}
}
} else if tableName != "" && !hasTablePrefix(condToCheck) {
// If tableName is provided and the condition DOESN'T have a table prefix,
// qualify unambiguous column references to prevent "ambiguous column" errors
// when there are multiple joins on the same table (e.g., recursive preloads)
columnName := extractUnqualifiedColumnName(condToCheck)
if columnName != "" && (validColumns == nil || isValidColumn(columnName, validColumns)) {
// Qualify the column with the table name
// Be careful to only replace the column name, not other occurrences of the string
oldRef := columnName
newRef := tableName + "." + columnName
// Use word boundary matching to avoid replacing partial matches
cond = qualifyColumnInCondition(cond, oldRef, newRef)
logger.Debug("Qualified unqualified column in condition: '%s' added table prefix '%s'", oldRef, tableName)
}
}
validConditions = append(validConditions, cond)
}
if len(validConditions) == 0 {
return ""
}
result := strings.Join(validConditions, " AND ")
if result != where {
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
}
return result
}
// stripOuterParentheses removes matching outer parentheses from a string
// It handles nested parentheses correctly
func stripOuterParentheses(s string) string {
s = strings.TrimSpace(s)
for {
if len(s) < 2 || s[0] != '(' || s[len(s)-1] != ')' {
return s
}
// Check if these parentheses match (i.e., they're the outermost pair)
depth := 0
matched := false
for i := 0; i < len(s); i++ {
switch s[i] {
case '(':
depth++
case ')':
depth--
if depth == 0 && i == len(s)-1 {
matched = true
} else if depth == 0 {
// Found a closing paren before the end, so outer parens don't match
return s
}
}
}
if !matched {
return s
}
// Strip the outer parentheses and continue
s = strings.TrimSpace(s[1 : len(s)-1])
}
}
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
// This is parenthesis-aware and won't split on AND operators inside subqueries
func splitByAND(where string) []string {
conditions := []string{}
currentCondition := strings.Builder{}
depth := 0 // Track parenthesis depth
i := 0
for i < len(where) {
ch := where[i]
// Track parenthesis depth
if ch == '(' {
depth++
currentCondition.WriteByte(ch)
i++
continue
} else if ch == ')' {
depth--
currentCondition.WriteByte(ch)
i++
continue
}
// Only look for AND operators at depth 0 (not inside parentheses)
if depth == 0 {
// Check if we're at an AND operator (case-insensitive)
// We need at least " AND " (5 chars) or " and " (5 chars)
if i+5 <= len(where) {
substring := where[i : i+5]
lowerSubstring := strings.ToLower(substring)
if lowerSubstring == " and " {
// Found an AND operator at the top level
// Add the current condition to the list
conditions = append(conditions, currentCondition.String())
currentCondition.Reset()
// Skip past the AND operator
i += 5
continue
}
}
}
// Not an AND operator or we're inside parentheses, just add the character
currentCondition.WriteByte(ch)
i++
}
// Add the last condition
if currentCondition.Len() > 0 {
conditions = append(conditions, currentCondition.String())
}
return conditions
}
// hasTablePrefix checks if a condition already has a table/relation prefix (contains a dot)
func hasTablePrefix(cond string) bool {
// Look for patterns like "table.column" or "`table`.`column`" or "\"table\".\"column\""
return strings.Contains(cond, ".")
}
// ExtractColumnName extracts the column name from a WHERE condition
// For example: "status = 'active'" returns "status"
func ExtractColumnName(cond string) string {
// Common SQL operators
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
for _, op := range operators {
if idx := strings.Index(cond, op); idx > 0 {
columnName := strings.TrimSpace(cond[:idx])
// Remove quotes if present
columnName = strings.Trim(columnName, "`\"'")
return columnName
}
}
// If no operator found, check if it's a simple identifier (for boolean columns)
parts := strings.Fields(cond)
if len(parts) > 0 {
columnName := strings.Trim(parts[0], "`\"'")
// Check if it's a valid identifier (not a SQL keyword)
if !IsSQLKeyword(strings.ToLower(columnName)) {
return columnName
}
}
return ""
}
// IsSQLKeyword checks if a string is a SQL keyword that shouldn't be treated as a column name
func IsSQLKeyword(word string) bool {
keywords := []string{"select", "from", "where", "and", "or", "not", "in", "is", "null", "true", "false", "like", "between", "exists"}
for _, kw := range keywords {
if word == kw {
return true
}
}
return false
}
// getValidColumnsForTable retrieves the valid SQL columns for a table from the model registry
// Returns a map of column names for fast lookup, or nil if the model is not found
func getValidColumnsForTable(tableName string) map[string]bool {
// Try to get the model from the registry
model, err := modelregistry.GetModelByName(tableName)
if err != nil {
// Model not found, return nil to indicate we should use fallback behavior
return nil
}
// Get SQL columns from the model
columns := reflection.GetSQLModelColumns(model)
if len(columns) == 0 {
// No columns found, return nil
return nil
}
// Build a map for fast lookup
columnMap := make(map[string]bool, len(columns))
for _, col := range columns {
columnMap[strings.ToLower(col)] = true
}
return columnMap
}
// extractTableAndColumn extracts the table prefix and column name from a qualified reference
// For example: "users.status = 'active'" returns ("users", "status")
// Returns empty strings if no table prefix is found
// This function is parenthesis-aware and will only look for operators outside of subqueries
func extractTableAndColumn(cond string) (table string, column string) {
// Common SQL operators to find the column reference
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is "}
var columnRef string
// Find the column reference (left side of the operator)
// We need to find the first operator that appears OUTSIDE of parentheses
minIdx := -1
for _, op := range operators {
idx := findOperatorOutsideParentheses(cond, op)
if idx > 0 && (minIdx == -1 || idx < minIdx) {
minIdx = idx
}
}
if minIdx > 0 {
columnRef = strings.TrimSpace(cond[:minIdx])
}
// If no operator found, the whole condition might be the column reference
if columnRef == "" {
parts := strings.Fields(cond)
if len(parts) > 0 {
columnRef = parts[0]
}
}
if columnRef == "" {
return "", ""
}
// Remove any quotes
columnRef = strings.Trim(columnRef, "`\"'")
// Check if there's a function call (contains opening parenthesis)
openParenIdx := strings.Index(columnRef, "(")
if openParenIdx >= 0 {
// There's a function call - find the FIRST dot after the opening paren
// This handles cases like: ifblnk(users.status, orders.status) - extracts users.status
dotIdx := strings.Index(columnRef[openParenIdx:], ".")
if dotIdx > 0 {
dotIdx += openParenIdx // Adjust to absolute position
// Extract table name (between paren and dot)
// Find the last opening paren before this dot
lastOpenParen := strings.LastIndex(columnRef[:dotIdx], "(")
table = columnRef[lastOpenParen+1 : dotIdx]
// Find the column name - it ends at comma, closing paren, whitespace, or end of string
columnStart := dotIdx + 1
columnEnd := len(columnRef)
for i := columnStart; i < len(columnRef); i++ {
ch := columnRef[i]
if ch == ',' || ch == ')' || ch == ' ' || ch == '\t' {
columnEnd = i
break
}
}
column = columnRef[columnStart:columnEnd]
// Remove quotes from table and column if present
table = strings.Trim(table, "`\"'")
column = strings.Trim(column, "`\"'")
return table, column
}
}
// No function call - check if it contains a dot (qualified reference)
// Use LastIndex to handle schema.table.column properly
if dotIdx := strings.LastIndex(columnRef, "."); dotIdx > 0 {
table = columnRef[:dotIdx]
column = columnRef[dotIdx+1:]
// Remove quotes from table and column if present
table = strings.Trim(table, "`\"'")
column = strings.Trim(column, "`\"'")
return table, column
}
return "", ""
}
// extractUnqualifiedColumnName extracts the column name from an unqualified condition
// For example: "rid_parentmastertaskitem is null" returns "rid_parentmastertaskitem"
// "status = 'active'" returns "status"
func extractUnqualifiedColumnName(cond string) string {
// Common SQL operators
operators := []string{" = ", " != ", " <> ", " > ", " >= ", " < ", " <= ", " LIKE ", " like ", " IN ", " in ", " IS ", " is ", " NOT ", " not "}
// Find the column reference (left side of the operator)
minIdx := -1
for _, op := range operators {
idx := strings.Index(cond, op)
if idx > 0 && (minIdx == -1 || idx < minIdx) {
minIdx = idx
}
}
var columnRef string
if minIdx > 0 {
columnRef = strings.TrimSpace(cond[:minIdx])
} else {
// No operator found, might be a single column reference
parts := strings.Fields(cond)
if len(parts) > 0 {
columnRef = parts[0]
}
}
if columnRef == "" {
return ""
}
// Remove any quotes
columnRef = strings.Trim(columnRef, "`\"'")
// Return empty if it contains a dot (already qualified) or function call
if strings.Contains(columnRef, ".") || strings.Contains(columnRef, "(") {
return ""
}
return columnRef
}
// qualifyColumnInCondition replaces an unqualified column name with a qualified one in a condition
// Uses word boundaries to avoid partial matches
// For example: qualifyColumnInCondition("rid_item is null", "rid_item", "table.rid_item")
// returns "table.rid_item is null"
func qualifyColumnInCondition(cond, oldRef, newRef string) string {
// Use word boundary matching with Go's supported regex syntax
// \b matches word boundaries
escapedOld := regexp.QuoteMeta(oldRef)
pattern := `\b` + escapedOld + `\b`
re, err := regexp.Compile(pattern)
if err != nil {
// If regex fails, fall back to simple string replacement
logger.Debug("Failed to compile regex for column qualification, using simple replace: %v", err)
return strings.Replace(cond, oldRef, newRef, 1)
}
// Only replace if the match is not preceded by a dot (to avoid replacing already qualified columns)
result := cond
matches := re.FindAllStringIndex(cond, -1)
// Process matches in reverse order to maintain correct indices
for i := len(matches) - 1; i >= 0; i-- {
match := matches[i]
start := match[0]
// Check if preceded by a dot (already qualified)
if start > 0 && cond[start-1] == '.' {
continue
}
// Replace this occurrence
result = result[:start] + newRef + result[match[1]:]
}
return result
}
// findOperatorOutsideParentheses finds the first occurrence of an operator outside of parentheses
// Returns the index of the operator, or -1 if not found or only found inside parentheses
func findOperatorOutsideParentheses(s string, operator string) int {
depth := 0
inSingleQuote := false
inDoubleQuote := false
for i := 0; i < len(s); i++ {
ch := s[i]
// Track quote state (operators inside quotes should be ignored)
if ch == '\'' && !inDoubleQuote {
inSingleQuote = !inSingleQuote
continue
}
if ch == '"' && !inSingleQuote {
inDoubleQuote = !inDoubleQuote
continue
}
// Skip if we're inside quotes
if inSingleQuote || inDoubleQuote {
continue
}
// Track parenthesis depth
switch ch {
case '(':
depth++
case ')':
depth--
}
// Only look for the operator when we're outside parentheses (depth == 0)
if depth == 0 {
// Check if the operator starts at this position
if i+len(operator) <= len(s) {
if s[i:i+len(operator)] == operator {
return i
}
}
}
}
return -1
}
// isValidColumn checks if a column name exists in the valid columns map
// Handles case-insensitive comparison
func isValidColumn(columnName string, validColumns map[string]bool) bool {
if validColumns == nil {
return true // No model info, assume valid
}
return validColumns[strings.ToLower(columnName)]
}

View File

@ -0,0 +1,641 @@
package common
import (
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
func TestSanitizeWhereClause(t *testing.T) {
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "trivial conditions in parentheses",
where: "(true AND true AND true)",
tableName: "mastertask",
expected: "",
},
{
name: "trivial conditions without parentheses",
where: "true AND true AND true",
tableName: "mastertask",
expected: "",
},
{
name: "single trivial condition",
where: "true",
tableName: "mastertask",
expected: "",
},
{
name: "valid condition with parentheses - prefix added to prevent ambiguity",
where: "(status = 'active')",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "mixed trivial and valid conditions - prefix added",
where: "true AND status = 'active' AND 1=1",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "condition with correct table prefix - unchanged",
where: "users.status = 'active'",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "condition with incorrect table prefix - fixed",
where: "wrong_table.status = 'active'",
tableName: "users",
expected: "users.status = 'active'",
},
{
name: "multiple conditions with incorrect prefix - fixed",
where: "wrong_table.status = 'active' AND wrong_table.age > 18",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "multiple valid conditions without prefix - prefixes added",
where: "status = 'active' AND age > 18",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "no table name provided",
where: "status = 'active'",
tableName: "",
expected: "status = 'active'",
},
{
name: "empty where clause",
where: "",
tableName: "users",
expected: "",
},
{
name: "mixed correct and incorrect prefixes",
where: "users.status = 'active' AND wrong_table.age > 18",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18",
},
{
name: "mixed case AND operators",
where: "status = 'active' AND age > 18 and name = 'John'",
tableName: "users",
expected: "users.status = 'active' AND users.age > 18 AND users.name = 'John'",
},
{
name: "subquery with ORDER BY and LIMIT - allowed",
where: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
tableName: "users",
expected: "users.id IN (SELECT users.id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
},
{
name: "dangerous DELETE keyword - blocked",
where: "status = 'active'; DELETE FROM users",
tableName: "users",
expected: "",
},
{
name: "dangerous UPDATE keyword - blocked",
where: "1=1; UPDATE users SET admin = true",
tableName: "users",
expected: "",
},
{
name: "dangerous TRUNCATE keyword - blocked",
where: "status = 'active' OR TRUNCATE TABLE users",
tableName: "users",
expected: "",
},
{
name: "dangerous DROP keyword - blocked",
where: "status = 'active'; DROP TABLE users",
tableName: "users",
expected: "",
},
{
name: "subquery with table alias should not be modified",
where: "apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576)",
tableName: "apiprovider",
expected: "apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576)",
},
{
name: "complex subquery with AND and multiple operators",
where: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
tableName: "apiprovider",
expected: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}
func TestStripOuterParentheses(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "single level parentheses",
input: "(true)",
expected: "true",
},
{
name: "multiple levels",
input: "((true))",
expected: "true",
},
{
name: "no parentheses",
input: "true",
expected: "true",
},
{
name: "mismatched parentheses",
input: "(true",
expected: "(true",
},
{
name: "complex expression",
input: "(a AND b)",
expected: "a AND b",
},
{
name: "nested but not outer",
input: "(a AND (b OR c)) AND d",
expected: "(a AND (b OR c)) AND d",
},
{
name: "with spaces",
input: " ( true ) ",
expected: "true",
},
{
name: "complex sub query",
input: "(a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3)",
expected: "a = 1 AND b = 2 or c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := stripOuterParentheses(tt.input)
if result != tt.expected {
t.Errorf("stripOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
}
})
}
}
func TestIsTrivialCondition(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{"true", "true", true},
{"true with spaces", " true ", true},
{"TRUE uppercase", "TRUE", true},
{"1=1", "1=1", true},
{"1 = 1", "1 = 1", true},
{"true = true", "true = true", true},
{"valid condition", "status = 'active'", false},
{"false", "false", false},
{"column name", "is_active", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsTrivialCondition(tt.input)
if result != tt.expected {
t.Errorf("IsTrivialCondition(%q) = %v; want %v", tt.input, result, tt.expected)
}
})
}
}
func TestExtractTableAndColumn(t *testing.T) {
tests := []struct {
name string
input string
expectedTable string
expectedCol string
}{
{
name: "qualified column with equals",
input: "users.status = 'active'",
expectedTable: "users",
expectedCol: "status",
},
{
name: "qualified column with greater than",
input: "users.age > 18",
expectedTable: "users",
expectedCol: "age",
},
{
name: "qualified column with LIKE",
input: "users.name LIKE '%john%'",
expectedTable: "users",
expectedCol: "name",
},
{
name: "qualified column with IN",
input: "users.status IN ('active', 'pending')",
expectedTable: "users",
expectedCol: "status",
},
{
name: "unqualified column",
input: "status = 'active'",
expectedTable: "",
expectedCol: "",
},
{
name: "qualified with backticks",
input: "`users`.`status` = 'active'",
expectedTable: "users",
expectedCol: "status",
},
{
name: "schema.table.column reference",
input: "public.users.status = 'active'",
expectedTable: "public.users",
expectedCol: "status",
},
{
name: "empty string",
input: "",
expectedTable: "",
expectedCol: "",
},
{
name: "function call with table.column - ifblnk",
input: "ifblnk(users.status,0) in (1,2,3,4)",
expectedTable: "users",
expectedCol: "status",
},
{
name: "function call with table.column - coalesce",
input: "coalesce(users.age, 0) = 25",
expectedTable: "users",
expectedCol: "age",
},
{
name: "nested function calls",
input: "upper(trim(users.name)) = 'JOHN'",
expectedTable: "users",
expectedCol: "name",
},
{
name: "function with multiple args and table.column",
input: "substring(users.email, 1, 5) = 'admin'",
expectedTable: "users",
expectedCol: "email",
},
{
name: "cast function with table.column",
input: "cast(orders.total as decimal) > 100",
expectedTable: "orders",
expectedCol: "total",
},
{
name: "complex nested functions",
input: "coalesce(nullif(users.status, ''), 'default') = 'active'",
expectedTable: "users",
expectedCol: "status",
},
{
name: "function with multiple table.column refs (extracts first)",
input: "greatest(users.created_at, users.updated_at) > '2024-01-01'",
expectedTable: "users",
expectedCol: "created_at",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
table, col := extractTableAndColumn(tt.input)
if table != tt.expectedTable || col != tt.expectedCol {
t.Errorf("extractTableAndColumn(%q) = (%q, %q); want (%q, %q)",
tt.input, table, col, tt.expectedTable, tt.expectedCol)
}
})
}
}
func TestSanitizeWhereClauseWithPreloads(t *testing.T) {
tests := []struct {
name string
where string
tableName string
options *RequestOptions
expected string
}{
{
name: "preload relation prefix is preserved",
where: "Department.name = 'Engineering'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
},
},
expected: "Department.name = 'Engineering'",
},
{
name: "multiple preload relations - all preserved",
where: "Department.name = 'Engineering' AND Manager.status = 'active'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
{Relation: "Manager"},
},
},
expected: "Department.name = 'Engineering' AND Manager.status = 'active'",
},
{
name: "mix of main table and preload relation",
where: "users.status = 'active' AND Department.name = 'Engineering'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
},
},
expected: "users.status = 'active' AND Department.name = 'Engineering'",
},
{
name: "incorrect prefix fixed when not a preload relation",
where: "wrong_table.status = 'active' AND Department.name = 'Engineering'",
tableName: "users",
options: &RequestOptions{
Preload: []PreloadOption{
{Relation: "Department"},
},
},
expected: "users.status = 'active' AND Department.name = 'Engineering'",
},
{
name: "Function Call with correct table prefix - unchanged",
where: "ifblnk(users.status,0) in (1,2,3,4)",
tableName: "users",
options: nil,
expected: "ifblnk(users.status,0) in (1,2,3,4)",
},
{
name: "no options provided - works as before",
where: "wrong_table.status = 'active'",
tableName: "users",
options: nil,
expected: "users.status = 'active'",
},
{
name: "empty preload list - works as before",
where: "wrong_table.status = 'active'",
tableName: "users",
options: &RequestOptions{Preload: []PreloadOption{}},
expected: "users.status = 'active'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var result string
if tt.options != nil {
result = SanitizeWhereClause(tt.where, tt.tableName, tt.options)
} else {
result = SanitizeWhereClause(tt.where, tt.tableName)
}
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q, options) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}
// Test model for model-aware sanitization tests
type MasterTask struct {
ID int `bun:"id,pk"`
Name string `bun:"name"`
Status string `bun:"status"`
UserID int `bun:"user_id"`
}
func TestSplitByAND(t *testing.T) {
tests := []struct {
name string
input string
expected []string
}{
{
name: "uppercase AND",
input: "status = 'active' AND age > 18",
expected: []string{"status = 'active'", "age > 18"},
},
{
name: "lowercase and",
input: "status = 'active' and age > 18",
expected: []string{"status = 'active'", "age > 18"},
},
{
name: "mixed case AND",
input: "status = 'active' AND age > 18 and name = 'John'",
expected: []string{"status = 'active'", "age > 18", "name = 'John'"},
},
{
name: "single condition",
input: "status = 'active'",
expected: []string{"status = 'active'"},
},
{
name: "multiple uppercase AND",
input: "a = 1 AND b = 2 AND c = 3",
expected: []string{"a = 1", "b = 2", "c = 3"},
},
{
name: "multiple case subquery",
input: "a = 1 AND b = 2 AND c = 3 and (select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3",
expected: []string{"a = 1", "b = 2", "c = 3", "(select s from generate_series(1,10) s where s < 10 and s > 0 offset 2 limit 1) = 3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := splitByAND(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("splitByAND(%q) returned %d conditions; want %d", tt.input, len(result), len(tt.expected))
return
}
for i := range result {
if strings.TrimSpace(result[i]) != strings.TrimSpace(tt.expected[i]) {
t.Errorf("splitByAND(%q)[%d] = %q; want %q", tt.input, i, result[i], tt.expected[i])
}
}
})
}
}
func TestValidateWhereClauseSecurity(t *testing.T) {
tests := []struct {
name string
input string
expectError bool
}{
{
name: "safe WHERE clause",
input: "status = 'active' AND age > 18",
expectError: false,
},
{
name: "safe subquery",
input: "id IN (SELECT id FROM users WHERE status = 'active' ORDER BY created_at DESC LIMIT 10)",
expectError: false,
},
{
name: "DELETE keyword",
input: "status = 'active'; DELETE FROM users",
expectError: true,
},
{
name: "UPDATE keyword",
input: "1=1; UPDATE users SET admin = true",
expectError: true,
},
{
name: "TRUNCATE keyword",
input: "status = 'active' OR TRUNCATE TABLE users",
expectError: true,
},
{
name: "DROP keyword",
input: "status = 'active'; DROP TABLE users",
expectError: true,
},
{
name: "INSERT keyword",
input: "status = 'active'; INSERT INTO users (name) VALUES ('hacker')",
expectError: true,
},
{
name: "ALTER keyword",
input: "1=1; ALTER TABLE users ADD COLUMN is_admin BOOLEAN",
expectError: true,
},
{
name: "CREATE keyword",
input: "1=1; CREATE TABLE malicious (id INT)",
expectError: true,
},
{
name: "empty clause",
input: "",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateWhereClauseSecurity(tt.input)
if tt.expectError && err == nil {
t.Errorf("validateWhereClauseSecurity(%q) expected error but got none", tt.input)
}
if !tt.expectError && err != nil {
t.Errorf("validateWhereClauseSecurity(%q) unexpected error: %v", tt.input, err)
}
})
}
}
func TestSanitizeWhereClauseWithModel(t *testing.T) {
// Register the test model
err := modelregistry.RegisterModel(MasterTask{}, "mastertask")
if err != nil {
// Model might already be registered, ignore error
t.Logf("Model registration returned: %v", err)
}
tests := []struct {
name string
where string
tableName string
expected string
}{
{
name: "valid column without prefix - no prefix added",
where: "status = 'active'",
tableName: "mastertask",
expected: "status = 'active'",
},
{
name: "multiple valid columns without prefix - no prefix added",
where: "status = 'active' AND user_id = 123",
tableName: "mastertask",
expected: "status = 'active' AND user_id = 123",
},
{
name: "incorrect table prefix on valid column - fixed",
where: "wrong_table.status = 'active'",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "incorrect prefix on invalid column - not fixed",
where: "wrong_table.invalid_column = 'value'",
tableName: "mastertask",
expected: "wrong_table.invalid_column = 'value'",
},
{
name: "mix of valid and trivial conditions",
where: "true AND status = 'active' AND 1=1",
tableName: "mastertask",
expected: "status = 'active'",
},
{
name: "parentheses with valid column - no prefix added",
where: "(status = 'active')",
tableName: "mastertask",
expected: "status = 'active'",
},
{
name: "correct prefix - unchanged",
where: "mastertask.status = 'active'",
tableName: "mastertask",
expected: "mastertask.status = 'active'",
},
{
name: "multiple conditions with mixed prefixes",
where: "mastertask.status = 'active' AND wrong_table.user_id = 123",
tableName: "mastertask",
expected: "mastertask.status = 'active' AND mastertask.user_id = 123",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeWhereClause(tt.where, tt.tableName)
if result != tt.expected {
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
}
})
}
}

579
pkg/common/sql_types.go Normal file
View File

@ -0,0 +1,579 @@
// Package common provides nullable SQL types with automatic casting and conversion methods.
package common
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
"time"
"github.com/google/uuid"
)
// tryParseDT attempts to parse a string into a time.Time using various formats.
func tryParseDT(str string) (time.Time, error) {
var lasterror error
tryFormats := []string{
time.RFC3339,
"2006-01-02T15:04:05.000-0700",
"2006-01-02T15:04:05.000",
"06-01-02T15:04:05.000",
"2006-01-02T15:04:05",
"2006-01-02 15:04:05",
"02/01/2006",
"02-01-2006",
"2006-01-02",
"15:04:05.000",
"15:04:05",
"15:04",
}
for _, f := range tryFormats {
tx, err := time.Parse(f, str)
if err == nil {
return tx, nil
}
lasterror = err
}
return time.Time{}, lasterror // Return zero time on failure
}
// ToJSONDT formats a time.Time to RFC3339 string.
func ToJSONDT(dt time.Time) string {
return dt.Format(time.RFC3339)
}
// SqlNull is a generic nullable type that behaves like sql.NullXXX with auto-casting.
type SqlNull[T any] struct {
Val T
Valid bool
}
// Scan implements sql.Scanner.
func (n *SqlNull[T]) Scan(value any) error {
if value == nil {
n.Valid = false
n.Val = *new(T)
return nil
}
// Try standard sql.Null[T] first.
var sqlNull sql.Null[T]
if err := sqlNull.Scan(value); err == nil {
n.Val = sqlNull.V
n.Valid = sqlNull.Valid
return nil
}
// Fallback: parse from string/bytes.
switch v := value.(type) {
case string:
return n.FromString(v)
case []byte:
return n.FromString(string(v))
default:
return n.FromString(fmt.Sprintf("%v", value))
}
}
func (n *SqlNull[T]) FromString(s string) error {
s = strings.TrimSpace(s)
n.Valid = false
n.Val = *new(T)
if s == "" || strings.EqualFold(s, "null") {
return nil
}
var zero T
switch any(zero).(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
reflect.ValueOf(&n.Val).Elem().SetInt(i)
n.Valid = true
}
case float32, float64:
if f, err := strconv.ParseFloat(s, 64); err == nil {
reflect.ValueOf(&n.Val).Elem().SetFloat(f)
n.Valid = true
}
case bool:
if b, err := strconv.ParseBool(s); err == nil {
n.Val = any(b).(T)
n.Valid = true
}
case time.Time:
if t, err := tryParseDT(s); err == nil && !t.IsZero() {
n.Val = any(t).(T)
n.Valid = true
}
case uuid.UUID:
if u, err := uuid.Parse(s); err == nil {
n.Val = any(u).(T)
n.Valid = true
}
case string:
n.Val = any(s).(T)
n.Valid = true
}
return nil
}
// Value implements driver.Valuer.
func (n SqlNull[T]) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return any(n.Val), nil
}
// MarshalJSON implements json.Marshaler.
func (n SqlNull[T]) MarshalJSON() ([]byte, error) {
if !n.Valid {
return []byte("null"), nil
}
return json.Marshal(n.Val)
}
// UnmarshalJSON implements json.Unmarshaler.
func (n *SqlNull[T]) UnmarshalJSON(b []byte) error {
if len(b) == 0 || string(b) == "null" || strings.TrimSpace(string(b)) == "" {
n.Valid = false
n.Val = *new(T)
return nil
}
// Try direct unmarshal.
var val T
if err := json.Unmarshal(b, &val); err == nil {
n.Val = val
n.Valid = true
return nil
}
// Fallback: unmarshal as string and parse.
var s string
if err := json.Unmarshal(b, &s); err == nil {
return n.FromString(s)
}
return fmt.Errorf("cannot unmarshal %s into SqlNull[%T]", b, n.Val)
}
// String implements fmt.Stringer.
func (n SqlNull[T]) String() string {
if !n.Valid {
return ""
}
return fmt.Sprintf("%v", n.Val)
}
// Int64 converts to int64 or 0 if invalid.
func (n SqlNull[T]) Int64() int64 {
if !n.Valid {
return 0
}
v := reflect.ValueOf(any(n.Val))
switch v.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return v.Int()
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return int64(v.Uint())
case reflect.Float32, reflect.Float64:
return int64(v.Float())
case reflect.String:
i, _ := strconv.ParseInt(v.String(), 10, 64)
return i
case reflect.Bool:
if v.Bool() {
return 1
}
return 0
}
return 0
}
// Float64 converts to float64 or 0.0 if invalid.
func (n SqlNull[T]) Float64() float64 {
if !n.Valid {
return 0.0
}
v := reflect.ValueOf(any(n.Val))
switch v.Kind() {
case reflect.Float32, reflect.Float64:
return v.Float()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return float64(v.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return float64(v.Uint())
case reflect.String:
f, _ := strconv.ParseFloat(v.String(), 64)
return f
}
return 0.0
}
// Bool converts to bool or false if invalid.
func (n SqlNull[T]) Bool() bool {
if !n.Valid {
return false
}
v := reflect.ValueOf(any(n.Val))
if v.Kind() == reflect.Bool {
return v.Bool()
}
s := strings.ToLower(strings.TrimSpace(fmt.Sprint(n.Val)))
return s == "true" || s == "t" || s == "1" || s == "yes" || s == "on"
}
// Time converts to time.Time or zero if invalid.
func (n SqlNull[T]) Time() time.Time {
if !n.Valid {
return time.Time{}
}
if t, ok := any(n.Val).(time.Time); ok {
return t
}
return time.Time{}
}
// UUID converts to uuid.UUID or Nil if invalid.
func (n SqlNull[T]) UUID() uuid.UUID {
if !n.Valid {
return uuid.Nil
}
if u, ok := any(n.Val).(uuid.UUID); ok {
return u
}
return uuid.Nil
}
// Type aliases for common types.
type (
SqlInt16 = SqlNull[int16]
SqlInt32 = SqlNull[int32]
SqlInt64 = SqlNull[int64]
SqlFloat64 = SqlNull[float64]
SqlBool = SqlNull[bool]
SqlString = SqlNull[string]
SqlUUID = SqlNull[uuid.UUID]
)
// SqlTimeStamp - Timestamp with custom formatting (YYYY-MM-DDTHH:MM:SS).
type SqlTimeStamp struct{ SqlNull[time.Time] }
func (t SqlTimeStamp) MarshalJSON() ([]byte, error) {
if !t.Valid || t.Val.IsZero() || t.Val.Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) {
return []byte("null"), nil
}
return []byte(fmt.Sprintf(`"%s"`, t.Val.Format("2006-01-02T15:04:05"))), nil
}
func (t *SqlTimeStamp) UnmarshalJSON(b []byte) error {
if err := t.SqlNull.UnmarshalJSON(b); err != nil {
return err
}
if t.Valid && (t.Val.IsZero() || t.Val.Format("2006-01-02T15:04:05") == "0001-01-01T00:00:00") {
t.Valid = false
}
return nil
}
func (t SqlTimeStamp) Value() (driver.Value, error) {
if !t.Valid || t.Val.IsZero() || t.Val.Before(time.Date(0002, 1, 1, 0, 0, 0, 0, time.UTC)) {
return nil, nil
}
return t.Val.Format("2006-01-02T15:04:05"), nil
}
func SqlTimeStampNow() SqlTimeStamp {
return SqlTimeStamp{SqlNull: SqlNull[time.Time]{Val: time.Now(), Valid: true}}
}
// SqlDate - Date only (YYYY-MM-DD).
type SqlDate struct{ SqlNull[time.Time] }
func (d SqlDate) MarshalJSON() ([]byte, error) {
if !d.Valid || d.Val.IsZero() {
return []byte("null"), nil
}
s := d.Val.Format("2006-01-02")
if strings.HasPrefix(s, "0001-01-01") {
return []byte("null"), nil
}
return []byte(fmt.Sprintf(`"%s"`, s)), nil
}
func (d *SqlDate) UnmarshalJSON(b []byte) error {
if err := d.SqlNull.UnmarshalJSON(b); err != nil {
return err
}
if d.Valid && d.Val.Format("2006-01-02") <= "0001-01-01" {
d.Valid = false
}
return nil
}
func (d SqlDate) Value() (driver.Value, error) {
if !d.Valid || d.Val.IsZero() {
return nil, nil
}
s := d.Val.Format("2006-01-02")
if s <= "0001-01-01" {
return nil, nil
}
return s, nil
}
func (d SqlDate) String() string {
if !d.Valid {
return ""
}
s := d.Val.Format("2006-01-02")
if strings.HasPrefix(s, "0001-01-01") || strings.HasPrefix(s, "1800-12-31") {
return ""
}
return s
}
func SqlDateNow() SqlDate {
return SqlDate{SqlNull: SqlNull[time.Time]{Val: time.Now(), Valid: true}}
}
// SqlTime - Time only (HH:MM:SS).
type SqlTime struct{ SqlNull[time.Time] }
func (t SqlTime) MarshalJSON() ([]byte, error) {
if !t.Valid || t.Val.IsZero() {
return []byte("null"), nil
}
s := t.Val.Format("15:04:05")
if s == "00:00:00" {
return []byte("null"), nil
}
return []byte(fmt.Sprintf(`"%s"`, s)), nil
}
func (t *SqlTime) UnmarshalJSON(b []byte) error {
if err := t.SqlNull.UnmarshalJSON(b); err != nil {
return err
}
if t.Valid && t.Val.Format("15:04:05") == "00:00:00" {
t.Valid = false
}
return nil
}
func (t SqlTime) Value() (driver.Value, error) {
if !t.Valid || t.Val.IsZero() {
return nil, nil
}
return t.Val.Format("15:04:05"), nil
}
func (t SqlTime) String() string {
if !t.Valid {
return ""
}
return t.Val.Format("15:04:05")
}
func SqlTimeNow() SqlTime {
return SqlTime{SqlNull: SqlNull[time.Time]{Val: time.Now(), Valid: true}}
}
// SqlJSONB - Nullable JSONB as []byte.
type SqlJSONB []byte
// Scan implements sql.Scanner.
func (n *SqlJSONB) Scan(value any) error {
if value == nil {
*n = nil
return nil
}
switch v := value.(type) {
case string:
*n = []byte(v)
case []byte:
*n = v
default:
dat, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal value to JSON: %v", err)
}
*n = dat
}
return nil
}
// Value implements driver.Valuer.
func (n SqlJSONB) Value() (driver.Value, error) {
if len(n) == 0 {
return nil, nil
}
var js any
if err := json.Unmarshal(n, &js); err != nil {
return nil, fmt.Errorf("invalid JSON: %v", err)
}
return string(n), nil
}
// MarshalJSON implements json.Marshaler.
func (n SqlJSONB) MarshalJSON() ([]byte, error) {
if len(n) == 0 {
return []byte("null"), nil
}
var obj any
if err := json.Unmarshal(n, &obj); err != nil {
return []byte("null"), nil
}
return n, nil
}
// UnmarshalJSON implements json.Unmarshaler.
func (n *SqlJSONB) UnmarshalJSON(b []byte) error {
s := strings.TrimSpace(string(b))
if s == "null" || s == "" || (!strings.HasPrefix(s, "{") && !strings.HasPrefix(s, "[")) {
*n = nil
return nil
}
*n = b
return nil
}
func (n SqlJSONB) AsMap() (map[string]any, error) {
if len(n) == 0 {
return nil, nil
}
js := make(map[string]any)
if err := json.Unmarshal(n, &js); err != nil {
return nil, fmt.Errorf("invalid JSON: %v", err)
}
return js, nil
}
func (n SqlJSONB) AsSlice() ([]any, error) {
if len(n) == 0 {
return nil, nil
}
js := make([]any, 0)
if err := json.Unmarshal(n, &js); err != nil {
return nil, fmt.Errorf("invalid JSON: %v", err)
}
return js, nil
}
// TryIfInt64 tries to parse any value to int64 with default.
func TryIfInt64(v any, def int64) int64 {
switch val := v.(type) {
case string:
i, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return def
}
return i
case int:
return int64(val)
case int8:
return int64(val)
case int16:
return int64(val)
case int32:
return int64(val)
case int64:
return val
case uint:
return int64(val)
case uint8:
return int64(val)
case uint16:
return int64(val)
case uint32:
return int64(val)
case uint64:
return int64(val)
case float32:
return int64(val)
case float64:
return int64(val)
case []byte:
i, err := strconv.ParseInt(string(val), 10, 64)
if err != nil {
return def
}
return i
default:
return def
}
}
// Constructor helpers - clean and fast value creation
func Null[T any](v T, valid bool) SqlNull[T] {
return SqlNull[T]{Val: v, Valid: valid}
}
func NewSql[T any](value any) SqlNull[T] {
n := SqlNull[T]{}
if value == nil {
return n
}
// Fast path: exact match
if v, ok := value.(T); ok {
n.Val = v
n.Valid = true
return n
}
// Try from another SqlNull
if sn, ok := value.(SqlNull[T]); ok {
return sn
}
// Convert via string
_ = n.FromString(fmt.Sprintf("%v", value))
return n
}
func NewSqlInt16(v int16) SqlInt16 {
return SqlInt16{Val: v, Valid: true}
}
func NewSqlInt32(v int32) SqlInt32 {
return SqlInt32{Val: v, Valid: true}
}
func NewSqlInt64(v int64) SqlInt64 {
return SqlInt64{Val: v, Valid: true}
}
func NewSqlFloat64(v float64) SqlFloat64 {
return SqlFloat64{Val: v, Valid: true}
}
func NewSqlBool(v bool) SqlBool {
return SqlBool{Val: v, Valid: true}
}
func NewSqlString(v string) SqlString {
return SqlString{Val: v, Valid: true}
}
func NewSqlUUID(v uuid.UUID) SqlUUID {
return SqlUUID{Val: v, Valid: true}
}
func NewSqlTimeStamp(v time.Time) SqlTimeStamp {
return SqlTimeStamp{SqlNull: SqlNull[time.Time]{Val: v, Valid: true}}
}
func NewSqlDate(v time.Time) SqlDate {
return SqlDate{SqlNull: SqlNull[time.Time]{Val: v, Valid: true}}
}
func NewSqlTime(v time.Time) SqlTime {
return SqlTime{SqlNull: SqlNull[time.Time]{Val: v, Valid: true}}
}

View File

@ -0,0 +1,566 @@
package common
import (
"database/sql/driver"
"encoding/json"
"testing"
"time"
"github.com/google/uuid"
)
// TestNewSqlInt16 tests NewSqlInt16 type
func TestNewSqlInt16(t *testing.T) {
tests := []struct {
name string
input interface{}
expected SqlInt16
}{
{"int", 42, Null(int16(42), true)},
{"int32", int32(100), NewSqlInt16(100)},
{"int64", int64(200), NewSqlInt16(200)},
{"string", "123", NewSqlInt16(123)},
{"nil", nil, Null(int16(0), false)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var n SqlInt16
if err := n.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if n != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n)
}
})
}
}
func TestNewSqlInt16_Value(t *testing.T) {
tests := []struct {
name string
input SqlInt16
expected driver.Value
}{
{"zero", Null(int16(0), false), nil},
{"positive", NewSqlInt16(42), int16(42)},
{"negative", NewSqlInt16(-10), int16(-10)},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, err := tt.input.Value()
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if val != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, val)
}
})
}
}
func TestNewSqlInt16_JSON(t *testing.T) {
n := NewSqlInt16(42)
// Marshal
data, err := json.Marshal(n)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := "42"
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
// Unmarshal
var n2 SqlInt16
if err := json.Unmarshal([]byte("123"), &n2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if n2.Int64() != 123 {
t.Errorf("expected 123, got %d", n2.Int64())
}
}
// TestNewSqlInt64 tests NewSqlInt64 type
func TestNewSqlInt64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected SqlInt64
}{
{"int", 42, NewSqlInt64(42)},
{"int32", int32(100), NewSqlInt64(100)},
{"int64", int64(9223372036854775807), NewSqlInt64(9223372036854775807)},
{"uint32", uint32(100), NewSqlInt64(100)},
{"uint64", uint64(200), NewSqlInt64(200)},
{"nil", nil, SqlInt64{}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var n SqlInt64
if err := n.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if n != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n)
}
})
}
}
// TestSqlFloat64 tests SqlFloat64 type
func TestSqlFloat64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected float64
valid bool
}{
{"float64", float64(3.14), 3.14, true},
{"float32", float32(2.5), 2.5, true},
{"int", 42, 42.0, true},
{"int64", int64(100), 100.0, true},
{"nil", nil, 0, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var n SqlFloat64
if err := n.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if n.Valid != tt.valid {
t.Errorf("expected valid=%v, got valid=%v", tt.valid, n.Valid)
}
if tt.valid && n.Float64() != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, n.Float64())
}
})
}
}
// TestSqlTimeStamp tests SqlTimeStamp type
func TestSqlTimeStamp(t *testing.T) {
now := time.Now()
tests := []struct {
name string
input interface{}
}{
{"time.Time", now},
{"string RFC3339", now.Format(time.RFC3339)},
{"string date", "2024-01-15"},
{"string datetime", "2024-01-15T10:30:00"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var ts SqlTimeStamp
if err := ts.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if ts.Time().IsZero() {
t.Error("expected non-zero time")
}
})
}
}
func TestSqlTimeStamp_JSON(t *testing.T) {
now := time.Date(2024, 1, 15, 10, 30, 45, 0, time.UTC)
ts := NewSqlTimeStamp(now)
// Marshal
data, err := json.Marshal(ts)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `"2024-01-15T10:30:45"`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
// Unmarshal
var ts2 SqlTimeStamp
if err := json.Unmarshal([]byte(`"2024-01-15T10:30:45"`), &ts2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if ts2.Time().Year() != 2024 {
t.Errorf("expected year 2024, got %d", ts2.Time().Year())
}
// Test null
var ts3 SqlTimeStamp
if err := json.Unmarshal([]byte("null"), &ts3); err != nil {
t.Fatalf("Unmarshal null failed: %v", err)
}
}
// TestSqlDate tests SqlDate type
func TestSqlDate(t *testing.T) {
now := time.Now()
tests := []struct {
name string
input interface{}
}{
{"time.Time", now},
{"string date", "2024-01-15"},
{"string UK format", "15/01/2024"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var d SqlDate
if err := d.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if d.String() == "0" {
t.Error("expected non-zero date")
}
})
}
}
func TestSqlDate_JSON(t *testing.T) {
date := NewSqlDate(time.Date(2024, 1, 15, 0, 0, 0, 0, time.UTC))
// Marshal
data, err := json.Marshal(date)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `"2024-01-15"`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
// Unmarshal
var d2 SqlDate
if err := json.Unmarshal([]byte(`"2024-01-15"`), &d2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
}
// TestSqlTime tests SqlTime type
func TestSqlTime(t *testing.T) {
now := time.Now()
tests := []struct {
name string
input interface{}
expected string
}{
{"time.Time", now, now.Format("15:04:05")},
{"string time", "10:30:45", "10:30:45"},
{"string short time", "10:30", "10:30:00"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var tm SqlTime
if err := tm.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if tm.String() != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, tm.String())
}
})
}
}
// TestSqlJSONB tests SqlJSONB type
func TestSqlJSONB_Scan(t *testing.T) {
tests := []struct {
name string
input interface{}
expected string
}{
{"string JSON object", `{"key":"value"}`, `{"key":"value"}`},
{"string JSON array", `[1,2,3]`, `[1,2,3]`},
{"bytes", []byte(`{"test":true}`), `{"test":true}`},
{"nil", nil, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var j SqlJSONB
if err := j.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if tt.expected == "" && j == nil {
return // nil case
}
if string(j) != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, string(j))
}
})
}
}
func TestSqlJSONB_Value(t *testing.T) {
tests := []struct {
name string
input SqlJSONB
expected string
wantErr bool
}{
{"valid object", SqlJSONB(`{"key":"value"}`), `{"key":"value"}`, false},
{"valid array", SqlJSONB(`[1,2,3]`), `[1,2,3]`, false},
{"empty", SqlJSONB{}, "", false},
{"nil", nil, "", false},
{"invalid JSON", SqlJSONB(`{invalid`), "", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, err := tt.input.Value()
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if tt.expected == "" && val == nil {
return // nil case
}
if val.(string) != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, val)
}
})
}
}
func TestSqlJSONB_JSON(t *testing.T) {
// Marshal
j := SqlJSONB(`{"name":"test","count":42}`)
data, err := json.Marshal(j)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
var result map[string]interface{}
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("Unmarshal result failed: %v", err)
}
if result["name"] != "test" {
t.Errorf("expected name=test, got %v", result["name"])
}
// Unmarshal
var j2 SqlJSONB
if err := json.Unmarshal([]byte(`{"key":"value"}`), &j2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if string(j2) != `{"key":"value"}` {
t.Errorf("expected {\"key\":\"value\"}, got %s", string(j2))
}
// Test null
var j3 SqlJSONB
if err := json.Unmarshal([]byte("null"), &j3); err != nil {
t.Fatalf("Unmarshal null failed: %v", err)
}
}
func TestSqlJSONB_AsMap(t *testing.T) {
tests := []struct {
name string
input SqlJSONB
wantErr bool
wantNil bool
}{
{"valid object", SqlJSONB(`{"name":"test","age":30}`), false, false},
{"empty", SqlJSONB{}, false, true},
{"nil", nil, false, true},
{"invalid JSON", SqlJSONB(`{invalid`), true, false},
{"array not object", SqlJSONB(`[1,2,3]`), true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m, err := tt.input.AsMap()
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("AsMap failed: %v", err)
}
if tt.wantNil {
if m != nil {
t.Errorf("expected nil, got %v", m)
}
return
}
if m == nil {
t.Error("expected non-nil map")
}
})
}
}
func TestSqlJSONB_AsSlice(t *testing.T) {
tests := []struct {
name string
input SqlJSONB
wantErr bool
wantNil bool
}{
{"valid array", SqlJSONB(`[1,2,3]`), false, false},
{"empty", SqlJSONB{}, false, true},
{"nil", nil, false, true},
{"invalid JSON", SqlJSONB(`[invalid`), true, false},
{"object not array", SqlJSONB(`{"key":"value"}`), true, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s, err := tt.input.AsSlice()
if tt.wantErr {
if err == nil {
t.Error("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("AsSlice failed: %v", err)
}
if tt.wantNil {
if s != nil {
t.Errorf("expected nil, got %v", s)
}
return
}
if s == nil {
t.Error("expected non-nil slice")
}
})
}
}
// TestSqlUUID tests SqlUUID type
func TestSqlUUID_Scan(t *testing.T) {
testUUID := uuid.New()
testUUIDStr := testUUID.String()
tests := []struct {
name string
input interface{}
expected string
valid bool
}{
{"string UUID", testUUIDStr, testUUIDStr, true},
{"bytes UUID", []byte(testUUIDStr), testUUIDStr, true},
{"nil", nil, "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var u SqlUUID
if err := u.Scan(tt.input); err != nil {
t.Fatalf("Scan failed: %v", err)
}
if u.Valid != tt.valid {
t.Errorf("expected valid=%v, got valid=%v", tt.valid, u.Valid)
}
if tt.valid && u.String() != tt.expected {
t.Errorf("expected %s, got %s", tt.expected, u.String())
}
})
}
}
func TestSqlUUID_Value(t *testing.T) {
testUUID := uuid.New()
u := NewSqlUUID(testUUID)
val, err := u.Value()
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if val != testUUID {
t.Errorf("expected %s, got %s", testUUID.String(), val)
}
// Test invalid UUID
u2 := SqlUUID{Valid: false}
val2, err := u2.Value()
if err != nil {
t.Fatalf("Value failed: %v", err)
}
if val2 != nil {
t.Errorf("expected nil, got %v", val2)
}
}
func TestSqlUUID_JSON(t *testing.T) {
testUUID := uuid.New()
u := NewSqlUUID(testUUID)
// Marshal
data, err := json.Marshal(u)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `"` + testUUID.String() + `"`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
// Unmarshal
var u2 SqlUUID
if err := json.Unmarshal([]byte(`"`+testUUID.String()+`"`), &u2); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
if u2.String() != testUUID.String() {
t.Errorf("expected %s, got %s", testUUID.String(), u2.String())
}
// Test null
var u3 SqlUUID
if err := json.Unmarshal([]byte("null"), &u3); err != nil {
t.Fatalf("Unmarshal null failed: %v", err)
}
if u3.Valid {
t.Error("expected invalid UUID")
}
}
// TestTryIfInt64 tests the TryIfInt64 helper function
func TestTryIfInt64(t *testing.T) {
tests := []struct {
name string
input interface{}
def int64
expected int64
}{
{"string valid", "123", 0, 123},
{"string invalid", "abc", 99, 99},
{"int", 42, 0, 42},
{"int32", int32(100), 0, 100},
{"int64", int64(200), 0, 200},
{"uint32", uint32(50), 0, 50},
{"uint64", uint64(75), 0, 75},
{"float32", float32(3.14), 0, 3},
{"float64", float64(2.71), 0, 2},
{"bytes", []byte("456"), 0, 456},
{"unknown type", struct{}{}, 999, 999},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := TryIfInt64(tt.input, tt.def)
if result != tt.expected {
t.Errorf("expected %d, got %d", tt.expected, result)
}
})
}
}

View File

@ -18,6 +18,11 @@ type RequestOptions struct {
CustomOperators []CustomOperator `json:"customOperators"` CustomOperators []CustomOperator `json:"customOperators"`
ComputedColumns []ComputedColumn `json:"computedColumns"` ComputedColumns []ComputedColumn `json:"computedColumns"`
Parameters []Parameter `json:"parameters"` Parameters []Parameter `json:"parameters"`
// Cursor pagination
CursorForward string `json:"cursor_forward"`
CursorBackward string `json:"cursor_backward"`
FetchRowNumber *string `json:"fetch_row_number"`
} }
type Parameter struct { type Parameter struct {
@ -30,16 +35,26 @@ type PreloadOption struct {
Relation string `json:"relation"` Relation string `json:"relation"`
Columns []string `json:"columns"` Columns []string `json:"columns"`
OmitColumns []string `json:"omit_columns"` OmitColumns []string `json:"omit_columns"`
Sort []SortOption `json:"sort"`
Filters []FilterOption `json:"filters"` Filters []FilterOption `json:"filters"`
Where string `json:"where"`
Limit *int `json:"limit"` Limit *int `json:"limit"`
Offset *int `json:"offset"` Offset *int `json:"offset"`
Updatable *bool `json:"updateable"` // if true, the relation can be updated Updatable *bool `json:"updateable"` // if true, the relation can be updated
ComputedQL map[string]string `json:"computed_ql"` // Computed columns as SQL expressions
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
// Relationship keys from XFiles - used to build proper foreign key filters
PrimaryKey string `json:"primary_key"` // Primary key of the related table
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
} }
type FilterOption struct { type FilterOption struct {
Column string `json:"column"` Column string `json:"column"`
Operator string `json:"operator"` Operator string `json:"operator"`
Value interface{} `json:"value"` Value interface{} `json:"value"`
LogicOperator string `json:"logic_operator"` // "AND" or "OR" - how this filter combines with previous filters
} }
type SortOption struct { type SortOption struct {
@ -67,9 +82,11 @@ type Response struct {
type Metadata struct { type Metadata struct {
Total int64 `json:"total"` Total int64 `json:"total"`
Count int64 `json:"count"`
Filtered int64 `json:"filtered"` Filtered int64 `json:"filtered"`
Limit int `json:"limit"` Limit int `json:"limit"`
Offset int `json:"offset"` Offset int `json:"offset"`
RowNumber *int64 `json:"row_number,omitempty"`
} }
type APIError struct { type APIError struct {

287
pkg/common/validation.go Normal file
View File

@ -0,0 +1,287 @@
package common
import (
"fmt"
"reflect"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// ColumnValidator validates column names against a model's fields
type ColumnValidator struct {
validColumns map[string]bool
model interface{}
}
// NewColumnValidator creates a new column validator for a given model
func NewColumnValidator(model interface{}) *ColumnValidator {
validator := &ColumnValidator{
validColumns: make(map[string]bool),
model: model,
}
validator.buildValidColumns()
return validator
}
// buildValidColumns extracts all valid column names from the model using reflection
func (v *ColumnValidator) buildValidColumns() {
modelType := reflect.TypeOf(v.model)
// Unwrap pointers, slices, and arrays to get to the base struct type
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
modelType = modelType.Elem()
}
// Validate that we have a struct type
if modelType == nil || modelType.Kind() != reflect.Struct {
return
}
// Extract column names from struct fields
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
if !field.IsExported() {
continue
}
// Get column name from bun, gorm, or json tag
columnName := v.getColumnName(field)
if columnName != "" && columnName != "-" {
v.validColumns[strings.ToLower(columnName)] = true
}
}
}
// getColumnName extracts the column name from a struct field's tags
// Supports both Bun and GORM tags
func (v *ColumnValidator) getColumnName(field reflect.StructField) string {
// First check Bun tag for column name
bunTag := field.Tag.Get("bun")
if bunTag != "" && bunTag != "-" {
parts := strings.Split(bunTag, ",")
// The first part is usually the column name
columnName := strings.TrimSpace(parts[0])
if columnName != "" && columnName != "-" {
return columnName
}
}
// Check GORM tag for column name
gormTag := field.Tag.Get("gorm")
if strings.Contains(gormTag, "column:") {
parts := strings.Split(gormTag, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "column:") {
return strings.TrimPrefix(part, "column:")
}
}
}
// Fall back to JSON tag
jsonTag := field.Tag.Get("json")
if jsonTag != "" && jsonTag != "-" {
// Extract just the name part (before any comma)
jsonName := strings.Split(jsonTag, ",")[0]
return jsonName
}
// Fall back to field name in lowercase (snake_case conversion would be better)
return strings.ToLower(field.Name)
}
// ValidateColumn validates a single column name
// Returns nil if valid, error if invalid
// Columns prefixed with "cql" (case insensitive) are always valid
// Handles PostgreSQL JSON operators (-> and ->>)
func (v *ColumnValidator) ValidateColumn(column string) error {
// Allow empty columns
if column == "" {
return nil
}
// Allow columns prefixed with "cql" (case insensitive) for computed columns
if strings.HasPrefix(strings.ToLower(column), "cql") {
return nil
}
// Extract source column name (remove JSON operators like ->> or ->)
sourceColumn := reflection.ExtractSourceColumn(column)
// Check if column exists in model
if _, exists := v.validColumns[strings.ToLower(sourceColumn)]; !exists {
return fmt.Errorf("invalid column '%s': column does not exist in model", column)
}
return nil
}
// IsValidColumn checks if a column is valid
// Returns true if valid, false if invalid
func (v *ColumnValidator) IsValidColumn(column string) bool {
return v.ValidateColumn(column) == nil
}
// FilterValidColumns filters a list of columns, returning only valid ones
// Logs warnings for any invalid columns
func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
if len(columns) == 0 {
return columns
}
validColumns := make([]string, 0, len(columns))
for _, col := range columns {
if v.IsValidColumn(col) {
validColumns = append(validColumns, col)
} else {
logger.Warn("Invalid column '%s' filtered out: column does not exist in model", col)
}
}
return validColumns
}
// ValidateColumns validates multiple column names
// Returns error with details about all invalid columns
func (v *ColumnValidator) ValidateColumns(columns []string) error {
var invalidColumns []string
for _, column := range columns {
if err := v.ValidateColumn(column); err != nil {
invalidColumns = append(invalidColumns, column)
}
}
if len(invalidColumns) > 0 {
return fmt.Errorf("invalid columns: %s", strings.Join(invalidColumns, ", "))
}
return nil
}
// ValidateRequestOptions validates all column references in RequestOptions
func (v *ColumnValidator) ValidateRequestOptions(options RequestOptions) error {
// Validate Columns
if err := v.ValidateColumns(options.Columns); err != nil {
return fmt.Errorf("in select columns: %w", err)
}
// Validate OmitColumns
if err := v.ValidateColumns(options.OmitColumns); err != nil {
return fmt.Errorf("in omit columns: %w", err)
}
// Validate Filter columns
for _, filter := range options.Filters {
if err := v.ValidateColumn(filter.Column); err != nil {
return fmt.Errorf("in filter: %w", err)
}
}
// Validate Sort columns
for _, sort := range options.Sort {
if err := v.ValidateColumn(sort.Column); err != nil {
return fmt.Errorf("in sort: %w", err)
}
}
// Validate Preload columns (if specified)
for idx := range options.Preload {
preload := options.Preload[idx]
// Note: We don't validate the relation name itself, as it's a relationship
// Only validate columns if specified for the preload
if err := v.ValidateColumns(preload.Columns); err != nil {
return fmt.Errorf("in preload '%s' columns: %w", preload.Relation, err)
}
if err := v.ValidateColumns(preload.OmitColumns); err != nil {
return fmt.Errorf("in preload '%s' omit columns: %w", preload.Relation, err)
}
// Validate filter columns in preload
for _, filter := range preload.Filters {
if err := v.ValidateColumn(filter.Column); err != nil {
return fmt.Errorf("in preload '%s' filter: %w", preload.Relation, err)
}
}
}
return nil
}
// FilterRequestOptions filters all column references in RequestOptions
// Returns a new RequestOptions with only valid columns, logging warnings for invalid ones
func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOptions {
filtered := options
// Filter Columns
filtered.Columns = v.FilterValidColumns(options.Columns)
// Filter OmitColumns
filtered.OmitColumns = v.FilterValidColumns(options.OmitColumns)
// Filter Filter columns
validFilters := make([]FilterOption, 0, len(options.Filters))
for _, filter := range options.Filters {
if v.IsValidColumn(filter.Column) {
validFilters = append(validFilters, filter)
} else {
logger.Warn("Invalid column in filter '%s' removed", filter.Column)
}
}
filtered.Filters = validFilters
// Filter Sort columns
validSorts := make([]SortOption, 0, len(options.Sort))
for _, sort := range options.Sort {
if v.IsValidColumn(sort.Column) {
validSorts = append(validSorts, sort)
} else {
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
}
}
filtered.Sort = validSorts
// Filter Preload columns
validPreloads := make([]PreloadOption, 0, len(options.Preload))
for idx := range options.Preload {
preload := options.Preload[idx]
filteredPreload := preload
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
// Filter preload filters
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
for _, filter := range preload.Filters {
if v.IsValidColumn(filter.Column) {
validPreloadFilters = append(validPreloadFilters, filter)
} else {
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
}
}
filteredPreload.Filters = validPreloadFilters
validPreloads = append(validPreloads, filteredPreload)
}
filtered.Preload = validPreloads
return filtered
}
// GetValidColumns returns a list of all valid column names for debugging purposes
func (v *ColumnValidator) GetValidColumns() []string {
columns := make([]string, 0, len(v.validColumns))
for col := range v.validColumns {
columns = append(columns, col)
}
return columns
}
func QuoteIdent(qualifier string) string {
return `"` + strings.ReplaceAll(qualifier, `"`, `""`) + `"`
}
func QuoteLiteral(value string) string {
return `'` + strings.ReplaceAll(value, `'`, `''`) + `'`
}

View File

@ -0,0 +1,126 @@
package common
import (
"testing"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
func TestExtractSourceColumn(t *testing.T) {
testCases := []struct {
name string
input string
expected string
}{
{
name: "simple column name",
input: "columna",
expected: "columna",
},
{
name: "column with ->> operator",
input: "columna->>'val'",
expected: "columna",
},
{
name: "column with -> operator",
input: "columna->'key'",
expected: "columna",
},
{
name: "column with table prefix and ->> operator",
input: "table.columna->>'val'",
expected: "table.columna",
},
{
name: "column with table prefix and -> operator",
input: "table.columna->'key'",
expected: "table.columna",
},
{
name: "complex JSON path with ->>",
input: "data->>'nested'->>'value'",
expected: "data",
},
{
name: "column with spaces before operator",
input: "columna ->>'val'",
expected: "columna",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := reflection.ExtractSourceColumn(tc.input)
if result != tc.expected {
t.Errorf("reflection.ExtractSourceColumn(%q) = %q; want %q", tc.input, result, tc.expected)
}
})
}
}
func TestValidateColumnWithJSONOperators(t *testing.T) {
// Create a test model
type TestModel struct {
ID int `json:"id"`
Name string `json:"name"`
Data string `json:"data"` // JSON column
Metadata string `json:"metadata"`
}
validator := NewColumnValidator(TestModel{})
testCases := []struct {
name string
column string
shouldErr bool
}{
{
name: "simple valid column",
column: "name",
shouldErr: false,
},
{
name: "valid column with ->> operator",
column: "data->>'field'",
shouldErr: false,
},
{
name: "valid column with -> operator",
column: "metadata->'key'",
shouldErr: false,
},
{
name: "invalid column",
column: "invalid_column",
shouldErr: true,
},
{
name: "invalid column with ->> operator",
column: "invalid_column->>'field'",
shouldErr: true,
},
{
name: "cql prefixed column (always valid)",
column: "cql_computed",
shouldErr: false,
},
{
name: "empty column",
column: "",
shouldErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := validator.ValidateColumn(tc.column)
if tc.shouldErr && err == nil {
t.Errorf("ValidateColumn(%q) expected error, got nil", tc.column)
}
if !tc.shouldErr && err != nil {
t.Errorf("ValidateColumn(%q) expected no error, got %v", tc.column, err)
}
})
}
}

View File

@ -0,0 +1,363 @@
package common
import (
"strings"
"testing"
)
// TestModel represents a sample model for testing
type TestModel struct {
ID int64 `json:"id" gorm:"primaryKey"`
Name string `json:"name" gorm:"column:name"`
Email string `json:"email" bun:"email"`
Age int `json:"age"`
IsActive bool `json:"is_active"`
CreatedAt string `json:"created_at"`
}
func TestNewColumnValidator(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
if validator == nil {
t.Fatal("Expected validator to be created")
}
if len(validator.validColumns) == 0 {
t.Fatal("Expected validator to have valid columns")
}
// Check that expected columns are present
expectedColumns := []string{"id", "name", "email", "age", "is_active", "created_at"}
for _, col := range expectedColumns {
if !validator.validColumns[col] {
t.Errorf("Expected column '%s' to be valid", col)
}
}
}
func TestValidateColumn(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
tests := []struct {
name string
column string
shouldError bool
}{
{"Valid column - id", "id", false},
{"Valid column - name", "name", false},
{"Valid column - email", "email", false},
{"Valid column - uppercase", "ID", false}, // Case insensitive
{"Invalid column", "invalid_column", true},
{"CQL prefixed - should be valid", "cqlComputedField", false},
{"CQL prefixed uppercase - should be valid", "CQLComputedField", false},
{"Empty column", "", false}, // Empty columns are allowed
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateColumn(tt.column)
if tt.shouldError && err == nil {
t.Errorf("Expected error for column '%s', got nil", tt.column)
}
if !tt.shouldError && err != nil {
t.Errorf("Expected no error for column '%s', got: %v", tt.column, err)
}
})
}
}
func TestValidateColumns(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
tests := []struct {
name string
columns []string
shouldError bool
}{
{"All valid columns", []string{"id", "name", "email"}, false},
{"One invalid column", []string{"id", "invalid_col", "name"}, true},
{"All invalid columns", []string{"bad1", "bad2"}, true},
{"With CQL prefix", []string{"id", "cqlComputed", "name"}, false},
{"Empty list", []string{}, false},
{"Nil list", nil, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateColumns(tt.columns)
if tt.shouldError && err == nil {
t.Errorf("Expected error for columns %v, got nil", tt.columns)
}
if !tt.shouldError && err != nil {
t.Errorf("Expected no error for columns %v, got: %v", tt.columns, err)
}
})
}
}
func TestValidateRequestOptions(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
tests := []struct {
name string
options RequestOptions
shouldError bool
errorMsg string
}{
{
name: "Valid options with columns",
options: RequestOptions{
Columns: []string{"id", "name"},
Filters: []FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
},
Sort: []SortOption{
{Column: "id", Direction: "ASC"},
},
},
shouldError: false,
},
{
name: "Invalid column in Columns",
options: RequestOptions{
Columns: []string{"id", "invalid_column"},
},
shouldError: true,
errorMsg: "select columns",
},
{
name: "Invalid column in Filters",
options: RequestOptions{
Filters: []FilterOption{
{Column: "invalid_col", Operator: "eq", Value: "test"},
},
},
shouldError: true,
errorMsg: "filter",
},
{
name: "Invalid column in Sort",
options: RequestOptions{
Sort: []SortOption{
{Column: "invalid_col", Direction: "ASC"},
},
},
shouldError: true,
errorMsg: "sort",
},
{
name: "Valid CQL prefixed columns",
options: RequestOptions{
Columns: []string{"id", "cqlComputedField"},
Filters: []FilterOption{
{Column: "cqlCustomFilter", Operator: "eq", Value: "test"},
},
},
shouldError: false,
},
{
name: "Invalid column in Preload",
options: RequestOptions{
Preload: []PreloadOption{
{
Relation: "SomeRelation",
Columns: []string{"id", "invalid_col"},
},
},
},
shouldError: true,
errorMsg: "preload",
},
{
name: "Valid preload with valid columns",
options: RequestOptions{
Preload: []PreloadOption{
{
Relation: "SomeRelation",
Columns: []string{"id", "name"},
},
},
},
shouldError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validator.ValidateRequestOptions(tt.options)
if tt.shouldError {
if err == nil {
t.Errorf("Expected error, got nil")
} else if tt.errorMsg != "" && !strings.Contains(err.Error(), tt.errorMsg) {
t.Errorf("Expected error to contain '%s', got: %v", tt.errorMsg, err)
}
} else {
if err != nil {
t.Errorf("Expected no error, got: %v", err)
}
}
})
}
}
func TestGetValidColumns(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
columns := validator.GetValidColumns()
if len(columns) == 0 {
t.Error("Expected to get valid columns, got empty list")
}
// Should have at least the columns from TestModel
if len(columns) < 6 {
t.Errorf("Expected at least 6 columns, got %d", len(columns))
}
}
// Test with Bun tags specifically
type BunModel struct {
ID int64 `bun:"id,pk"`
Name string `bun:"name"`
Email string `bun:"user_email"`
}
func TestBunTagSupport(t *testing.T) {
model := BunModel{}
validator := NewColumnValidator(model)
// Test that bun tags are properly recognized
tests := []struct {
column string
shouldError bool
}{
{"id", false},
{"name", false},
{"user_email", false}, // Bun tag specifies this name
{"email", true}, // JSON tag would be "email", but bun tag says "user_email"
}
for _, tt := range tests {
t.Run(tt.column, func(t *testing.T) {
err := validator.ValidateColumn(tt.column)
if tt.shouldError && err == nil {
t.Errorf("Expected error for column '%s'", tt.column)
}
if !tt.shouldError && err != nil {
t.Errorf("Expected no error for column '%s', got: %v", tt.column, err)
}
})
}
}
func TestFilterValidColumns(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
tests := []struct {
name string
input []string
expectedOutput []string
}{
{
name: "All valid columns",
input: []string{"id", "name", "email"},
expectedOutput: []string{"id", "name", "email"},
},
{
name: "Mix of valid and invalid",
input: []string{"id", "invalid_col", "name", "bad_col", "email"},
expectedOutput: []string{"id", "name", "email"},
},
{
name: "All invalid columns",
input: []string{"bad1", "bad2"},
expectedOutput: []string{},
},
{
name: "With CQL prefix (should pass)",
input: []string{"id", "cqlComputed", "name"},
expectedOutput: []string{"id", "cqlComputed", "name"},
},
{
name: "Empty input",
input: []string{},
expectedOutput: []string{},
},
{
name: "Nil input",
input: nil,
expectedOutput: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := validator.FilterValidColumns(tt.input)
if len(result) != len(tt.expectedOutput) {
t.Errorf("Expected %d columns, got %d", len(tt.expectedOutput), len(result))
}
for i, col := range result {
if col != tt.expectedOutput[i] {
t.Errorf("At index %d: expected %s, got %s", i, tt.expectedOutput[i], col)
}
}
})
}
}
func TestFilterRequestOptions(t *testing.T) {
model := TestModel{}
validator := NewColumnValidator(model)
options := RequestOptions{
Columns: []string{"id", "name", "invalid_col"},
OmitColumns: []string{"email", "bad_col"},
Filters: []FilterOption{
{Column: "name", Operator: "eq", Value: "test"},
{Column: "invalid_col", Operator: "eq", Value: "test"},
},
Sort: []SortOption{
{Column: "id", Direction: "ASC"},
{Column: "bad_col", Direction: "DESC"},
},
}
filtered := validator.FilterRequestOptions(options)
// Check Columns
if len(filtered.Columns) != 2 {
t.Errorf("Expected 2 columns, got %d", len(filtered.Columns))
}
if filtered.Columns[0] != "id" || filtered.Columns[1] != "name" {
t.Errorf("Expected columns [id, name], got %v", filtered.Columns)
}
// Check OmitColumns
if len(filtered.OmitColumns) != 1 {
t.Errorf("Expected 1 omit column, got %d", len(filtered.OmitColumns))
}
if filtered.OmitColumns[0] != "email" {
t.Errorf("Expected omit column [email], got %v", filtered.OmitColumns)
}
// Check Filters
if len(filtered.Filters) != 1 {
t.Errorf("Expected 1 filter, got %d", len(filtered.Filters))
}
if filtered.Filters[0].Column != "name" {
t.Errorf("Expected filter column 'name', got %s", filtered.Filters[0].Column)
}
// Check Sort
if len(filtered.Sort) != 1 {
t.Errorf("Expected 1 sort, got %d", len(filtered.Sort))
}
if filtered.Sort[0].Column != "id" {
t.Errorf("Expected sort column 'id', got %s", filtered.Sort[0].Column)
}
}

291
pkg/config/README.md Normal file
View File

@ -0,0 +1,291 @@
# ResolveSpec Configuration System
A centralized configuration system with support for multiple configuration sources: config files (YAML, TOML, JSON), environment variables, and programmatic configuration.
## Features
- **Multiple Config Sources**: Config files, environment variables, and code
- **Priority Order**: Environment variables > Config file > Defaults
- **Multiple Formats**: YAML, TOML, JSON supported
- **Type Safety**: Strongly-typed configuration structs
- **Sensible Defaults**: Works out of the box with reasonable defaults
## Quick Start
### Basic Usage
```go
import "github.com/heinhel/ResolveSpec/pkg/config"
// Create a new config manager
mgr := config.NewManager()
// Load configuration from file and environment
if err := mgr.Load(); err != nil {
log.Fatal(err)
}
// Get the complete configuration
cfg, err := mgr.GetConfig()
if err != nil {
log.Fatal(err)
}
// Use the configuration
fmt.Println("Server address:", cfg.Server.Addr)
```
### Custom Configuration Paths
```go
mgr := config.NewManagerWithOptions(
config.WithConfigFile("/path/to/config.yaml"),
config.WithEnvPrefix("MYAPP"),
)
```
## Configuration Sources
### 1. Config Files
Place a `config.yaml` file in one of these locations:
- Current directory (`.`)
- `./config/`
- `/etc/resolvespec/`
- `$HOME/.resolvespec/`
Example `config.yaml`:
```yaml
server:
addr: ":8080"
shutdown_timeout: 30s
tracing:
enabled: true
service_name: "my-service"
cache:
provider: "redis"
redis:
host: "localhost"
port: 6379
```
### 2. Environment Variables
All configuration can be set via environment variables with the `RESOLVESPEC_` prefix:
```bash
export RESOLVESPEC_SERVER_ADDR=":9090"
export RESOLVESPEC_TRACING_ENABLED=true
export RESOLVESPEC_CACHE_PROVIDER=redis
export RESOLVESPEC_CACHE_REDIS_HOST=localhost
```
Nested configuration uses underscores:
- `server.addr``RESOLVESPEC_SERVER_ADDR`
- `cache.redis.host``RESOLVESPEC_CACHE_REDIS_HOST`
### 3. Programmatic Configuration
```go
mgr := config.NewManager()
mgr.Set("server.addr", ":9090")
mgr.Set("tracing.enabled", true)
cfg, _ := mgr.GetConfig()
```
## Configuration Options
### Server Configuration
```yaml
server:
addr: ":8080" # Server address
shutdown_timeout: 30s # Graceful shutdown timeout
drain_timeout: 25s # Connection drain timeout
read_timeout: 10s # HTTP read timeout
write_timeout: 10s # HTTP write timeout
idle_timeout: 120s # HTTP idle timeout
```
### Tracing Configuration
```yaml
tracing:
enabled: false # Enable/disable tracing
service_name: "resolvespec" # Service name
service_version: "1.0.0" # Service version
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
```
### Cache Configuration
```yaml
cache:
provider: "memory" # Options: memory, redis, memcache
redis:
host: "localhost"
port: 6379
password: ""
db: 0
memcache:
servers:
- "localhost:11211"
max_idle_conns: 10
timeout: 100ms
```
### Logger Configuration
```yaml
logger:
dev: false # Development mode (human-readable output)
path: "" # Log file path (empty = stdout)
```
### Middleware Configuration
```yaml
middleware:
rate_limit_rps: 100.0 # Requests per second
rate_limit_burst: 200 # Burst size
max_request_size: 10485760 # Max request size in bytes (10MB)
```
### CORS Configuration
```yaml
cors:
allowed_origins:
- "*"
allowed_methods:
- "GET"
- "POST"
- "PUT"
- "DELETE"
- "OPTIONS"
allowed_headers:
- "*"
max_age: 3600
```
### Database Configuration
```yaml
database:
url: "host=localhost user=postgres password=postgres dbname=mydb port=5432 sslmode=disable"
```
## Priority and Overrides
Configuration sources are applied in this order (highest priority first):
1. **Environment Variables** (highest priority)
2. **Config File**
3. **Defaults** (lowest priority)
This allows you to:
- Set defaults in code
- Override with a config file
- Override specific values with environment variables
## Examples
### Production Setup
```yaml
# config.yaml
server:
addr: ":8080"
tracing:
enabled: true
service_name: "myapi"
endpoint: "http://jaeger:4318/v1/traces"
cache:
provider: "redis"
redis:
host: "redis"
port: 6379
password: "${REDIS_PASSWORD}"
logger:
dev: false
path: "/var/log/myapi/app.log"
```
### Development Setup
```bash
# Use environment variables for development
export RESOLVESPEC_LOGGER_DEV=true
export RESOLVESPEC_TRACING_ENABLED=false
export RESOLVESPEC_CACHE_PROVIDER=memory
```
### Testing Setup
```go
// Override config for tests
mgr := config.NewManager()
mgr.Set("cache.provider", "memory")
mgr.Set("database.url", testDBURL)
cfg, _ := mgr.GetConfig()
```
## Best Practices
1. **Use config files for base configuration** - Define your standard settings
2. **Use environment variables for secrets** - Never commit passwords/tokens
3. **Use environment variables for deployment-specific values** - Different per environment
4. **Keep defaults sensible** - Application should work with minimal configuration
5. **Document your configuration** - Comment your config.yaml files
## Integration with ResolveSpec Components
The configuration system integrates seamlessly with ResolveSpec components:
```go
cfg, _ := config.NewManager().Load().GetConfig()
// Server
srv := server.NewGracefulServer(server.Config{
Addr: cfg.Server.Addr,
ShutdownTimeout: cfg.Server.ShutdownTimeout,
// ... other fields
})
// Tracing
if cfg.Tracing.Enabled {
tracer := tracing.Init(tracing.Config{
ServiceName: cfg.Tracing.ServiceName,
ServiceVersion: cfg.Tracing.ServiceVersion,
Endpoint: cfg.Tracing.Endpoint,
})
defer tracer.Shutdown(context.Background())
}
// Cache
var cacheProvider cache.Provider
switch cfg.Cache.Provider {
case "redis":
cacheProvider = cache.NewRedisProvider(cfg.Cache.Redis.Host, cfg.Cache.Redis.Port, ...)
case "memcache":
cacheProvider = cache.NewMemcacheProvider(cfg.Cache.Memcache.Servers, ...)
default:
cacheProvider = cache.NewMemoryProvider()
}
// Logger
logger.Init(cfg.Logger.Dev)
if cfg.Logger.Path != "" {
logger.UpdateLoggerPath(cfg.Logger.Path, cfg.Logger.Dev)
}
```

143
pkg/config/config.go Normal file
View File

@ -0,0 +1,143 @@
package config
import "time"
// Config represents the complete application configuration
type Config struct {
Server ServerConfig `mapstructure:"server"`
Tracing TracingConfig `mapstructure:"tracing"`
Cache CacheConfig `mapstructure:"cache"`
Logger LoggerConfig `mapstructure:"logger"`
ErrorTracking ErrorTrackingConfig `mapstructure:"error_tracking"`
Middleware MiddlewareConfig `mapstructure:"middleware"`
CORS CORSConfig `mapstructure:"cors"`
Database DatabaseConfig `mapstructure:"database"`
EventBroker EventBrokerConfig `mapstructure:"event_broker"`
}
// ServerConfig holds server-related configuration
type ServerConfig struct {
Addr string `mapstructure:"addr"`
ShutdownTimeout time.Duration `mapstructure:"shutdown_timeout"`
DrainTimeout time.Duration `mapstructure:"drain_timeout"`
ReadTimeout time.Duration `mapstructure:"read_timeout"`
WriteTimeout time.Duration `mapstructure:"write_timeout"`
IdleTimeout time.Duration `mapstructure:"idle_timeout"`
}
// TracingConfig holds OpenTelemetry tracing configuration
type TracingConfig struct {
Enabled bool `mapstructure:"enabled"`
ServiceName string `mapstructure:"service_name"`
ServiceVersion string `mapstructure:"service_version"`
Endpoint string `mapstructure:"endpoint"`
}
// CacheConfig holds cache provider configuration
type CacheConfig struct {
Provider string `mapstructure:"provider"` // memory, redis, memcache
Redis RedisConfig `mapstructure:"redis"`
Memcache MemcacheConfig `mapstructure:"memcache"`
}
// RedisConfig holds Redis-specific configuration
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
}
// MemcacheConfig holds Memcache-specific configuration
type MemcacheConfig struct {
Servers []string `mapstructure:"servers"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
Timeout time.Duration `mapstructure:"timeout"`
}
// LoggerConfig holds logger configuration
type LoggerConfig struct {
Dev bool `mapstructure:"dev"`
Path string `mapstructure:"path"`
}
// MiddlewareConfig holds middleware configuration
type MiddlewareConfig struct {
RateLimitRPS float64 `mapstructure:"rate_limit_rps"`
RateLimitBurst int `mapstructure:"rate_limit_burst"`
MaxRequestSize int64 `mapstructure:"max_request_size"`
}
// CORSConfig holds CORS configuration
type CORSConfig struct {
AllowedOrigins []string `mapstructure:"allowed_origins"`
AllowedMethods []string `mapstructure:"allowed_methods"`
AllowedHeaders []string `mapstructure:"allowed_headers"`
MaxAge int `mapstructure:"max_age"`
}
// DatabaseConfig holds database configuration (primarily for testing)
type DatabaseConfig struct {
URL string `mapstructure:"url"`
}
// ErrorTrackingConfig holds error tracking configuration
type ErrorTrackingConfig struct {
Enabled bool `mapstructure:"enabled"`
Provider string `mapstructure:"provider"` // sentry, noop
DSN string `mapstructure:"dsn"` // Sentry DSN
Environment string `mapstructure:"environment"` // e.g., production, staging, development
Release string `mapstructure:"release"` // Application version/release
Debug bool `mapstructure:"debug"` // Enable debug mode
SampleRate float64 `mapstructure:"sample_rate"` // Error sample rate (0.0-1.0)
TracesSampleRate float64 `mapstructure:"traces_sample_rate"` // Traces sample rate (0.0-1.0)
}
// EventBrokerConfig contains configuration for the event broker
type EventBrokerConfig struct {
Enabled bool `mapstructure:"enabled"`
Provider string `mapstructure:"provider"` // memory, redis, nats, database
Mode string `mapstructure:"mode"` // sync, async
WorkerCount int `mapstructure:"worker_count"`
BufferSize int `mapstructure:"buffer_size"`
InstanceID string `mapstructure:"instance_id"`
Redis EventBrokerRedisConfig `mapstructure:"redis"`
NATS EventBrokerNATSConfig `mapstructure:"nats"`
Database EventBrokerDatabaseConfig `mapstructure:"database"`
RetryPolicy EventBrokerRetryPolicyConfig `mapstructure:"retry_policy"`
}
// EventBrokerRedisConfig contains Redis-specific configuration
type EventBrokerRedisConfig struct {
StreamName string `mapstructure:"stream_name"`
ConsumerGroup string `mapstructure:"consumer_group"`
MaxLen int64 `mapstructure:"max_len"`
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
}
// EventBrokerNATSConfig contains NATS-specific configuration
type EventBrokerNATSConfig struct {
URL string `mapstructure:"url"`
StreamName string `mapstructure:"stream_name"`
Subjects []string `mapstructure:"subjects"`
Storage string `mapstructure:"storage"` // file, memory
MaxAge time.Duration `mapstructure:"max_age"`
}
// EventBrokerDatabaseConfig contains database provider configuration
type EventBrokerDatabaseConfig struct {
TableName string `mapstructure:"table_name"`
Channel string `mapstructure:"channel"` // PostgreSQL NOTIFY channel name
PollInterval time.Duration `mapstructure:"poll_interval"`
}
// EventBrokerRetryPolicyConfig contains retry policy configuration
type EventBrokerRetryPolicyConfig struct {
MaxRetries int `mapstructure:"max_retries"`
InitialDelay time.Duration `mapstructure:"initial_delay"`
MaxDelay time.Duration `mapstructure:"max_delay"`
BackoffFactor float64 `mapstructure:"backoff_factor"`
}

203
pkg/config/manager.go Normal file
View File

@ -0,0 +1,203 @@
package config
import (
"fmt"
"strings"
"github.com/spf13/viper"
)
// Manager handles configuration loading from multiple sources
type Manager struct {
v *viper.Viper
}
// NewManager creates a new configuration manager with defaults
func NewManager() *Manager {
v := viper.New()
// Set configuration file settings
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath(".")
v.AddConfigPath("./config")
v.AddConfigPath("/etc/resolvespec")
v.AddConfigPath("$HOME/.resolvespec")
// Enable environment variable support
v.SetEnvPrefix("RESOLVESPEC")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
// Set default values
setDefaults(v)
return &Manager{v: v}
}
// NewManagerWithOptions creates a new configuration manager with custom options
func NewManagerWithOptions(opts ...Option) *Manager {
m := NewManager()
for _, opt := range opts {
opt(m)
}
return m
}
// Option is a functional option for configuring the Manager
type Option func(*Manager)
// WithConfigFile sets a specific config file path
func WithConfigFile(path string) Option {
return func(m *Manager) {
m.v.SetConfigFile(path)
}
}
// WithConfigName sets the config file name (without extension)
func WithConfigName(name string) Option {
return func(m *Manager) {
m.v.SetConfigName(name)
}
}
// WithConfigPath adds a path to search for config files
func WithConfigPath(path string) Option {
return func(m *Manager) {
m.v.AddConfigPath(path)
}
}
// WithEnvPrefix sets the environment variable prefix
func WithEnvPrefix(prefix string) Option {
return func(m *Manager) {
m.v.SetEnvPrefix(prefix)
}
}
// Load attempts to load configuration from file and environment
func (m *Manager) Load() error {
// Try to read config file (not an error if it doesn't exist)
if err := m.v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return fmt.Errorf("error reading config file: %w", err)
}
// Config file not found; will rely on defaults and env vars
}
return nil
}
// GetConfig returns the complete configuration
func (m *Manager) GetConfig() (*Config, error) {
var cfg Config
if err := m.v.Unmarshal(&cfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
return &cfg, nil
}
// Get returns a configuration value by key
func (m *Manager) Get(key string) interface{} {
return m.v.Get(key)
}
// GetString returns a string configuration value
func (m *Manager) GetString(key string) string {
return m.v.GetString(key)
}
// GetInt returns an int configuration value
func (m *Manager) GetInt(key string) int {
return m.v.GetInt(key)
}
// GetBool returns a bool configuration value
func (m *Manager) GetBool(key string) bool {
return m.v.GetBool(key)
}
// Set sets a configuration value
func (m *Manager) Set(key string, value interface{}) {
m.v.Set(key, value)
}
// setDefaults sets default configuration values
func setDefaults(v *viper.Viper) {
// Server defaults
v.SetDefault("server.addr", ":8080")
v.SetDefault("server.shutdown_timeout", "30s")
v.SetDefault("server.drain_timeout", "25s")
v.SetDefault("server.read_timeout", "10s")
v.SetDefault("server.write_timeout", "10s")
v.SetDefault("server.idle_timeout", "120s")
// Tracing defaults
v.SetDefault("tracing.enabled", false)
v.SetDefault("tracing.service_name", "resolvespec")
v.SetDefault("tracing.service_version", "1.0.0")
v.SetDefault("tracing.endpoint", "")
// Cache defaults
v.SetDefault("cache.provider", "memory")
v.SetDefault("cache.redis.host", "localhost")
v.SetDefault("cache.redis.port", 6379)
v.SetDefault("cache.redis.password", "")
v.SetDefault("cache.redis.db", 0)
v.SetDefault("cache.memcache.servers", []string{"localhost:11211"})
v.SetDefault("cache.memcache.max_idle_conns", 10)
v.SetDefault("cache.memcache.timeout", "100ms")
// Logger defaults
v.SetDefault("logger.dev", false)
v.SetDefault("logger.path", "")
// Middleware defaults
v.SetDefault("middleware.rate_limit_rps", 100.0)
v.SetDefault("middleware.rate_limit_burst", 200)
v.SetDefault("middleware.max_request_size", 10485760) // 10MB
// CORS defaults
v.SetDefault("cors.allowed_origins", []string{"*"})
v.SetDefault("cors.allowed_methods", []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"})
v.SetDefault("cors.allowed_headers", []string{"*"})
v.SetDefault("cors.max_age", 3600)
// Database defaults
v.SetDefault("database.url", "")
// Event Broker defaults
v.SetDefault("event_broker.enabled", false)
v.SetDefault("event_broker.provider", "memory")
v.SetDefault("event_broker.mode", "async")
v.SetDefault("event_broker.worker_count", 10)
v.SetDefault("event_broker.buffer_size", 1000)
v.SetDefault("event_broker.instance_id", "")
// Event Broker - Redis defaults
v.SetDefault("event_broker.redis.stream_name", "resolvespec:events")
v.SetDefault("event_broker.redis.consumer_group", "resolvespec-workers")
v.SetDefault("event_broker.redis.max_len", 10000)
v.SetDefault("event_broker.redis.host", "localhost")
v.SetDefault("event_broker.redis.port", 6379)
v.SetDefault("event_broker.redis.password", "")
v.SetDefault("event_broker.redis.db", 0)
// Event Broker - NATS defaults
v.SetDefault("event_broker.nats.url", "nats://localhost:4222")
v.SetDefault("event_broker.nats.stream_name", "RESOLVESPEC_EVENTS")
v.SetDefault("event_broker.nats.subjects", []string{"events.>"})
v.SetDefault("event_broker.nats.storage", "file")
v.SetDefault("event_broker.nats.max_age", "24h")
// Event Broker - Database defaults
v.SetDefault("event_broker.database.table_name", "events")
v.SetDefault("event_broker.database.channel", "resolvespec_events")
v.SetDefault("event_broker.database.poll_interval", "1s")
// Event Broker - Retry Policy defaults
v.SetDefault("event_broker.retry_policy.max_retries", 3)
v.SetDefault("event_broker.retry_policy.initial_delay", "1s")
v.SetDefault("event_broker.retry_policy.max_delay", "30s")
v.SetDefault("event_broker.retry_policy.backoff_factor", 2.0)
}

166
pkg/config/manager_test.go Normal file
View File

@ -0,0 +1,166 @@
package config
import (
"os"
"testing"
"time"
)
func TestNewManager(t *testing.T) {
mgr := NewManager()
if mgr == nil {
t.Fatal("Expected manager to be non-nil")
}
if mgr.v == nil {
t.Fatal("Expected viper instance to be non-nil")
}
}
func TestDefaultValues(t *testing.T) {
mgr := NewManager()
if err := mgr.Load(); err != nil {
t.Fatalf("Failed to load config: %v", err)
}
cfg, err := mgr.GetConfig()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
// Test default values
tests := []struct {
name string
got interface{}
expected interface{}
}{
{"server.addr", cfg.Server.Addr, ":8080"},
{"server.shutdown_timeout", cfg.Server.ShutdownTimeout, 30 * time.Second},
{"tracing.enabled", cfg.Tracing.Enabled, false},
{"tracing.service_name", cfg.Tracing.ServiceName, "resolvespec"},
{"cache.provider", cfg.Cache.Provider, "memory"},
{"cache.redis.host", cfg.Cache.Redis.Host, "localhost"},
{"cache.redis.port", cfg.Cache.Redis.Port, 6379},
{"logger.dev", cfg.Logger.Dev, false},
{"middleware.rate_limit_rps", cfg.Middleware.RateLimitRPS, 100.0},
{"middleware.rate_limit_burst", cfg.Middleware.RateLimitBurst, 200},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.expected {
t.Errorf("%s: got %v, want %v", tt.name, tt.got, tt.expected)
}
})
}
}
func TestEnvironmentVariableOverrides(t *testing.T) {
// Set environment variables
os.Setenv("RESOLVESPEC_SERVER_ADDR", ":9090")
os.Setenv("RESOLVESPEC_TRACING_ENABLED", "true")
os.Setenv("RESOLVESPEC_CACHE_PROVIDER", "redis")
os.Setenv("RESOLVESPEC_LOGGER_DEV", "true")
defer func() {
os.Unsetenv("RESOLVESPEC_SERVER_ADDR")
os.Unsetenv("RESOLVESPEC_TRACING_ENABLED")
os.Unsetenv("RESOLVESPEC_CACHE_PROVIDER")
os.Unsetenv("RESOLVESPEC_LOGGER_DEV")
}()
mgr := NewManager()
if err := mgr.Load(); err != nil {
t.Fatalf("Failed to load config: %v", err)
}
cfg, err := mgr.GetConfig()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
// Test environment variable overrides
tests := []struct {
name string
got interface{}
expected interface{}
}{
{"server.addr", cfg.Server.Addr, ":9090"},
{"tracing.enabled", cfg.Tracing.Enabled, true},
{"cache.provider", cfg.Cache.Provider, "redis"},
{"logger.dev", cfg.Logger.Dev, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.got != tt.expected {
t.Errorf("%s: got %v, want %v", tt.name, tt.got, tt.expected)
}
})
}
}
func TestProgrammaticConfiguration(t *testing.T) {
mgr := NewManager()
mgr.Set("server.addr", ":7070")
mgr.Set("tracing.service_name", "test-service")
cfg, err := mgr.GetConfig()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
if cfg.Server.Addr != ":7070" {
t.Errorf("server.addr: got %s, want :7070", cfg.Server.Addr)
}
if cfg.Tracing.ServiceName != "test-service" {
t.Errorf("tracing.service_name: got %s, want test-service", cfg.Tracing.ServiceName)
}
}
func TestGetterMethods(t *testing.T) {
mgr := NewManager()
mgr.Set("test.string", "value")
mgr.Set("test.int", 42)
mgr.Set("test.bool", true)
if got := mgr.GetString("test.string"); got != "value" {
t.Errorf("GetString: got %s, want value", got)
}
if got := mgr.GetInt("test.int"); got != 42 {
t.Errorf("GetInt: got %d, want 42", got)
}
if got := mgr.GetBool("test.bool"); !got {
t.Errorf("GetBool: got %v, want true", got)
}
}
func TestWithOptions(t *testing.T) {
mgr := NewManagerWithOptions(
WithEnvPrefix("MYAPP"),
WithConfigName("myconfig"),
)
if mgr == nil {
t.Fatal("Expected manager to be non-nil")
}
// Set environment variable with custom prefix
os.Setenv("MYAPP_SERVER_ADDR", ":5000")
defer os.Unsetenv("MYAPP_SERVER_ADDR")
if err := mgr.Load(); err != nil {
t.Fatalf("Failed to load config: %v", err)
}
cfg, err := mgr.GetConfig()
if err != nil {
t.Fatalf("Failed to get config: %v", err)
}
if cfg.Server.Addr != ":5000" {
t.Errorf("server.addr: got %s, want :5000", cfg.Server.Addr)
}
}

150
pkg/errortracking/README.md Normal file
View File

@ -0,0 +1,150 @@
# Error Tracking
This package provides error tracking integration for ResolveSpec, with built-in support for Sentry.
## Features
- **Provider Interface**: Flexible design supporting multiple error tracking backends
- **Sentry Integration**: Full-featured Sentry support with automatic error, warning, and panic tracking
- **Automatic Logger Integration**: All `logger.Error()` and `logger.Warn()` calls are automatically sent to the error tracker
- **Panic Tracking**: Automatic panic capture with stack traces
- **NoOp Provider**: Zero-overhead when error tracking is disabled
## Configuration
Add error tracking configuration to your config file:
```yaml
error_tracking:
enabled: true
provider: "sentry" # Currently supports: "sentry" or "noop"
dsn: "https://your-sentry-dsn@sentry.io/project-id"
environment: "production" # e.g., production, staging, development
release: "v1.0.0" # Your application version
debug: false
sample_rate: 1.0 # Error sample rate (0.0-1.0)
traces_sample_rate: 0.1 # Traces sample rate (0.0-1.0)
```
## Usage
### Initialization
Initialize error tracking in your application startup:
```go
package main
import (
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
func main() {
// Load your configuration
cfg := config.Config{
ErrorTracking: config.ErrorTrackingConfig{
Enabled: true,
Provider: "sentry",
DSN: "https://your-sentry-dsn@sentry.io/project-id",
Environment: "production",
Release: "v1.0.0",
SampleRate: 1.0,
},
}
// Initialize logger
logger.Init(false)
// Initialize error tracking
provider, err := errortracking.NewProviderFromConfig(cfg.ErrorTracking)
if err != nil {
logger.Error("Failed to initialize error tracking: %v", err)
} else {
logger.InitErrorTracking(provider)
}
// Your application code...
// Cleanup on shutdown
defer logger.CloseErrorTracking()
}
```
### Automatic Tracking
Once initialized, all logger errors and warnings are automatically sent to the error tracker:
```go
// This will be logged AND sent to Sentry
logger.Error("Database connection failed: %v", err)
// This will also be logged AND sent to Sentry
logger.Warn("Cache miss for key: %s", key)
```
### Panic Tracking
Panics are automatically captured when using the logger's panic handlers:
```go
// Using CatchPanic
defer logger.CatchPanic("MyFunction")
// Using CatchPanicCallback
defer logger.CatchPanicCallback("MyFunction", func(err any) {
// Custom cleanup
})
// Using HandlePanic
defer func() {
if r := recover(); r != nil {
err = logger.HandlePanic("MyMethod", r)
}
}()
```
### Manual Tracking
You can also use the provider directly for custom error tracking:
```go
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/errortracking"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
func someFunction() {
tracker := logger.GetErrorTracker()
if tracker != nil {
// Capture an error
tracker.CaptureError(context.Background(), err, errortracking.SeverityError, map[string]interface{}{
"user_id": userID,
"request_id": requestID,
})
// Capture a message
tracker.CaptureMessage(context.Background(), "Important event occurred", errortracking.SeverityInfo, map[string]interface{}{
"event_type": "user_signup",
})
// Capture a panic
tracker.CapturePanic(context.Background(), recovered, stackTrace, map[string]interface{}{
"context": "background_job",
})
}
}
```
## Severity Levels
The package supports the following severity levels:
- `SeverityError`: For errors that should be tracked and investigated
- `SeverityWarning`: For warnings that may indicate potential issues
- `SeverityInfo`: For informational messages
- `SeverityDebug`: For debug-level information
```

View File

@ -0,0 +1,67 @@
package errortracking
import (
"context"
"errors"
"testing"
)
func TestNoOpProvider(t *testing.T) {
provider := NewNoOpProvider()
// Test that all methods can be called without panicking
t.Run("CaptureError", func(t *testing.T) {
provider.CaptureError(context.Background(), errors.New("test error"), SeverityError, nil)
})
t.Run("CaptureMessage", func(t *testing.T) {
provider.CaptureMessage(context.Background(), "test message", SeverityWarning, nil)
})
t.Run("CapturePanic", func(t *testing.T) {
provider.CapturePanic(context.Background(), "panic!", []byte("stack trace"), nil)
})
t.Run("Flush", func(t *testing.T) {
result := provider.Flush(5)
if !result {
t.Error("Expected Flush to return true")
}
})
t.Run("Close", func(t *testing.T) {
err := provider.Close()
if err != nil {
t.Errorf("Expected Close to return nil, got %v", err)
}
})
}
func TestSeverityLevels(t *testing.T) {
tests := []struct {
name string
severity Severity
expected string
}{
{"Error", SeverityError, "error"},
{"Warning", SeverityWarning, "warning"},
{"Info", SeverityInfo, "info"},
{"Debug", SeverityDebug, "debug"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if string(tt.severity) != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, string(tt.severity))
}
})
}
}
func TestProviderInterface(t *testing.T) {
// Test that NoOpProvider implements Provider interface
var _ Provider = (*NoOpProvider)(nil)
// Test that SentryProvider implements Provider interface
var _ Provider = (*SentryProvider)(nil)
}

View File

@ -0,0 +1,33 @@
package errortracking
import (
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/config"
)
// NewProviderFromConfig creates an error tracking provider based on the configuration
func NewProviderFromConfig(cfg config.ErrorTrackingConfig) (Provider, error) {
if !cfg.Enabled {
return NewNoOpProvider(), nil
}
switch cfg.Provider {
case "sentry":
if cfg.DSN == "" {
return nil, fmt.Errorf("sentry DSN is required when error tracking is enabled")
}
return NewSentryProvider(SentryConfig{
DSN: cfg.DSN,
Environment: cfg.Environment,
Release: cfg.Release,
Debug: cfg.Debug,
SampleRate: cfg.SampleRate,
TracesSampleRate: cfg.TracesSampleRate,
})
case "noop", "":
return NewNoOpProvider(), nil
default:
return nil, fmt.Errorf("unknown error tracking provider: %s", cfg.Provider)
}
}

View File

@ -0,0 +1,33 @@
package errortracking
import (
"context"
)
// Severity represents the severity level of an error
type Severity string
const (
SeverityError Severity = "error"
SeverityWarning Severity = "warning"
SeverityInfo Severity = "info"
SeverityDebug Severity = "debug"
)
// Provider defines the interface for error tracking providers
type Provider interface {
// CaptureError captures an error with the given severity and additional context
CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{})
// CaptureMessage captures a message with the given severity and additional context
CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{})
// CapturePanic captures a panic with stack trace
CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{})
// Flush waits for all events to be sent (useful for graceful shutdown)
Flush(timeout int) bool
// Close closes the provider and releases resources
Close() error
}

37
pkg/errortracking/noop.go Normal file
View File

@ -0,0 +1,37 @@
package errortracking
import "context"
// NoOpProvider is a no-op implementation of the Provider interface
// Used when error tracking is disabled
type NoOpProvider struct{}
// NewNoOpProvider creates a new NoOp provider
func NewNoOpProvider() *NoOpProvider {
return &NoOpProvider{}
}
// CaptureError does nothing
func (n *NoOpProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
// No-op
}
// CaptureMessage does nothing
func (n *NoOpProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
// No-op
}
// CapturePanic does nothing
func (n *NoOpProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
// No-op
}
// Flush does nothing and returns true
func (n *NoOpProvider) Flush(timeout int) bool {
return true
}
// Close does nothing
func (n *NoOpProvider) Close() error {
return nil
}

154
pkg/errortracking/sentry.go Normal file
View File

@ -0,0 +1,154 @@
package errortracking
import (
"context"
"fmt"
"time"
"github.com/getsentry/sentry-go"
)
// SentryProvider implements the Provider interface using Sentry
type SentryProvider struct {
hub *sentry.Hub
}
// SentryConfig holds the configuration for Sentry
type SentryConfig struct {
DSN string
Environment string
Release string
Debug bool
SampleRate float64
TracesSampleRate float64
}
// NewSentryProvider creates a new Sentry provider
func NewSentryProvider(config SentryConfig) (*SentryProvider, error) {
err := sentry.Init(sentry.ClientOptions{
Dsn: config.DSN,
Environment: config.Environment,
Release: config.Release,
Debug: config.Debug,
AttachStacktrace: true,
SampleRate: config.SampleRate,
TracesSampleRate: config.TracesSampleRate,
})
if err != nil {
return nil, fmt.Errorf("failed to initialize Sentry: %w", err)
}
return &SentryProvider{
hub: sentry.CurrentHub(),
}, nil
}
// CaptureError captures an error with the given severity and additional context
func (s *SentryProvider) CaptureError(ctx context.Context, err error, severity Severity, extra map[string]interface{}) {
if err == nil {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = s.convertSeverity(severity)
event.Message = err.Error()
event.Exception = []sentry.Exception{
{
Value: err.Error(),
Type: fmt.Sprintf("%T", err),
Stacktrace: sentry.ExtractStacktrace(err),
},
}
if extra != nil {
event.Extra = extra
}
hub.CaptureEvent(event)
}
// CaptureMessage captures a message with the given severity and additional context
func (s *SentryProvider) CaptureMessage(ctx context.Context, message string, severity Severity, extra map[string]interface{}) {
if message == "" {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = s.convertSeverity(severity)
event.Message = message
if extra != nil {
event.Extra = extra
}
hub.CaptureEvent(event)
}
// CapturePanic captures a panic with stack trace
func (s *SentryProvider) CapturePanic(ctx context.Context, recovered interface{}, stackTrace []byte, extra map[string]interface{}) {
if recovered == nil {
return
}
hub := sentry.GetHubFromContext(ctx)
if hub == nil {
hub = s.hub
}
event := sentry.NewEvent()
event.Level = sentry.LevelError
event.Message = fmt.Sprintf("Panic: %v", recovered)
event.Exception = []sentry.Exception{
{
Value: fmt.Sprintf("%v", recovered),
Type: "panic",
},
}
if extra != nil {
event.Extra = extra
}
if stackTrace != nil {
event.Extra["stack_trace"] = string(stackTrace)
}
hub.CaptureEvent(event)
}
// Flush waits for all events to be sent (useful for graceful shutdown)
func (s *SentryProvider) Flush(timeout int) bool {
return sentry.Flush(time.Duration(timeout) * time.Second)
}
// Close closes the provider and releases resources
func (s *SentryProvider) Close() error {
sentry.Flush(2 * time.Second)
return nil
}
// convertSeverity converts our Severity to Sentry's Level
func (s *SentryProvider) convertSeverity(severity Severity) sentry.Level {
switch severity {
case SeverityError:
return sentry.LevelError
case SeverityWarning:
return sentry.LevelWarning
case SeverityInfo:
return sentry.LevelInfo
case SeverityDebug:
return sentry.LevelDebug
default:
return sentry.LevelError
}
}

327
pkg/eventbroker/README.md Normal file
View File

@ -0,0 +1,327 @@
# Event Broker System
A comprehensive event handler/broker system for ResolveSpec that provides real-time event publishing, subscription, and cross-instance communication.
## Features
- **Multiple Sources**: Events from database, websockets, frontend, system, and internal sources
- **Event Status Tracking**: Pending, processing, completed, failed states with timestamps
- **Rich Metadata**: User IDs, session IDs, instance IDs, JSON payloads, and custom metadata
- **Sync & Async Modes**: Choose between synchronous or asynchronous event processing
- **Pattern Matching**: Subscribe to events using glob-style patterns
- **Multiple Providers**: In-memory, Redis Streams, NATS JetStream, PostgreSQL with NOTIFY
- **Hook Integration**: Automatic CRUD event capture via restheadspec hooks
- **Retry Logic**: Configurable retry policy with exponential backoff
- **Metrics**: Prometheus-compatible metrics for monitoring
- **Graceful Shutdown**: Proper cleanup and event flushing on shutdown
## Quick Start
### 1. Configuration
Add to your `config.yaml`:
```yaml
event_broker:
enabled: true
provider: memory # memory, redis, nats, database
mode: async # sync, async
worker_count: 10
buffer_size: 1000
instance_id: "${HOSTNAME}"
```
### 2. Initialize
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/eventbroker"
)
func main() {
// Load configuration
cfgMgr := config.NewManager()
cfg, _ := cfgMgr.GetConfig()
// Initialize event broker
if err := eventbroker.Initialize(cfg.EventBroker); err != nil {
log.Fatal(err)
}
}
```
### 3. Subscribe to Events
```go
// Subscribe to specific events
eventbroker.Subscribe("public.users.create", eventbroker.EventHandlerFunc(
func(ctx context.Context, event *eventbroker.Event) error {
log.Printf("New user created: %s", event.Payload)
// Send welcome email, update cache, etc.
return nil
},
))
// Subscribe with patterns
eventbroker.Subscribe("*.*.delete", eventbroker.EventHandlerFunc(
func(ctx context.Context, event *eventbroker.Event) error {
log.Printf("Deleted: %s.%s", event.Schema, event.Entity)
return nil
},
))
```
### 4. Publish Events
```go
// Create and publish an event
event := eventbroker.NewEvent(eventbroker.EventSourceDatabase, "public.users.update")
event.InstanceID = eventbroker.GetDefaultBroker().InstanceID()
event.UserID = 123
event.SessionID = "session-456"
event.Schema = "public"
event.Entity = "users"
event.Operation = "update"
event.SetPayload(map[string]interface{}{
"id": 123,
"name": "John Doe",
})
// Async (non-blocking)
eventbroker.PublishAsync(ctx, event)
// Sync (blocking)
eventbroker.PublishSync(ctx, event)
```
## Automatic CRUD Event Capture
Automatically capture database CRUD operations:
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/eventbroker"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
)
func setupHooks(handler *restheadspec.Handler) {
broker := eventbroker.GetDefaultBroker()
// Configure which operations to capture
config := eventbroker.DefaultCRUDHookConfig()
config.EnableRead = false // Disable read events for performance
// Register hooks
eventbroker.RegisterCRUDHooks(broker, handler.Hooks(), config)
// Now all create/update/delete operations automatically publish events!
}
```
## Event Structure
Every event contains:
```go
type Event struct {
ID string // UUID
Source EventSource // database, websocket, system, frontend, internal
Type string // Pattern: schema.entity.operation
Status EventStatus // pending, processing, completed, failed
Payload json.RawMessage // JSON payload
UserID int // User who triggered the event
SessionID string // Session identifier
InstanceID string // Server instance identifier
Schema string // Database schema
Entity string // Database entity/table
Operation string // create, update, delete, read
CreatedAt time.Time // When event was created
ProcessedAt *time.Time // When processing started
CompletedAt *time.Time // When processing completed
Error string // Error message if failed
Metadata map[string]interface{} // Additional context
RetryCount int // Number of retry attempts
}
```
## Pattern Matching
Subscribe to events using glob-style patterns:
| Pattern | Matches | Example |
|---------|---------|---------|
| `*` | All events | Any event |
| `public.users.*` | All user operations | `public.users.create`, `public.users.update` |
| `*.*.create` | All create operations | `public.users.create`, `auth.sessions.create` |
| `public.*.*` | All events in public schema | `public.users.create`, `public.posts.delete` |
| `public.users.create` | Exact match | Only `public.users.create` |
## Providers
### Memory Provider (Default)
Best for: Development, single-instance deployments
- **Pros**: Fast, no dependencies, simple
- **Cons**: Events lost on restart, single-instance only
```yaml
event_broker:
provider: memory
```
### Redis Provider (Future)
Best for: Production, multi-instance deployments
- **Pros**: Persistent, cross-instance pub/sub, reliable
- **Cons**: Requires Redis
```yaml
event_broker:
provider: redis
redis:
stream_name: "resolvespec:events"
consumer_group: "resolvespec-workers"
host: "localhost"
port: 6379
```
### NATS Provider (Future)
Best for: High-performance, low-latency requirements
- **Pros**: Very fast, built-in clustering, durable
- **Cons**: Requires NATS server
```yaml
event_broker:
provider: nats
nats:
url: "nats://localhost:4222"
stream_name: "RESOLVESPEC_EVENTS"
```
### Database Provider (Future)
Best for: Audit trails, event replay, SQL queries
- **Pros**: No additional infrastructure, full SQL query support, PostgreSQL NOTIFY for real-time
- **Cons**: Slower than Redis/NATS
```yaml
event_broker:
provider: database
database:
table_name: "events"
channel: "resolvespec_events"
```
## Processing Modes
### Async Mode (Recommended)
Events are queued and processed by worker pool:
- Non-blocking event publishing
- Configurable worker count
- Better throughput
- Events may be processed out of order
```yaml
event_broker:
mode: async
worker_count: 10
buffer_size: 1000
```
### Sync Mode
Events are processed immediately:
- Blocking event publishing
- Guaranteed ordering
- Immediate error feedback
- Lower throughput
```yaml
event_broker:
mode: sync
```
## Retry Policy
Configure automatic retries for failed handlers:
```yaml
event_broker:
retry_policy:
max_retries: 3
initial_delay: 1s
max_delay: 30s
backoff_factor: 2.0 # Exponential backoff
```
## Metrics
The event broker exposes Prometheus metrics:
- `eventbroker_events_published_total{source, type}` - Total events published
- `eventbroker_events_processed_total{source, type, status}` - Total events processed
- `eventbroker_event_processing_duration_seconds{source, type}` - Event processing duration
- `eventbroker_queue_size` - Current queue size (async mode)
## Best Practices
1. **Use Async Mode**: For better performance, use async mode in production
2. **Disable Read Events**: Read events can be high volume; disable if not needed
3. **Pattern Matching**: Use specific patterns to avoid processing unnecessary events
4. **Error Handling**: Always handle errors in event handlers; they won't fail the original operation
5. **Idempotency**: Make handlers idempotent as events may be retried
6. **Payload Size**: Keep payloads reasonable; avoid large objects
7. **Monitoring**: Monitor metrics to detect issues early
## Examples
See `example_usage.go` for comprehensive examples including:
- Basic event publishing and subscription
- Hook integration
- Error handling
- Configuration
- Pattern matching
## Architecture
```
┌─────────────────┐
│ Application │
└────────┬────────┘
├─ Publish Events
┌────────▼────────┐ ┌──────────────┐
│ Event Broker │◄────►│ Subscribers │
└────────┬────────┘ └──────────────┘
├─ Store Events
┌────────▼────────┐
│ Provider │
│ (Memory/Redis │
│ /NATS/DB) │
└─────────────────┘
```
## Future Enhancements
- [ ] Database Provider with PostgreSQL NOTIFY
- [ ] Redis Streams Provider
- [ ] NATS JetStream Provider
- [ ] Event replay functionality
- [ ] Dead letter queue
- [ ] Event filtering at provider level
- [ ] Batch publishing
- [ ] Event compression
- [ ] Schema versioning

453
pkg/eventbroker/broker.go Normal file
View File

@ -0,0 +1,453 @@
package eventbroker
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Broker is the main interface for event publishing and subscription
type Broker interface {
// Publish publishes an event (mode-dependent: sync or async)
Publish(ctx context.Context, event *Event) error
// PublishSync publishes an event synchronously (blocks until all handlers complete)
PublishSync(ctx context.Context, event *Event) error
// PublishAsync publishes an event asynchronously (returns immediately)
PublishAsync(ctx context.Context, event *Event) error
// Subscribe registers a handler for events matching the pattern
Subscribe(pattern string, handler EventHandler) (SubscriptionID, error)
// Unsubscribe removes a subscription
Unsubscribe(id SubscriptionID) error
// Start starts the broker (begins processing events)
Start(ctx context.Context) error
// Stop stops the broker gracefully (flushes pending events)
Stop(ctx context.Context) error
// Stats returns broker statistics
Stats(ctx context.Context) (*BrokerStats, error)
// InstanceID returns the instance ID of this broker
InstanceID() string
}
// ProcessingMode determines how events are processed
type ProcessingMode string
const (
ProcessingModeSync ProcessingMode = "sync"
ProcessingModeAsync ProcessingMode = "async"
)
// BrokerStats contains broker statistics
type BrokerStats struct {
InstanceID string `json:"instance_id"`
Mode ProcessingMode `json:"mode"`
IsRunning bool `json:"is_running"`
TotalPublished int64 `json:"total_published"`
TotalProcessed int64 `json:"total_processed"`
TotalFailed int64 `json:"total_failed"`
ActiveSubscribers int `json:"active_subscribers"`
QueueSize int `json:"queue_size,omitempty"` // For async mode
ActiveWorkers int `json:"active_workers,omitempty"` // For async mode
ProviderStats *ProviderStats `json:"provider_stats,omitempty"`
AdditionalStats map[string]interface{} `json:"additional_stats,omitempty"`
}
// EventBroker implements the Broker interface
type EventBroker struct {
provider Provider
subscriptions *subscriptionManager
mode ProcessingMode
instanceID string
retryPolicy *RetryPolicy
// Async mode fields (initialized in Phase 4)
workerPool *workerPool
// Runtime state
isRunning atomic.Bool
stopOnce sync.Once
stopCh chan struct{}
wg sync.WaitGroup
// Statistics
statsPublished atomic.Int64
statsProcessed atomic.Int64
statsFailed atomic.Int64
}
// RetryPolicy defines how failed events should be retried
type RetryPolicy struct {
MaxRetries int
InitialDelay time.Duration
MaxDelay time.Duration
BackoffFactor float64
}
// DefaultRetryPolicy returns a sensible default retry policy
func DefaultRetryPolicy() *RetryPolicy {
return &RetryPolicy{
MaxRetries: 3,
InitialDelay: 1 * time.Second,
MaxDelay: 30 * time.Second,
BackoffFactor: 2.0,
}
}
// Options for creating a new broker
type Options struct {
Provider Provider
Mode ProcessingMode
WorkerCount int // For async mode
BufferSize int // For async mode
RetryPolicy *RetryPolicy
InstanceID string
}
// NewBroker creates a new event broker with the given options
func NewBroker(opts Options) (*EventBroker, error) {
if opts.Provider == nil {
return nil, fmt.Errorf("provider is required")
}
if opts.InstanceID == "" {
return nil, fmt.Errorf("instance ID is required")
}
if opts.Mode == "" {
opts.Mode = ProcessingModeAsync // Default to async
}
if opts.RetryPolicy == nil {
opts.RetryPolicy = DefaultRetryPolicy()
}
broker := &EventBroker{
provider: opts.Provider,
subscriptions: newSubscriptionManager(),
mode: opts.Mode,
instanceID: opts.InstanceID,
retryPolicy: opts.RetryPolicy,
stopCh: make(chan struct{}),
}
// Worker pool will be initialized in Phase 4 for async mode
if opts.Mode == ProcessingModeAsync {
if opts.WorkerCount == 0 {
opts.WorkerCount = 10 // Default
}
if opts.BufferSize == 0 {
opts.BufferSize = 1000 // Default
}
broker.workerPool = newWorkerPool(opts.WorkerCount, opts.BufferSize, broker.processEvent)
}
return broker, nil
}
// Functional option pattern helpers
func WithProvider(p Provider) func(*Options) {
return func(o *Options) { o.Provider = p }
}
func WithMode(m ProcessingMode) func(*Options) {
return func(o *Options) { o.Mode = m }
}
func WithWorkerCount(count int) func(*Options) {
return func(o *Options) { o.WorkerCount = count }
}
func WithBufferSize(size int) func(*Options) {
return func(o *Options) { o.BufferSize = size }
}
func WithRetryPolicy(policy *RetryPolicy) func(*Options) {
return func(o *Options) { o.RetryPolicy = policy }
}
func WithInstanceID(id string) func(*Options) {
return func(o *Options) { o.InstanceID = id }
}
// Start starts the broker
func (b *EventBroker) Start(ctx context.Context) error {
if b.isRunning.Load() {
return fmt.Errorf("broker already running")
}
b.isRunning.Store(true)
// Start worker pool for async mode
if b.mode == ProcessingModeAsync && b.workerPool != nil {
b.workerPool.Start()
}
logger.Info("Event broker started (mode: %s, instance: %s)", b.mode, b.instanceID)
return nil
}
// Stop stops the broker gracefully
func (b *EventBroker) Stop(ctx context.Context) error {
var stopErr error
b.stopOnce.Do(func() {
logger.Info("Stopping event broker...")
// Mark as not running
b.isRunning.Store(false)
// Close the stop channel
close(b.stopCh)
// Stop worker pool for async mode
if b.mode == ProcessingModeAsync && b.workerPool != nil {
if err := b.workerPool.Stop(ctx); err != nil {
logger.Error("Error stopping worker pool: %v", err)
stopErr = err
}
}
// Wait for all goroutines
b.wg.Wait()
// Close provider
if err := b.provider.Close(); err != nil {
logger.Error("Error closing provider: %v", err)
if stopErr == nil {
stopErr = err
}
}
logger.Info("Event broker stopped")
})
return stopErr
}
// Publish publishes an event based on the broker's mode
func (b *EventBroker) Publish(ctx context.Context, event *Event) error {
if b.mode == ProcessingModeSync {
return b.PublishSync(ctx, event)
}
return b.PublishAsync(ctx, event)
}
// PublishSync publishes an event synchronously
func (b *EventBroker) PublishSync(ctx context.Context, event *Event) error {
if !b.isRunning.Load() {
return fmt.Errorf("broker is not running")
}
// Validate event
if err := event.Validate(); err != nil {
return fmt.Errorf("invalid event: %w", err)
}
// Store event in provider
if err := b.provider.Publish(ctx, event); err != nil {
return fmt.Errorf("failed to publish event: %w", err)
}
b.statsPublished.Add(1)
// Record metrics
recordEventPublished(event)
// Process event synchronously
if err := b.processEvent(ctx, event); err != nil {
logger.Error("Failed to process event %s: %v", event.ID, err)
b.statsFailed.Add(1)
return err
}
b.statsProcessed.Add(1)
return nil
}
// PublishAsync publishes an event asynchronously
func (b *EventBroker) PublishAsync(ctx context.Context, event *Event) error {
if !b.isRunning.Load() {
return fmt.Errorf("broker is not running")
}
// Validate event
if err := event.Validate(); err != nil {
return fmt.Errorf("invalid event: %w", err)
}
// Store event in provider
if err := b.provider.Publish(ctx, event); err != nil {
return fmt.Errorf("failed to publish event: %w", err)
}
b.statsPublished.Add(1)
// Record metrics
recordEventPublished(event)
// Queue for async processing
if b.mode == ProcessingModeAsync && b.workerPool != nil {
// Update queue size metrics
updateQueueSize(int64(b.workerPool.QueueSize()))
return b.workerPool.Submit(ctx, event)
}
// Fallback to sync if async not configured
return b.processEvent(ctx, event)
}
// Subscribe adds a subscription for events matching the pattern
func (b *EventBroker) Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
return b.subscriptions.Subscribe(pattern, handler)
}
// Unsubscribe removes a subscription
func (b *EventBroker) Unsubscribe(id SubscriptionID) error {
return b.subscriptions.Unsubscribe(id)
}
// processEvent processes an event by calling all matching handlers
func (b *EventBroker) processEvent(ctx context.Context, event *Event) error {
startTime := time.Now()
// Get all handlers matching this event type
handlers := b.subscriptions.GetMatching(event.Type)
if len(handlers) == 0 {
logger.Debug("No handlers for event type: %s", event.Type)
return nil
}
logger.Debug("Processing event %s with %d handler(s)", event.ID, len(handlers))
// Mark event as processing
event.MarkProcessing()
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusProcessing, ""); err != nil {
logger.Warn("Failed to update event status: %v", err)
}
// Execute all handlers
var lastErr error
for i, handler := range handlers {
if err := b.executeHandlerWithRetry(ctx, handler, event); err != nil {
logger.Error("Handler %d failed for event %s: %v", i+1, event.ID, err)
lastErr = err
// Continue processing other handlers
}
}
// Update final status
if lastErr != nil {
event.MarkFailed(lastErr)
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusFailed, lastErr.Error()); err != nil {
logger.Warn("Failed to update event status: %v", err)
}
// Record metrics
recordEventProcessed(event, time.Since(startTime))
return lastErr
}
event.MarkCompleted()
if err := b.provider.UpdateStatus(ctx, event.ID, EventStatusCompleted, ""); err != nil {
logger.Warn("Failed to update event status: %v", err)
}
// Record metrics
recordEventProcessed(event, time.Since(startTime))
return nil
}
// executeHandlerWithRetry executes a handler with retry logic
func (b *EventBroker) executeHandlerWithRetry(ctx context.Context, handler EventHandler, event *Event) error {
var lastErr error
for attempt := 0; attempt <= b.retryPolicy.MaxRetries; attempt++ {
if attempt > 0 {
// Calculate backoff delay
delay := b.calculateBackoff(attempt)
logger.Debug("Retrying event %s (attempt %d/%d) after %v",
event.ID, attempt, b.retryPolicy.MaxRetries, delay)
select {
case <-time.After(delay):
case <-ctx.Done():
return ctx.Err()
}
event.IncrementRetry()
}
// Execute handler
if err := handler.Handle(ctx, event); err != nil {
lastErr = err
logger.Warn("Handler failed for event %s (attempt %d): %v", event.ID, attempt+1, err)
continue
}
// Success
return nil
}
return fmt.Errorf("handler failed after %d attempts: %w", b.retryPolicy.MaxRetries+1, lastErr)
}
// calculateBackoff calculates the backoff delay for a retry attempt
func (b *EventBroker) calculateBackoff(attempt int) time.Duration {
delay := float64(b.retryPolicy.InitialDelay) * pow(b.retryPolicy.BackoffFactor, float64(attempt-1))
if delay > float64(b.retryPolicy.MaxDelay) {
delay = float64(b.retryPolicy.MaxDelay)
}
return time.Duration(delay)
}
// pow is a simple integer power function
func pow(base float64, exp float64) float64 {
result := 1.0
for i := 0.0; i < exp; i++ {
result *= base
}
return result
}
// Stats returns broker statistics
func (b *EventBroker) Stats(ctx context.Context) (*BrokerStats, error) {
providerStats, err := b.provider.Stats(ctx)
if err != nil {
logger.Warn("Failed to get provider stats: %v", err)
}
stats := &BrokerStats{
InstanceID: b.instanceID,
Mode: b.mode,
IsRunning: b.isRunning.Load(),
TotalPublished: b.statsPublished.Load(),
TotalProcessed: b.statsProcessed.Load(),
TotalFailed: b.statsFailed.Load(),
ActiveSubscribers: b.subscriptions.Count(),
ProviderStats: providerStats,
}
// Add async-specific stats
if b.mode == ProcessingModeAsync && b.workerPool != nil {
stats.QueueSize = b.workerPool.QueueSize()
stats.ActiveWorkers = b.workerPool.ActiveWorkers()
}
return stats, nil
}
// InstanceID returns the instance ID
func (b *EventBroker) InstanceID() string {
return b.instanceID
}

View File

@ -0,0 +1,524 @@
package eventbroker
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestNewBroker(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
MaxEvents: 1000,
})
tests := []struct {
name string
opts Options
wantError bool
}{
{
name: "valid options",
opts: Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
},
wantError: false,
},
{
name: "missing provider",
opts: Options{
InstanceID: "test-instance",
},
wantError: true,
},
{
name: "missing instance ID",
opts: Options{
Provider: provider,
},
wantError: true,
},
{
name: "async mode with defaults",
opts: Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeAsync,
},
wantError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
broker, err := NewBroker(tt.opts)
if (err != nil) != tt.wantError {
t.Errorf("NewBroker() error = %v, wantError %v", err, tt.wantError)
}
if err == nil && broker == nil {
t.Error("Expected non-nil broker")
}
})
}
}
func TestBrokerStartStop(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, err := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
})
if err != nil {
t.Fatalf("Failed to create broker: %v", err)
}
// Test Start
if err := broker.Start(context.Background()); err != nil {
t.Fatalf("Failed to start broker: %v", err)
}
// Test double start (should fail)
if err := broker.Start(context.Background()); err == nil {
t.Error("Expected error on double start")
}
// Test Stop
if err := broker.Stop(context.Background()); err != nil {
t.Fatalf("Failed to stop broker: %v", err)
}
// Test double stop (should not fail)
if err := broker.Stop(context.Background()); err != nil {
t.Error("Double stop should not fail")
}
}
func TestBrokerPublishSync(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
// Subscribe to events
called := false
var receivedEvent *Event
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
called = true
receivedEvent = event
return nil
}))
// Publish event
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
err := broker.PublishSync(context.Background(), event)
if err != nil {
t.Fatalf("PublishSync failed: %v", err)
}
// Verify handler was called
if !called {
t.Error("Expected handler to be called")
}
if receivedEvent == nil || receivedEvent.ID != event.ID {
t.Error("Expected to receive the published event")
}
// Verify event status
if event.Status != EventStatusCompleted {
t.Errorf("Expected status %s, got %s", EventStatusCompleted, event.Status)
}
}
func TestBrokerPublishAsync(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeAsync,
WorkerCount: 2,
BufferSize: 10,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
// Subscribe to events
var callCount atomic.Int32
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
callCount.Add(1)
return nil
}))
// Publish multiple events
for i := 0; i < 5; i++ {
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
if err := broker.PublishAsync(context.Background(), event); err != nil {
t.Fatalf("PublishAsync failed: %v", err)
}
}
// Wait for events to be processed
time.Sleep(100 * time.Millisecond)
if callCount.Load() != 5 {
t.Errorf("Expected 5 handler calls, got %d", callCount.Load())
}
}
func TestBrokerPublishBeforeStart(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
})
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
err := broker.Publish(context.Background(), event)
if err == nil {
t.Error("Expected error when publishing before start")
}
}
func TestBrokerHandlerError(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
RetryPolicy: &RetryPolicy{
MaxRetries: 2,
InitialDelay: 10 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
BackoffFactor: 2.0,
},
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
// Subscribe with failing handler
var callCount atomic.Int32
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
callCount.Add(1)
return errors.New("handler error")
}))
// Publish event
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
err := broker.PublishSync(context.Background(), event)
// Should fail after retries
if err == nil {
t.Error("Expected error from handler")
}
// Should have been called MaxRetries+1 times (initial + retries)
if callCount.Load() != 3 {
t.Errorf("Expected 3 calls (1 initial + 2 retries), got %d", callCount.Load())
}
// Event should be marked as failed
if event.Status != EventStatusFailed {
t.Errorf("Expected status %s, got %s", EventStatusFailed, event.Status)
}
}
func TestBrokerMultipleHandlers(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
// Subscribe multiple handlers
var called1, called2, called3 bool
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
called1 = true
return nil
}))
broker.Subscribe("test.event", EventHandlerFunc(func(ctx context.Context, event *Event) error {
called2 = true
return nil
}))
broker.Subscribe("*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
called3 = true
return nil
}))
// Publish event
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
broker.PublishSync(context.Background(), event)
// All handlers should be called
if !called1 || !called2 || !called3 {
t.Error("Expected all handlers to be called")
}
}
func TestBrokerUnsubscribe(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
// Subscribe
called := false
id, _ := broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
called = true
return nil
}))
// Unsubscribe
if err := broker.Unsubscribe(id); err != nil {
t.Fatalf("Unsubscribe failed: %v", err)
}
// Publish event
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
broker.PublishSync(context.Background(), event)
// Handler should not be called
if called {
t.Error("Expected handler not to be called after unsubscribe")
}
}
func TestBrokerStats(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeSync,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
// Subscribe
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
return nil
}))
// Publish events
for i := 0; i < 3; i++ {
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
broker.PublishSync(context.Background(), event)
}
// Get stats
stats, err := broker.Stats(context.Background())
if err != nil {
t.Fatalf("Stats failed: %v", err)
}
if stats.InstanceID != "test-instance" {
t.Errorf("Expected instance ID 'test-instance', got %s", stats.InstanceID)
}
if stats.TotalPublished != 3 {
t.Errorf("Expected 3 published events, got %d", stats.TotalPublished)
}
if stats.TotalProcessed != 3 {
t.Errorf("Expected 3 processed events, got %d", stats.TotalProcessed)
}
if stats.ActiveSubscribers != 1 {
t.Errorf("Expected 1 active subscriber, got %d", stats.ActiveSubscribers)
}
}
func TestBrokerInstanceID(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "my-instance",
})
if broker.InstanceID() != "my-instance" {
t.Errorf("Expected instance ID 'my-instance', got %s", broker.InstanceID())
}
}
func TestBrokerConcurrentPublish(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeAsync,
WorkerCount: 5,
BufferSize: 100,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
var callCount atomic.Int32
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
callCount.Add(1)
return nil
}))
// Publish concurrently
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
broker.PublishAsync(context.Background(), event)
}()
}
wg.Wait()
time.Sleep(200 * time.Millisecond) // Wait for async processing
if callCount.Load() != 50 {
t.Errorf("Expected 50 handler calls, got %d", callCount.Load())
}
}
func TestBrokerGracefulShutdown(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: ProcessingModeAsync,
WorkerCount: 2,
BufferSize: 10,
})
broker.Start(context.Background())
var processedCount atomic.Int32
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
time.Sleep(50 * time.Millisecond) // Simulate work
processedCount.Add(1)
return nil
}))
// Publish events
for i := 0; i < 5; i++ {
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
broker.PublishAsync(context.Background(), event)
}
// Stop broker (should wait for events to be processed)
if err := broker.Stop(context.Background()); err != nil {
t.Fatalf("Stop failed: %v", err)
}
// All events should be processed
if processedCount.Load() != 5 {
t.Errorf("Expected 5 processed events, got %d", processedCount.Load())
}
}
func TestBrokerDefaultRetryPolicy(t *testing.T) {
policy := DefaultRetryPolicy()
if policy.MaxRetries != 3 {
t.Errorf("Expected MaxRetries 3, got %d", policy.MaxRetries)
}
if policy.InitialDelay != 1*time.Second {
t.Errorf("Expected InitialDelay 1s, got %v", policy.InitialDelay)
}
if policy.BackoffFactor != 2.0 {
t.Errorf("Expected BackoffFactor 2.0, got %f", policy.BackoffFactor)
}
}
func TestBrokerProcessingModes(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
tests := []struct {
name string
mode ProcessingMode
}{
{"sync mode", ProcessingModeSync},
{"async mode", ProcessingModeAsync},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
broker, _ := NewBroker(Options{
Provider: provider,
InstanceID: "test-instance",
Mode: tt.mode,
})
broker.Start(context.Background())
defer broker.Stop(context.Background())
called := false
broker.Subscribe("test.*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
called = true
return nil
}))
event := NewEvent(EventSourceSystem, "test.event")
event.InstanceID = "test-instance"
broker.Publish(context.Background(), event)
if tt.mode == ProcessingModeAsync {
time.Sleep(50 * time.Millisecond)
}
if !called {
t.Error("Expected handler to be called")
}
})
}
}

175
pkg/eventbroker/event.go Normal file
View File

@ -0,0 +1,175 @@
package eventbroker
import (
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
)
// EventSource represents where an event originated from
type EventSource string
const (
EventSourceDatabase EventSource = "database"
EventSourceWebSocket EventSource = "websocket"
EventSourceFrontend EventSource = "frontend"
EventSourceSystem EventSource = "system"
EventSourceInternal EventSource = "internal"
)
// EventStatus represents the current state of an event
type EventStatus string
const (
EventStatusPending EventStatus = "pending"
EventStatusProcessing EventStatus = "processing"
EventStatusCompleted EventStatus = "completed"
EventStatusFailed EventStatus = "failed"
)
// Event represents a single event in the system with complete metadata
type Event struct {
// Identification
ID string `json:"id" db:"id"`
// Source & Classification
Source EventSource `json:"source" db:"source"`
Type string `json:"type" db:"type"` // Pattern: schema.entity.operation
// Status Tracking
Status EventStatus `json:"status" db:"status"`
RetryCount int `json:"retry_count" db:"retry_count"`
Error string `json:"error,omitempty" db:"error"`
// Payload
Payload json.RawMessage `json:"payload" db:"payload"`
// Context Information
UserID int `json:"user_id" db:"user_id"`
SessionID string `json:"session_id" db:"session_id"`
InstanceID string `json:"instance_id" db:"instance_id"`
// Database Context
Schema string `json:"schema" db:"schema"`
Entity string `json:"entity" db:"entity"`
Operation string `json:"operation" db:"operation"` // create, update, delete, read
// Timestamps
CreatedAt time.Time `json:"created_at" db:"created_at"`
ProcessedAt *time.Time `json:"processed_at,omitempty" db:"processed_at"`
CompletedAt *time.Time `json:"completed_at,omitempty" db:"completed_at"`
// Extensibility
Metadata map[string]interface{} `json:"metadata" db:"metadata"`
}
// NewEvent creates a new event with defaults
func NewEvent(source EventSource, eventType string) *Event {
return &Event{
ID: uuid.New().String(),
Source: source,
Type: eventType,
Status: EventStatusPending,
CreatedAt: time.Now(),
Metadata: make(map[string]interface{}),
RetryCount: 0,
}
}
// EventType generates a type string from schema, entity, and operation
// Pattern: schema.entity.operation (e.g., "public.users.create")
func EventType(schema, entity, operation string) string {
return fmt.Sprintf("%s.%s.%s", schema, entity, operation)
}
// MarkProcessing marks the event as being processed
func (e *Event) MarkProcessing() {
e.Status = EventStatusProcessing
now := time.Now()
e.ProcessedAt = &now
}
// MarkCompleted marks the event as successfully completed
func (e *Event) MarkCompleted() {
e.Status = EventStatusCompleted
now := time.Now()
e.CompletedAt = &now
}
// MarkFailed marks the event as failed with an error message
func (e *Event) MarkFailed(err error) {
e.Status = EventStatusFailed
e.Error = err.Error()
now := time.Now()
e.CompletedAt = &now
}
// IncrementRetry increments the retry counter
func (e *Event) IncrementRetry() {
e.RetryCount++
}
// SetPayload sets the event payload from any value by marshaling to JSON
func (e *Event) SetPayload(v interface{}) error {
data, err := json.Marshal(v)
if err != nil {
return fmt.Errorf("failed to marshal payload: %w", err)
}
e.Payload = data
return nil
}
// GetPayload unmarshals the payload into the provided value
func (e *Event) GetPayload(v interface{}) error {
if len(e.Payload) == 0 {
return fmt.Errorf("payload is empty")
}
if err := json.Unmarshal(e.Payload, v); err != nil {
return fmt.Errorf("failed to unmarshal payload: %w", err)
}
return nil
}
// Clone creates a deep copy of the event
func (e *Event) Clone() *Event {
clone := *e
// Deep copy metadata
if e.Metadata != nil {
clone.Metadata = make(map[string]interface{})
for k, v := range e.Metadata {
clone.Metadata[k] = v
}
}
// Deep copy timestamps
if e.ProcessedAt != nil {
t := *e.ProcessedAt
clone.ProcessedAt = &t
}
if e.CompletedAt != nil {
t := *e.CompletedAt
clone.CompletedAt = &t
}
return &clone
}
// Validate performs basic validation on the event
func (e *Event) Validate() error {
if e.ID == "" {
return fmt.Errorf("event ID is required")
}
if e.Source == "" {
return fmt.Errorf("event source is required")
}
if e.Type == "" {
return fmt.Errorf("event type is required")
}
if e.InstanceID == "" {
return fmt.Errorf("instance ID is required")
}
return nil
}

View File

@ -0,0 +1,314 @@
package eventbroker
import (
"encoding/json"
"errors"
"testing"
"time"
)
func TestNewEvent(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
if event.ID == "" {
t.Error("Expected event ID to be generated")
}
if event.Source != EventSourceDatabase {
t.Errorf("Expected source %s, got %s", EventSourceDatabase, event.Source)
}
if event.Type != "public.users.create" {
t.Errorf("Expected type 'public.users.create', got %s", event.Type)
}
if event.Status != EventStatusPending {
t.Errorf("Expected status %s, got %s", EventStatusPending, event.Status)
}
if event.CreatedAt.IsZero() {
t.Error("Expected CreatedAt to be set")
}
if event.Metadata == nil {
t.Error("Expected Metadata to be initialized")
}
}
func TestEventType(t *testing.T) {
tests := []struct {
schema string
entity string
operation string
expected string
}{
{"public", "users", "create", "public.users.create"},
{"admin", "roles", "update", "admin.roles.update"},
{"", "system", "start", ".system.start"}, // Empty schema results in leading dot
}
for _, tt := range tests {
result := EventType(tt.schema, tt.entity, tt.operation)
if result != tt.expected {
t.Errorf("EventType(%q, %q, %q) = %q, expected %q",
tt.schema, tt.entity, tt.operation, result, tt.expected)
}
}
}
func TestEventValidate(t *testing.T) {
tests := []struct {
name string
event *Event
wantError bool
}{
{
name: "valid event",
event: func() *Event {
e := NewEvent(EventSourceDatabase, "public.users.create")
e.InstanceID = "test-instance"
return e
}(),
wantError: false,
},
{
name: "missing ID",
event: &Event{
Source: EventSourceDatabase,
Type: "public.users.create",
Status: EventStatusPending,
},
wantError: true,
},
{
name: "missing source",
event: &Event{
ID: "test-id",
Type: "public.users.create",
Status: EventStatusPending,
},
wantError: true,
},
{
name: "missing type",
event: &Event{
ID: "test-id",
Source: EventSourceDatabase,
Status: EventStatusPending,
},
wantError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.event.Validate()
if (err != nil) != tt.wantError {
t.Errorf("Event.Validate() error = %v, wantError %v", err, tt.wantError)
}
})
}
}
func TestEventSetPayload(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
payload := map[string]interface{}{
"id": 1,
"name": "John Doe",
"email": "john@example.com",
}
err := event.SetPayload(payload)
if err != nil {
t.Fatalf("SetPayload failed: %v", err)
}
if event.Payload == nil {
t.Fatal("Expected payload to be set")
}
// Verify payload can be unmarshaled
var result map[string]interface{}
if err := json.Unmarshal(event.Payload, &result); err != nil {
t.Fatalf("Failed to unmarshal payload: %v", err)
}
if result["name"] != "John Doe" {
t.Errorf("Expected name 'John Doe', got %v", result["name"])
}
}
func TestEventGetPayload(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
payload := map[string]interface{}{
"id": float64(1), // JSON unmarshals numbers as float64
"name": "John Doe",
}
if err := event.SetPayload(payload); err != nil {
t.Fatalf("SetPayload failed: %v", err)
}
var result map[string]interface{}
if err := event.GetPayload(&result); err != nil {
t.Fatalf("GetPayload failed: %v", err)
}
if result["name"] != "John Doe" {
t.Errorf("Expected name 'John Doe', got %v", result["name"])
}
}
func TestEventMarkProcessing(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
event.MarkProcessing()
if event.Status != EventStatusProcessing {
t.Errorf("Expected status %s, got %s", EventStatusProcessing, event.Status)
}
if event.ProcessedAt == nil {
t.Error("Expected ProcessedAt to be set")
}
}
func TestEventMarkCompleted(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
event.MarkCompleted()
if event.Status != EventStatusCompleted {
t.Errorf("Expected status %s, got %s", EventStatusCompleted, event.Status)
}
if event.CompletedAt == nil {
t.Error("Expected CompletedAt to be set")
}
}
func TestEventMarkFailed(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
testErr := errors.New("test error")
event.MarkFailed(testErr)
if event.Status != EventStatusFailed {
t.Errorf("Expected status %s, got %s", EventStatusFailed, event.Status)
}
if event.Error != "test error" {
t.Errorf("Expected error %q, got %q", "test error", event.Error)
}
if event.CompletedAt == nil {
t.Error("Expected CompletedAt to be set")
}
}
func TestEventIncrementRetry(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
initialCount := event.RetryCount
event.IncrementRetry()
if event.RetryCount != initialCount+1 {
t.Errorf("Expected retry count %d, got %d", initialCount+1, event.RetryCount)
}
}
func TestEventJSONMarshaling(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
event.UserID = 123
event.SessionID = "session-123"
event.InstanceID = "instance-1"
event.Schema = "public"
event.Entity = "users"
event.Operation = "create"
event.SetPayload(map[string]interface{}{"name": "Test"})
// Marshal to JSON
data, err := json.Marshal(event)
if err != nil {
t.Fatalf("Failed to marshal event: %v", err)
}
// Unmarshal back
var decoded Event
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Failed to unmarshal event: %v", err)
}
// Verify fields
if decoded.ID != event.ID {
t.Errorf("Expected ID %s, got %s", event.ID, decoded.ID)
}
if decoded.Source != event.Source {
t.Errorf("Expected source %s, got %s", event.Source, decoded.Source)
}
if decoded.UserID != event.UserID {
t.Errorf("Expected UserID %d, got %d", event.UserID, decoded.UserID)
}
}
func TestEventStatusString(t *testing.T) {
statuses := []EventStatus{
EventStatusPending,
EventStatusProcessing,
EventStatusCompleted,
EventStatusFailed,
}
for _, status := range statuses {
if string(status) == "" {
t.Errorf("EventStatus %v has empty string representation", status)
}
}
}
func TestEventSourceString(t *testing.T) {
sources := []EventSource{
EventSourceDatabase,
EventSourceWebSocket,
EventSourceFrontend,
EventSourceSystem,
EventSourceInternal,
}
for _, source := range sources {
if string(source) == "" {
t.Errorf("EventSource %v has empty string representation", source)
}
}
}
func TestEventMetadata(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
// Test setting metadata
event.Metadata["key1"] = "value1"
event.Metadata["key2"] = 123
if event.Metadata["key1"] != "value1" {
t.Errorf("Expected metadata key1 to be 'value1', got %v", event.Metadata["key1"])
}
if event.Metadata["key2"] != 123 {
t.Errorf("Expected metadata key2 to be 123, got %v", event.Metadata["key2"])
}
}
func TestEventTimestamps(t *testing.T) {
event := NewEvent(EventSourceDatabase, "public.users.create")
createdAt := event.CreatedAt
// Wait a tiny bit to ensure timestamps differ
time.Sleep(time.Millisecond)
event.MarkProcessing()
if event.ProcessedAt == nil {
t.Fatal("ProcessedAt should be set")
}
if !event.ProcessedAt.After(createdAt) {
t.Error("ProcessedAt should be after CreatedAt")
}
time.Sleep(time.Millisecond)
event.MarkCompleted()
if event.CompletedAt == nil {
t.Fatal("CompletedAt should be set")
}
if !event.CompletedAt.After(*event.ProcessedAt) {
t.Error("CompletedAt should be after ProcessedAt")
}
}

View File

@ -0,0 +1,160 @@
package eventbroker
import (
"context"
"fmt"
"sync"
"github.com/bitechdev/ResolveSpec/pkg/config"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/server"
)
var (
defaultBroker Broker
brokerMu sync.RWMutex
)
// Initialize initializes the global event broker from configuration
func Initialize(cfg config.EventBrokerConfig) error {
if !cfg.Enabled {
logger.Info("Event broker is disabled")
return nil
}
// Create provider
provider, err := NewProviderFromConfig(cfg)
if err != nil {
return fmt.Errorf("failed to create provider: %w", err)
}
// Parse mode
mode := ProcessingModeAsync
if cfg.Mode == "sync" {
mode = ProcessingModeSync
}
// Convert retry policy
retryPolicy := &RetryPolicy{
MaxRetries: cfg.RetryPolicy.MaxRetries,
InitialDelay: cfg.RetryPolicy.InitialDelay,
MaxDelay: cfg.RetryPolicy.MaxDelay,
BackoffFactor: cfg.RetryPolicy.BackoffFactor,
}
if retryPolicy.MaxRetries == 0 {
retryPolicy = DefaultRetryPolicy()
}
// Create broker options
opts := Options{
Provider: provider,
Mode: mode,
WorkerCount: cfg.WorkerCount,
BufferSize: cfg.BufferSize,
RetryPolicy: retryPolicy,
InstanceID: getInstanceID(cfg.InstanceID),
}
// Create broker
broker, err := NewBroker(opts)
if err != nil {
return fmt.Errorf("failed to create broker: %w", err)
}
// Start broker
if err := broker.Start(context.Background()); err != nil {
return fmt.Errorf("failed to start broker: %w", err)
}
// Set as default
SetDefaultBroker(broker)
// Register shutdown callback
RegisterShutdown(broker)
logger.Info("Event broker initialized successfully (provider: %s, mode: %s, instance: %s)",
cfg.Provider, cfg.Mode, opts.InstanceID)
return nil
}
// SetDefaultBroker sets the default global broker
func SetDefaultBroker(broker Broker) {
brokerMu.Lock()
defer brokerMu.Unlock()
defaultBroker = broker
}
// GetDefaultBroker returns the default global broker
func GetDefaultBroker() Broker {
brokerMu.RLock()
defer brokerMu.RUnlock()
return defaultBroker
}
// IsInitialized returns true if the default broker is initialized
func IsInitialized() bool {
return GetDefaultBroker() != nil
}
// Publish publishes an event using the default broker
func Publish(ctx context.Context, event *Event) error {
broker := GetDefaultBroker()
if broker == nil {
return fmt.Errorf("event broker not initialized")
}
return broker.Publish(ctx, event)
}
// PublishSync publishes an event synchronously using the default broker
func PublishSync(ctx context.Context, event *Event) error {
broker := GetDefaultBroker()
if broker == nil {
return fmt.Errorf("event broker not initialized")
}
return broker.PublishSync(ctx, event)
}
// PublishAsync publishes an event asynchronously using the default broker
func PublishAsync(ctx context.Context, event *Event) error {
broker := GetDefaultBroker()
if broker == nil {
return fmt.Errorf("event broker not initialized")
}
return broker.PublishAsync(ctx, event)
}
// Subscribe subscribes to events using the default broker
func Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
broker := GetDefaultBroker()
if broker == nil {
return "", fmt.Errorf("event broker not initialized")
}
return broker.Subscribe(pattern, handler)
}
// Unsubscribe unsubscribes from events using the default broker
func Unsubscribe(id SubscriptionID) error {
broker := GetDefaultBroker()
if broker == nil {
return fmt.Errorf("event broker not initialized")
}
return broker.Unsubscribe(id)
}
// Stats returns statistics from the default broker
func Stats(ctx context.Context) (*BrokerStats, error) {
broker := GetDefaultBroker()
if broker == nil {
return nil, fmt.Errorf("event broker not initialized")
}
return broker.Stats(ctx)
}
// RegisterShutdown registers the broker's shutdown with the server shutdown callbacks
func RegisterShutdown(broker Broker) {
server.RegisterShutdownCallback(func(ctx context.Context) error {
logger.Info("Shutting down event broker...")
return broker.Stop(ctx)
})
}

View File

@ -0,0 +1,266 @@
// nolint
package eventbroker
import (
"context"
"fmt"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Example demonstrates basic usage of the event broker
func Example() {
// 1. Create a memory provider
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "example-instance",
MaxEvents: 1000,
CleanupInterval: 5 * time.Minute,
MaxAge: 1 * time.Hour,
})
// 2. Create a broker
broker, err := NewBroker(Options{
Provider: provider,
Mode: ProcessingModeAsync,
WorkerCount: 5,
BufferSize: 100,
RetryPolicy: DefaultRetryPolicy(),
InstanceID: "example-instance",
})
if err != nil {
logger.Error("Failed to create broker: %v", err)
return
}
// 3. Start the broker
if err := broker.Start(context.Background()); err != nil {
logger.Error("Failed to start broker: %v", err)
return
}
defer func() {
err := broker.Stop(context.Background())
if err != nil {
logger.Error("Failed to stop broker: %v", err)
}
}()
// 4. Subscribe to events
broker.Subscribe("public.users.*", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
logger.Info("User event: %s (operation: %s)", event.Type, event.Operation)
return nil
},
))
broker.Subscribe("*.*.create", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
logger.Info("Create event: %s.%s", event.Schema, event.Entity)
return nil
},
))
// 5. Publish events
ctx := context.Background()
// Database event
dbEvent := NewEvent(EventSourceDatabase, EventType("public", "users", "create"))
dbEvent.InstanceID = "example-instance"
dbEvent.UserID = 123
dbEvent.SessionID = "session-456"
dbEvent.Schema = "public"
dbEvent.Entity = "users"
dbEvent.Operation = "create"
dbEvent.SetPayload(map[string]interface{}{
"id": 123,
"name": "John Doe",
"email": "john@example.com",
})
if err := broker.PublishAsync(ctx, dbEvent); err != nil {
logger.Error("Failed to publish event: %v", err)
}
// WebSocket event
wsEvent := NewEvent(EventSourceWebSocket, "chat.message")
wsEvent.InstanceID = "example-instance"
wsEvent.UserID = 123
wsEvent.SessionID = "session-456"
wsEvent.SetPayload(map[string]interface{}{
"room": "general",
"message": "Hello, World!",
})
if err := broker.PublishAsync(ctx, wsEvent); err != nil {
logger.Error("Failed to publish event: %v", err)
}
// 6. Get statistics
time.Sleep(1 * time.Second) // Wait for processing
stats, _ := broker.Stats(ctx)
logger.Info("Broker stats: %d published, %d processed", stats.TotalPublished, stats.TotalProcessed)
}
// ExampleWithHooks demonstrates integration with the hook system
func ExampleWithHooks() {
// This would typically be called in your main.go or initialization code
// after setting up your restheadspec.Handler
// Pseudo-code (actual implementation would use real handler):
/*
broker := eventbroker.GetDefaultBroker()
hookRegistry := handler.Hooks()
// Register CRUD hooks
config := eventbroker.DefaultCRUDHookConfig()
config.EnableRead = false // Disable read events for performance
if err := eventbroker.RegisterCRUDHooks(broker, hookRegistry, config); err != nil {
logger.Error("Failed to register CRUD hooks: %v", err)
}
// Now all CRUD operations will automatically publish events
*/
}
// ExampleSubscriptionPatterns demonstrates different subscription patterns
func ExampleSubscriptionPatterns() {
broker := GetDefaultBroker()
if broker == nil {
return
}
// Pattern 1: Subscribe to all events from a specific entity
broker.Subscribe("public.users.*", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
fmt.Printf("User event: %s\n", event.Operation)
return nil
},
))
// Pattern 2: Subscribe to a specific operation across all entities
broker.Subscribe("*.*.create", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
fmt.Printf("Create event: %s.%s\n", event.Schema, event.Entity)
return nil
},
))
// Pattern 3: Subscribe to all events in a schema
broker.Subscribe("public.*.*", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
fmt.Printf("Public schema event: %s.%s\n", event.Entity, event.Operation)
return nil
},
))
// Pattern 4: Subscribe to everything (use with caution)
broker.Subscribe("*", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
fmt.Printf("Any event: %s\n", event.Type)
return nil
},
))
}
// ExampleErrorHandling demonstrates error handling in event handlers
func ExampleErrorHandling() {
broker := GetDefaultBroker()
if broker == nil {
return
}
// Handler that may fail
broker.Subscribe("public.users.create", EventHandlerFunc(
func(ctx context.Context, event *Event) error {
// Simulate processing
var user struct {
ID int `json:"id"`
Email string `json:"email"`
}
if err := event.GetPayload(&user); err != nil {
return fmt.Errorf("invalid payload: %w", err)
}
// Validate
if user.Email == "" {
return fmt.Errorf("email is required")
}
// Process (e.g., send email)
logger.Info("Sending welcome email to %s", user.Email)
return nil
},
))
}
// ExampleConfiguration demonstrates initializing from configuration
func ExampleConfiguration() {
// This would typically be in your main.go
// Pseudo-code:
/*
// Load configuration
cfgMgr := config.NewManager()
if err := cfgMgr.Load(); err != nil {
logger.Fatal("Failed to load config: %v", err)
}
cfg, err := cfgMgr.GetConfig()
if err != nil {
logger.Fatal("Failed to get config: %v", err)
}
// Initialize event broker
if err := eventbroker.Initialize(cfg.EventBroker); err != nil {
logger.Fatal("Failed to initialize event broker: %v", err)
}
// Use the default broker
eventbroker.Subscribe("*.*.create", eventbroker.EventHandlerFunc(
func(ctx context.Context, event *eventbroker.Event) error {
logger.Info("Created: %s.%s", event.Schema, event.Entity)
return nil
},
))
*/
}
// ExampleYAMLConfiguration shows example YAML configuration
const ExampleYAMLConfiguration = `
event_broker:
enabled: true
provider: memory # memory, redis, nats, database
mode: async # sync, async
worker_count: 10
buffer_size: 1000
instance_id: "${HOSTNAME}"
# Memory provider is default, no additional config needed
# Redis provider (when provider: redis)
redis:
stream_name: "resolvespec:events"
consumer_group: "resolvespec-workers"
host: "localhost"
port: 6379
# NATS provider (when provider: nats)
nats:
url: "nats://localhost:4222"
stream_name: "RESOLVESPEC_EVENTS"
# Database provider (when provider: database)
database:
table_name: "events"
channel: "resolvespec_events"
# Retry policy
retry_policy:
max_retries: 3
initial_delay: 1s
max_delay: 30s
backoff_factor: 2.0
`

View File

@ -0,0 +1,56 @@
package eventbroker
import (
"fmt"
"os"
"time"
"github.com/bitechdev/ResolveSpec/pkg/config"
)
// NewProviderFromConfig creates a provider based on configuration
func NewProviderFromConfig(cfg config.EventBrokerConfig) (Provider, error) {
switch cfg.Provider {
case "memory":
cleanupInterval := 5 * time.Minute
if cfg.Database.PollInterval > 0 {
cleanupInterval = cfg.Database.PollInterval
}
return NewMemoryProvider(MemoryProviderOptions{
InstanceID: getInstanceID(cfg.InstanceID),
MaxEvents: 10000,
CleanupInterval: cleanupInterval,
}), nil
case "redis":
// Redis provider will be implemented in Phase 8
return nil, fmt.Errorf("redis provider not yet implemented")
case "nats":
// NATS provider will be implemented in Phase 9
return nil, fmt.Errorf("nats provider not yet implemented")
case "database":
// Database provider will be implemented in Phase 7
return nil, fmt.Errorf("database provider not yet implemented")
default:
return nil, fmt.Errorf("unknown provider: %s", cfg.Provider)
}
}
// getInstanceID returns the instance ID, defaulting to hostname if not specified
func getInstanceID(configID string) string {
if configID != "" {
return configID
}
// Try to get hostname
if hostname, err := os.Hostname(); err == nil {
return hostname
}
// Fallback to a default
return "resolvespec-instance"
}

View File

@ -0,0 +1,17 @@
package eventbroker
import "context"
// EventHandler processes an event
type EventHandler interface {
Handle(ctx context.Context, event *Event) error
}
// EventHandlerFunc is a function adapter for EventHandler
// This allows using regular functions as event handlers
type EventHandlerFunc func(ctx context.Context, event *Event) error
// Handle implements EventHandler
func (f EventHandlerFunc) Handle(ctx context.Context, event *Event) error {
return f(ctx, event)
}

137
pkg/eventbroker/hooks.go Normal file
View File

@ -0,0 +1,137 @@
package eventbroker
import (
"encoding/json"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// CRUDHookConfig configures which CRUD operations should trigger events
type CRUDHookConfig struct {
EnableCreate bool
EnableRead bool
EnableUpdate bool
EnableDelete bool
}
// DefaultCRUDHookConfig returns default configuration (all enabled)
func DefaultCRUDHookConfig() *CRUDHookConfig {
return &CRUDHookConfig{
EnableCreate: true,
EnableRead: false, // Typically disabled for performance
EnableUpdate: true,
EnableDelete: true,
}
}
// RegisterCRUDHooks registers event hooks for CRUD operations
// This integrates with the restheadspec.HookRegistry to automatically
// capture database events
func RegisterCRUDHooks(broker Broker, hookRegistry *restheadspec.HookRegistry, config *CRUDHookConfig) error {
if broker == nil {
return fmt.Errorf("broker cannot be nil")
}
if hookRegistry == nil {
return fmt.Errorf("hookRegistry cannot be nil")
}
if config == nil {
config = DefaultCRUDHookConfig()
}
// Create hook handler factory
createHookHandler := func(operation string) restheadspec.HookFunc {
return func(hookCtx *restheadspec.HookContext) error {
// Get user context from Go context
userCtx, ok := security.GetUserContext(hookCtx.Context)
if !ok || userCtx == nil {
logger.Debug("No user context found in hook")
userCtx = &security.UserContext{} // Empty user context
}
// Create event
event := NewEvent(EventSourceDatabase, EventType(hookCtx.Schema, hookCtx.Entity, operation))
event.InstanceID = broker.InstanceID()
event.UserID = userCtx.UserID
event.SessionID = userCtx.SessionID
event.Schema = hookCtx.Schema
event.Entity = hookCtx.Entity
event.Operation = operation
// Set payload based on operation
var payload interface{}
switch operation {
case "create":
payload = hookCtx.Result
case "read":
payload = hookCtx.Result
case "update":
payload = map[string]interface{}{
"id": hookCtx.ID,
"data": hookCtx.Data,
}
case "delete":
payload = map[string]interface{}{
"id": hookCtx.ID,
}
}
if payload != nil {
if err := event.SetPayload(payload); err != nil {
logger.Error("Failed to set event payload: %v", err)
payload = map[string]interface{}{"error": "failed to serialize payload"}
event.Payload, _ = json.Marshal(payload)
}
}
// Add metadata
if userCtx.UserName != "" {
event.Metadata["user_name"] = userCtx.UserName
}
if userCtx.Email != "" {
event.Metadata["user_email"] = userCtx.Email
}
if len(userCtx.Roles) > 0 {
event.Metadata["user_roles"] = userCtx.Roles
}
event.Metadata["table_name"] = hookCtx.TableName
// Publish asynchronously to not block CRUD operation
if err := broker.PublishAsync(hookCtx.Context, event); err != nil {
logger.Error("Failed to publish %s event for %s.%s: %v",
operation, hookCtx.Schema, hookCtx.Entity, err)
// Don't fail the CRUD operation if event publishing fails
return nil
}
logger.Debug("Published %s event for %s.%s (ID: %s)",
operation, hookCtx.Schema, hookCtx.Entity, event.ID)
return nil
}
}
// Register hooks based on configuration
if config.EnableCreate {
hookRegistry.Register(restheadspec.AfterCreate, createHookHandler("create"))
logger.Info("Registered event hook for CREATE operations")
}
if config.EnableRead {
hookRegistry.Register(restheadspec.AfterRead, createHookHandler("read"))
logger.Info("Registered event hook for READ operations")
}
if config.EnableUpdate {
hookRegistry.Register(restheadspec.AfterUpdate, createHookHandler("update"))
logger.Info("Registered event hook for UPDATE operations")
}
if config.EnableDelete {
hookRegistry.Register(restheadspec.AfterDelete, createHookHandler("delete"))
logger.Info("Registered event hook for DELETE operations")
}
return nil
}

View File

@ -0,0 +1,28 @@
package eventbroker
import (
"time"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
)
// recordEventPublished records an event publication metric
func recordEventPublished(event *Event) {
if mp := metrics.GetProvider(); mp != nil {
mp.RecordEventPublished(string(event.Source), event.Type)
}
}
// recordEventProcessed records an event processing metric
func recordEventProcessed(event *Event, duration time.Duration) {
if mp := metrics.GetProvider(); mp != nil {
mp.RecordEventProcessed(string(event.Source), event.Type, string(event.Status), duration)
}
}
// updateQueueSize updates the event queue size metric
func updateQueueSize(size int64) {
if mp := metrics.GetProvider(); mp != nil {
mp.UpdateEventQueueSize(size)
}
}

View File

@ -0,0 +1,70 @@
package eventbroker
import (
"context"
"time"
)
// Provider defines the storage backend interface for events
// Implementations: MemoryProvider, RedisProvider, NATSProvider, DatabaseProvider
type Provider interface {
// Store stores an event
Store(ctx context.Context, event *Event) error
// Get retrieves an event by ID
Get(ctx context.Context, id string) (*Event, error)
// List lists events with optional filters
List(ctx context.Context, filter *EventFilter) ([]*Event, error)
// UpdateStatus updates the status of an event
UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error
// Delete deletes an event by ID
Delete(ctx context.Context, id string) error
// Stream returns a channel of events for real-time consumption
// Used for cross-instance pub/sub
// The channel is closed when the context is canceled or an error occurs
Stream(ctx context.Context, pattern string) (<-chan *Event, error)
// Publish publishes an event to all subscribers (for distributed providers)
// For in-memory provider, this is the same as Store
// For Redis/NATS/Database, this triggers cross-instance delivery
Publish(ctx context.Context, event *Event) error
// Close closes the provider and releases resources
Close() error
// Stats returns provider statistics
Stats(ctx context.Context) (*ProviderStats, error)
}
// EventFilter defines filter criteria for listing events
type EventFilter struct {
Source *EventSource
Status *EventStatus
UserID *int
Schema string
Entity string
Operation string
InstanceID string
StartTime *time.Time
EndTime *time.Time
Limit int
Offset int
}
// ProviderStats contains statistics about the provider
type ProviderStats struct {
ProviderType string `json:"provider_type"`
TotalEvents int64 `json:"total_events"`
PendingEvents int64 `json:"pending_events"`
ProcessingEvents int64 `json:"processing_events"`
CompletedEvents int64 `json:"completed_events"`
FailedEvents int64 `json:"failed_events"`
EventsPublished int64 `json:"events_published"`
EventsConsumed int64 `json:"events_consumed"`
ActiveSubscribers int `json:"active_subscribers"`
ProviderSpecific map[string]interface{} `json:"provider_specific,omitempty"`
}

View File

@ -0,0 +1,446 @@
package eventbroker
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// MemoryProvider implements Provider interface using in-memory storage
// Features:
// - Thread-safe event storage with RW mutex
// - LRU eviction when max events reached
// - In-process pub/sub (not cross-instance)
// - Automatic cleanup of old completed events
type MemoryProvider struct {
mu sync.RWMutex
events map[string]*Event
eventOrder []string // For LRU tracking
subscribers map[string][]chan *Event
instanceID string
maxEvents int
cleanupInterval time.Duration
maxAge time.Duration
// Statistics
stats MemoryProviderStats
// Lifecycle
stopCleanup chan struct{}
wg sync.WaitGroup
isRunning atomic.Bool
}
// MemoryProviderStats contains statistics for the memory provider
type MemoryProviderStats struct {
TotalEvents atomic.Int64
PendingEvents atomic.Int64
ProcessingEvents atomic.Int64
CompletedEvents atomic.Int64
FailedEvents atomic.Int64
EventsPublished atomic.Int64
EventsConsumed atomic.Int64
ActiveSubscribers atomic.Int32
Evictions atomic.Int64
}
// MemoryProviderOptions configures the memory provider
type MemoryProviderOptions struct {
InstanceID string
MaxEvents int
CleanupInterval time.Duration
MaxAge time.Duration
}
// NewMemoryProvider creates a new in-memory event provider
func NewMemoryProvider(opts MemoryProviderOptions) *MemoryProvider {
if opts.MaxEvents == 0 {
opts.MaxEvents = 10000 // Default
}
if opts.CleanupInterval == 0 {
opts.CleanupInterval = 5 * time.Minute // Default
}
if opts.MaxAge == 0 {
opts.MaxAge = 24 * time.Hour // Default: keep events for 24 hours
}
mp := &MemoryProvider{
events: make(map[string]*Event),
eventOrder: make([]string, 0),
subscribers: make(map[string][]chan *Event),
instanceID: opts.InstanceID,
maxEvents: opts.MaxEvents,
cleanupInterval: opts.CleanupInterval,
maxAge: opts.MaxAge,
stopCleanup: make(chan struct{}),
}
mp.isRunning.Store(true)
// Start cleanup goroutine
mp.wg.Add(1)
go mp.cleanupLoop()
logger.Info("Memory provider initialized (max_events: %d, cleanup: %v, max_age: %v)",
opts.MaxEvents, opts.CleanupInterval, opts.MaxAge)
return mp
}
// Store stores an event
func (mp *MemoryProvider) Store(ctx context.Context, event *Event) error {
mp.mu.Lock()
defer mp.mu.Unlock()
// Check if we need to evict oldest events
if len(mp.events) >= mp.maxEvents {
mp.evictOldestLocked()
}
// Store event
mp.events[event.ID] = event.Clone()
mp.eventOrder = append(mp.eventOrder, event.ID)
// Update statistics
mp.stats.TotalEvents.Add(1)
mp.updateStatusCountsLocked(event.Status, 1)
return nil
}
// Get retrieves an event by ID
func (mp *MemoryProvider) Get(ctx context.Context, id string) (*Event, error) {
mp.mu.RLock()
defer mp.mu.RUnlock()
event, exists := mp.events[id]
if !exists {
return nil, fmt.Errorf("event not found: %s", id)
}
return event.Clone(), nil
}
// List lists events with optional filters
func (mp *MemoryProvider) List(ctx context.Context, filter *EventFilter) ([]*Event, error) {
mp.mu.RLock()
defer mp.mu.RUnlock()
var results []*Event
for _, event := range mp.events {
if mp.matchesFilter(event, filter) {
results = append(results, event.Clone())
}
}
// Apply limit and offset
if filter != nil {
if filter.Offset > 0 && filter.Offset < len(results) {
results = results[filter.Offset:]
}
if filter.Limit > 0 && filter.Limit < len(results) {
results = results[:filter.Limit]
}
}
return results, nil
}
// UpdateStatus updates the status of an event
func (mp *MemoryProvider) UpdateStatus(ctx context.Context, id string, status EventStatus, errorMsg string) error {
mp.mu.Lock()
defer mp.mu.Unlock()
event, exists := mp.events[id]
if !exists {
return fmt.Errorf("event not found: %s", id)
}
// Update status counts
mp.updateStatusCountsLocked(event.Status, -1)
mp.updateStatusCountsLocked(status, 1)
// Update event
event.Status = status
if errorMsg != "" {
event.Error = errorMsg
}
return nil
}
// Delete deletes an event by ID
func (mp *MemoryProvider) Delete(ctx context.Context, id string) error {
mp.mu.Lock()
defer mp.mu.Unlock()
event, exists := mp.events[id]
if !exists {
return fmt.Errorf("event not found: %s", id)
}
// Update counts
mp.stats.TotalEvents.Add(-1)
mp.updateStatusCountsLocked(event.Status, -1)
// Delete event
delete(mp.events, id)
// Remove from order tracking
for i, eid := range mp.eventOrder {
if eid == id {
mp.eventOrder = append(mp.eventOrder[:i], mp.eventOrder[i+1:]...)
break
}
}
return nil
}
// Stream returns a channel of events for real-time consumption
// Note: This is in-process only, not cross-instance
func (mp *MemoryProvider) Stream(ctx context.Context, pattern string) (<-chan *Event, error) {
mp.mu.Lock()
defer mp.mu.Unlock()
// Create buffered channel for events
ch := make(chan *Event, 100)
// Store subscriber
mp.subscribers[pattern] = append(mp.subscribers[pattern], ch)
mp.stats.ActiveSubscribers.Add(1)
// Goroutine to clean up on context cancellation
mp.wg.Add(1)
go func() {
defer mp.wg.Done()
<-ctx.Done()
mp.mu.Lock()
defer mp.mu.Unlock()
// Remove subscriber
subs := mp.subscribers[pattern]
for i, subCh := range subs {
if subCh == ch {
mp.subscribers[pattern] = append(subs[:i], subs[i+1:]...)
break
}
}
mp.stats.ActiveSubscribers.Add(-1)
close(ch)
}()
logger.Debug("Stream created for pattern: %s", pattern)
return ch, nil
}
// Publish publishes an event to all subscribers
func (mp *MemoryProvider) Publish(ctx context.Context, event *Event) error {
// Store the event first
if err := mp.Store(ctx, event); err != nil {
return err
}
mp.stats.EventsPublished.Add(1)
// Notify subscribers
mp.mu.RLock()
defer mp.mu.RUnlock()
for pattern, channels := range mp.subscribers {
if matchPattern(pattern, event.Type) {
for _, ch := range channels {
select {
case ch <- event.Clone():
mp.stats.EventsConsumed.Add(1)
default:
// Channel full, skip
logger.Warn("Subscriber channel full for pattern: %s", pattern)
}
}
}
}
return nil
}
// Close closes the provider and releases resources
func (mp *MemoryProvider) Close() error {
if !mp.isRunning.Load() {
return nil
}
mp.isRunning.Store(false)
// Stop cleanup loop
close(mp.stopCleanup)
// Wait for goroutines
mp.wg.Wait()
// Close all subscriber channels
mp.mu.Lock()
for _, channels := range mp.subscribers {
for _, ch := range channels {
close(ch)
}
}
mp.subscribers = make(map[string][]chan *Event)
mp.mu.Unlock()
logger.Info("Memory provider closed")
return nil
}
// Stats returns provider statistics
func (mp *MemoryProvider) Stats(ctx context.Context) (*ProviderStats, error) {
return &ProviderStats{
ProviderType: "memory",
TotalEvents: mp.stats.TotalEvents.Load(),
PendingEvents: mp.stats.PendingEvents.Load(),
ProcessingEvents: mp.stats.ProcessingEvents.Load(),
CompletedEvents: mp.stats.CompletedEvents.Load(),
FailedEvents: mp.stats.FailedEvents.Load(),
EventsPublished: mp.stats.EventsPublished.Load(),
EventsConsumed: mp.stats.EventsConsumed.Load(),
ActiveSubscribers: int(mp.stats.ActiveSubscribers.Load()),
ProviderSpecific: map[string]interface{}{
"max_events": mp.maxEvents,
"cleanup_interval": mp.cleanupInterval.String(),
"max_age": mp.maxAge.String(),
"evictions": mp.stats.Evictions.Load(),
},
}, nil
}
// cleanupLoop periodically cleans up old completed events
func (mp *MemoryProvider) cleanupLoop() {
defer mp.wg.Done()
ticker := time.NewTicker(mp.cleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
mp.cleanup()
case <-mp.stopCleanup:
return
}
}
}
// cleanup removes old completed/failed events
func (mp *MemoryProvider) cleanup() {
mp.mu.Lock()
defer mp.mu.Unlock()
cutoff := time.Now().Add(-mp.maxAge)
removed := 0
for id, event := range mp.events {
// Only clean up completed or failed events that are old
if (event.Status == EventStatusCompleted || event.Status == EventStatusFailed) &&
event.CreatedAt.Before(cutoff) {
delete(mp.events, id)
mp.stats.TotalEvents.Add(-1)
mp.updateStatusCountsLocked(event.Status, -1)
// Remove from order tracking
for i, eid := range mp.eventOrder {
if eid == id {
mp.eventOrder = append(mp.eventOrder[:i], mp.eventOrder[i+1:]...)
break
}
}
removed++
}
}
if removed > 0 {
logger.Debug("Cleanup removed %d old events", removed)
}
}
// evictOldestLocked evicts the oldest event (LRU)
// Caller must hold write lock
func (mp *MemoryProvider) evictOldestLocked() {
if len(mp.eventOrder) == 0 {
return
}
// Get oldest event ID
oldestID := mp.eventOrder[0]
mp.eventOrder = mp.eventOrder[1:]
// Remove event
if event, exists := mp.events[oldestID]; exists {
delete(mp.events, oldestID)
mp.stats.TotalEvents.Add(-1)
mp.updateStatusCountsLocked(event.Status, -1)
mp.stats.Evictions.Add(1)
logger.Debug("Evicted oldest event: %s", oldestID)
}
}
// matchesFilter checks if an event matches the filter criteria
func (mp *MemoryProvider) matchesFilter(event *Event, filter *EventFilter) bool {
if filter == nil {
return true
}
if filter.Source != nil && event.Source != *filter.Source {
return false
}
if filter.Status != nil && event.Status != *filter.Status {
return false
}
if filter.UserID != nil && event.UserID != *filter.UserID {
return false
}
if filter.Schema != "" && event.Schema != filter.Schema {
return false
}
if filter.Entity != "" && event.Entity != filter.Entity {
return false
}
if filter.Operation != "" && event.Operation != filter.Operation {
return false
}
if filter.InstanceID != "" && event.InstanceID != filter.InstanceID {
return false
}
if filter.StartTime != nil && event.CreatedAt.Before(*filter.StartTime) {
return false
}
if filter.EndTime != nil && event.CreatedAt.After(*filter.EndTime) {
return false
}
return true
}
// updateStatusCountsLocked updates status statistics
// Caller must hold write lock
func (mp *MemoryProvider) updateStatusCountsLocked(status EventStatus, delta int64) {
switch status {
case EventStatusPending:
mp.stats.PendingEvents.Add(delta)
case EventStatusProcessing:
mp.stats.ProcessingEvents.Add(delta)
case EventStatusCompleted:
mp.stats.CompletedEvents.Add(delta)
case EventStatusFailed:
mp.stats.FailedEvents.Add(delta)
}
}

View File

@ -0,0 +1,419 @@
package eventbroker
import (
"context"
"testing"
"time"
)
func TestNewMemoryProvider(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
MaxEvents: 100,
CleanupInterval: 1 * time.Minute,
})
if provider == nil {
t.Fatal("Expected non-nil provider")
}
stats, err := provider.Stats(context.Background())
if err != nil {
t.Fatalf("Stats failed: %v", err)
}
if stats.ProviderType != "memory" {
t.Errorf("Expected provider type 'memory', got %s", stats.ProviderType)
}
}
func TestMemoryProviderPublishAndGet(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
event := NewEvent(EventSourceDatabase, "public.users.create")
event.UserID = 123
// Publish event
if err := provider.Publish(context.Background(), event); err != nil {
t.Fatalf("Publish failed: %v", err)
}
// Get event
retrieved, err := provider.Get(context.Background(), event.ID)
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if retrieved.ID != event.ID {
t.Errorf("Expected event ID %s, got %s", event.ID, retrieved.ID)
}
if retrieved.UserID != 123 {
t.Errorf("Expected user ID 123, got %d", retrieved.UserID)
}
}
func TestMemoryProviderGetNonExistent(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
_, err := provider.Get(context.Background(), "non-existent-id")
if err == nil {
t.Error("Expected error when getting non-existent event")
}
}
func TestMemoryProviderUpdateStatus(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
event := NewEvent(EventSourceDatabase, "public.users.create")
provider.Publish(context.Background(), event)
// Update status to processing
err := provider.UpdateStatus(context.Background(), event.ID, EventStatusProcessing, "")
if err != nil {
t.Fatalf("UpdateStatus failed: %v", err)
}
retrieved, _ := provider.Get(context.Background(), event.ID)
if retrieved.Status != EventStatusProcessing {
t.Errorf("Expected status %s, got %s", EventStatusProcessing, retrieved.Status)
}
// Update status to failed with error
err = provider.UpdateStatus(context.Background(), event.ID, EventStatusFailed, "test error")
if err != nil {
t.Fatalf("UpdateStatus failed: %v", err)
}
retrieved, _ = provider.Get(context.Background(), event.ID)
if retrieved.Status != EventStatusFailed {
t.Errorf("Expected status %s, got %s", EventStatusFailed, retrieved.Status)
}
if retrieved.Error != "test error" {
t.Errorf("Expected error 'test error', got %s", retrieved.Error)
}
}
func TestMemoryProviderList(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Publish multiple events
for i := 0; i < 5; i++ {
event := NewEvent(EventSourceDatabase, "public.users.create")
provider.Publish(context.Background(), event)
}
// List all events
events, err := provider.List(context.Background(), &EventFilter{})
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(events) != 5 {
t.Errorf("Expected 5 events, got %d", len(events))
}
}
func TestMemoryProviderListWithFilter(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Publish events with different types
event1 := NewEvent(EventSourceDatabase, "public.users.create")
provider.Publish(context.Background(), event1)
event2 := NewEvent(EventSourceDatabase, "public.roles.create")
provider.Publish(context.Background(), event2)
event3 := NewEvent(EventSourceWebSocket, "chat.message")
provider.Publish(context.Background(), event3)
// Filter by source
source := EventSourceDatabase
events, err := provider.List(context.Background(), &EventFilter{
Source: &source,
})
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(events) != 2 {
t.Errorf("Expected 2 events with database source, got %d", len(events))
}
// Filter by status
status := EventStatusPending
events, err = provider.List(context.Background(), &EventFilter{
Status: &status,
})
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(events) != 3 {
t.Errorf("Expected 3 events with pending status, got %d", len(events))
}
}
func TestMemoryProviderListWithLimit(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Publish multiple events
for i := 0; i < 10; i++ {
event := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event)
}
// List with limit
events, err := provider.List(context.Background(), &EventFilter{
Limit: 5,
})
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(events) != 5 {
t.Errorf("Expected 5 events (limited), got %d", len(events))
}
}
func TestMemoryProviderDelete(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
event := NewEvent(EventSourceDatabase, "public.users.create")
provider.Publish(context.Background(), event)
// Delete event
err := provider.Delete(context.Background(), event.ID)
if err != nil {
t.Fatalf("Delete failed: %v", err)
}
// Verify deleted
_, err = provider.Get(context.Background(), event.ID)
if err == nil {
t.Error("Expected error when getting deleted event")
}
}
func TestMemoryProviderLRUEviction(t *testing.T) {
// Create provider with small max events
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
MaxEvents: 3,
})
// Publish 5 events
events := make([]*Event, 5)
for i := 0; i < 5; i++ {
events[i] = NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), events[i])
}
// First 2 events should be evicted
_, err := provider.Get(context.Background(), events[0].ID)
if err == nil {
t.Error("Expected first event to be evicted")
}
_, err = provider.Get(context.Background(), events[1].ID)
if err == nil {
t.Error("Expected second event to be evicted")
}
// Last 3 events should still exist
for i := 2; i < 5; i++ {
_, err := provider.Get(context.Background(), events[i].ID)
if err != nil {
t.Errorf("Expected event %d to still exist", i)
}
}
}
func TestMemoryProviderCleanup(t *testing.T) {
// Create provider with short cleanup interval
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
CleanupInterval: 100 * time.Millisecond,
MaxAge: 200 * time.Millisecond,
})
// Publish and complete an event
event := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event)
provider.UpdateStatus(context.Background(), event.ID, EventStatusCompleted, "")
// Wait for cleanup to run
time.Sleep(400 * time.Millisecond)
// Event should be cleaned up
_, err := provider.Get(context.Background(), event.ID)
if err == nil {
t.Error("Expected event to be cleaned up")
}
provider.Close()
}
func TestMemoryProviderStats(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
MaxEvents: 100,
})
// Publish events
for i := 0; i < 5; i++ {
event := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event)
}
stats, err := provider.Stats(context.Background())
if err != nil {
t.Fatalf("Stats failed: %v", err)
}
if stats.ProviderType != "memory" {
t.Errorf("Expected provider type 'memory', got %s", stats.ProviderType)
}
if stats.TotalEvents != 5 {
t.Errorf("Expected 5 total events, got %d", stats.TotalEvents)
}
}
func TestMemoryProviderClose(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
CleanupInterval: 100 * time.Millisecond,
})
// Publish event
event := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event)
// Close provider
err := provider.Close()
if err != nil {
t.Fatalf("Close failed: %v", err)
}
// Cleanup goroutine should be stopped
time.Sleep(200 * time.Millisecond)
}
func TestMemoryProviderConcurrency(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Concurrent publish
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func() {
defer func() { done <- true }()
event := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event)
}()
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
// Verify all events were stored
events, _ := provider.List(context.Background(), &EventFilter{})
if len(events) != 10 {
t.Errorf("Expected 10 events, got %d", len(events))
}
}
func TestMemoryProviderStream(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Stream is implemented for memory provider (in-process pub/sub)
ch, err := provider.Stream(context.Background(), "test.*")
if err != nil {
t.Fatalf("Stream failed: %v", err)
}
if ch == nil {
t.Error("Expected non-nil channel")
}
}
func TestMemoryProviderTimeRangeFilter(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Publish events at different times
event1 := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event1)
time.Sleep(10 * time.Millisecond)
event2 := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event2)
time.Sleep(10 * time.Millisecond)
event3 := NewEvent(EventSourceDatabase, "test.event")
provider.Publish(context.Background(), event3)
// Filter by time range
startTime := event2.CreatedAt.Add(-1 * time.Millisecond)
events, err := provider.List(context.Background(), &EventFilter{
StartTime: &startTime,
})
if err != nil {
t.Fatalf("List failed: %v", err)
}
// Should get events 2 and 3
if len(events) != 2 {
t.Errorf("Expected 2 events after start time, got %d", len(events))
}
}
func TestMemoryProviderInstanceIDFilter(t *testing.T) {
provider := NewMemoryProvider(MemoryProviderOptions{
InstanceID: "test-instance",
})
// Publish events with different instance IDs
event1 := NewEvent(EventSourceDatabase, "test.event")
event1.InstanceID = "instance-1"
provider.Publish(context.Background(), event1)
event2 := NewEvent(EventSourceDatabase, "test.event")
event2.InstanceID = "instance-2"
provider.Publish(context.Background(), event2)
// Filter by instance ID
events, err := provider.List(context.Background(), &EventFilter{
InstanceID: "instance-1",
})
if err != nil {
t.Fatalf("List failed: %v", err)
}
if len(events) != 1 {
t.Errorf("Expected 1 event with instance-1, got %d", len(events))
}
if events[0].InstanceID != "instance-1" {
t.Errorf("Expected instance ID 'instance-1', got %s", events[0].InstanceID)
}
}

View File

@ -0,0 +1,140 @@
package eventbroker
import (
"fmt"
"strings"
"sync"
"sync/atomic"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// SubscriptionID uniquely identifies a subscription
type SubscriptionID string
// subscription represents a single subscription with its handler and pattern
type subscription struct {
id SubscriptionID
pattern string
handler EventHandler
}
// subscriptionManager manages event subscriptions and pattern matching
type subscriptionManager struct {
mu sync.RWMutex
subscriptions map[SubscriptionID]*subscription
nextID atomic.Uint64
}
// newSubscriptionManager creates a new subscription manager
func newSubscriptionManager() *subscriptionManager {
return &subscriptionManager{
subscriptions: make(map[SubscriptionID]*subscription),
}
}
// Subscribe adds a new subscription
func (sm *subscriptionManager) Subscribe(pattern string, handler EventHandler) (SubscriptionID, error) {
if pattern == "" {
return "", fmt.Errorf("pattern cannot be empty")
}
if handler == nil {
return "", fmt.Errorf("handler cannot be nil")
}
id := SubscriptionID(fmt.Sprintf("sub-%d", sm.nextID.Add(1)))
sm.mu.Lock()
sm.subscriptions[id] = &subscription{
id: id,
pattern: pattern,
handler: handler,
}
sm.mu.Unlock()
logger.Info("Subscribed to pattern '%s' with ID: %s", pattern, id)
return id, nil
}
// Unsubscribe removes a subscription
func (sm *subscriptionManager) Unsubscribe(id SubscriptionID) error {
sm.mu.Lock()
defer sm.mu.Unlock()
if _, exists := sm.subscriptions[id]; !exists {
return fmt.Errorf("subscription not found: %s", id)
}
delete(sm.subscriptions, id)
logger.Info("Unsubscribed: %s", id)
return nil
}
// GetMatching returns all handlers that match the event type
func (sm *subscriptionManager) GetMatching(eventType string) []EventHandler {
sm.mu.RLock()
defer sm.mu.RUnlock()
var handlers []EventHandler
for _, sub := range sm.subscriptions {
if matchPattern(sub.pattern, eventType) {
handlers = append(handlers, sub.handler)
}
}
return handlers
}
// Count returns the number of active subscriptions
func (sm *subscriptionManager) Count() int {
sm.mu.RLock()
defer sm.mu.RUnlock()
return len(sm.subscriptions)
}
// Clear removes all subscriptions
func (sm *subscriptionManager) Clear() {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.subscriptions = make(map[SubscriptionID]*subscription)
logger.Info("Cleared all subscriptions")
}
// matchPattern implements glob-style pattern matching for event types
// Patterns:
// - "*" matches any single segment
// - "a.b.c" matches exactly "a.b.c"
// - "a.*.c" matches "a.anything.c"
// - "a.b.*" matches any operation on a.b
// - "*" matches everything
//
// Event type format: schema.entity.operation (e.g., "public.users.create")
func matchPattern(pattern, eventType string) bool {
// Wildcard matches everything
if pattern == "*" {
return true
}
// Exact match
if pattern == eventType {
return true
}
// Split pattern and event type by dots
patternParts := strings.Split(pattern, ".")
eventParts := strings.Split(eventType, ".")
// Different number of parts can only match if pattern has wildcards
if len(patternParts) != len(eventParts) {
return false
}
// Match each part
for i := range patternParts {
if patternParts[i] != "*" && patternParts[i] != eventParts[i] {
return false
}
}
return true
}

View File

@ -0,0 +1,270 @@
package eventbroker
import (
"context"
"testing"
)
func TestMatchPattern(t *testing.T) {
tests := []struct {
pattern string
eventType string
expected bool
}{
// Exact matches
{"public.users.create", "public.users.create", true},
{"public.users.create", "public.users.update", false},
// Wildcard matches
{"*", "public.users.create", true},
{"*", "anything", true},
{"public.*", "public.users", true},
{"public.*", "public.users.create", false}, // Different number of parts
{"public.*", "admin.users", false},
{"*.users.create", "public.users.create", true},
{"*.users.create", "admin.users.create", true},
{"*.users.create", "public.roles.create", false},
{"public.*.create", "public.users.create", true},
{"public.*.create", "public.roles.create", true},
{"public.*.create", "public.users.update", false},
// Multiple wildcards
{"*.*", "public.users", true},
{"*.*", "public.users.create", false}, // Different number of parts
{"*.*.create", "public.users.create", true},
{"*.*.create", "admin.roles.create", true},
{"*.*.create", "public.users.update", false},
// Edge cases
{"", "", true},
{"", "something", false},
{"something", "", false},
}
for _, tt := range tests {
t.Run(tt.pattern+"_vs_"+tt.eventType, func(t *testing.T) {
result := matchPattern(tt.pattern, tt.eventType)
if result != tt.expected {
t.Errorf("matchPattern(%q, %q) = %v, expected %v",
tt.pattern, tt.eventType, result, tt.expected)
}
})
}
}
func TestSubscriptionManager(t *testing.T) {
manager := newSubscriptionManager()
// Create test handler
called := false
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
called = true
return nil
})
// Test Subscribe
id, err := manager.Subscribe("public.users.*", handler)
if err != nil {
t.Fatalf("Subscribe failed: %v", err)
}
if id == "" {
t.Fatal("Expected non-empty subscription ID")
}
// Test GetMatching
handlers := manager.GetMatching("public.users.create")
if len(handlers) != 1 {
t.Fatalf("Expected 1 handler, got %d", len(handlers))
}
// Test handler execution
event := NewEvent(EventSourceDatabase, "public.users.create")
if err := handlers[0].Handle(context.Background(), event); err != nil {
t.Fatalf("Handler execution failed: %v", err)
}
if !called {
t.Error("Expected handler to be called")
}
// Test Count
if manager.Count() != 1 {
t.Errorf("Expected count 1, got %d", manager.Count())
}
// Test Unsubscribe
if err := manager.Unsubscribe(id); err != nil {
t.Fatalf("Unsubscribe failed: %v", err)
}
// Verify unsubscribed
handlers = manager.GetMatching("public.users.create")
if len(handlers) != 0 {
t.Errorf("Expected 0 handlers after unsubscribe, got %d", len(handlers))
}
if manager.Count() != 0 {
t.Errorf("Expected count 0 after unsubscribe, got %d", manager.Count())
}
}
func TestSubscriptionManagerMultipleHandlers(t *testing.T) {
manager := newSubscriptionManager()
called1 := false
handler1 := EventHandlerFunc(func(ctx context.Context, event *Event) error {
called1 = true
return nil
})
called2 := false
handler2 := EventHandlerFunc(func(ctx context.Context, event *Event) error {
called2 = true
return nil
})
// Subscribe multiple handlers
id1, _ := manager.Subscribe("public.users.*", handler1)
id2, _ := manager.Subscribe("*.users.*", handler2)
// Both should match
handlers := manager.GetMatching("public.users.create")
if len(handlers) != 2 {
t.Fatalf("Expected 2 handlers, got %d", len(handlers))
}
// Execute all handlers
event := NewEvent(EventSourceDatabase, "public.users.create")
for _, h := range handlers {
h.Handle(context.Background(), event)
}
if !called1 || !called2 {
t.Error("Expected both handlers to be called")
}
// Unsubscribe one
manager.Unsubscribe(id1)
handlers = manager.GetMatching("public.users.create")
if len(handlers) != 1 {
t.Errorf("Expected 1 handler after unsubscribe, got %d", len(handlers))
}
// Unsubscribe remaining
manager.Unsubscribe(id2)
if manager.Count() != 0 {
t.Errorf("Expected count 0 after all unsubscribe, got %d", manager.Count())
}
}
func TestSubscriptionManagerConcurrency(t *testing.T) {
manager := newSubscriptionManager()
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
return nil
})
// Subscribe and unsubscribe concurrently
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func() {
defer func() { done <- true }()
id, _ := manager.Subscribe("test.*", handler)
manager.GetMatching("test.event")
manager.Unsubscribe(id)
}()
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
// Should have no subscriptions left
if manager.Count() != 0 {
t.Errorf("Expected count 0 after concurrent operations, got %d", manager.Count())
}
}
func TestSubscriptionManagerUnsubscribeNonExistent(t *testing.T) {
manager := newSubscriptionManager()
// Try to unsubscribe a non-existent ID
err := manager.Unsubscribe("non-existent-id")
if err == nil {
t.Error("Expected error when unsubscribing non-existent ID")
}
}
func TestSubscriptionIDGeneration(t *testing.T) {
manager := newSubscriptionManager()
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
return nil
})
// Subscribe multiple times and ensure unique IDs
ids := make(map[SubscriptionID]bool)
for i := 0; i < 100; i++ {
id, _ := manager.Subscribe("test.*", handler)
if ids[id] {
t.Fatalf("Duplicate subscription ID: %s", id)
}
ids[id] = true
}
}
func TestEventHandlerFunc(t *testing.T) {
called := false
var receivedEvent *Event
handler := EventHandlerFunc(func(ctx context.Context, event *Event) error {
called = true
receivedEvent = event
return nil
})
event := NewEvent(EventSourceDatabase, "test.event")
err := handler.Handle(context.Background(), event)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if !called {
t.Error("Expected handler to be called")
}
if receivedEvent != event {
t.Error("Expected to receive the same event")
}
}
func TestSubscriptionManagerPatternPriority(t *testing.T) {
manager := newSubscriptionManager()
// More specific patterns should still match
specificCalled := false
genericCalled := false
manager.Subscribe("public.users.create", EventHandlerFunc(func(ctx context.Context, event *Event) error {
specificCalled = true
return nil
}))
manager.Subscribe("*", EventHandlerFunc(func(ctx context.Context, event *Event) error {
genericCalled = true
return nil
}))
handlers := manager.GetMatching("public.users.create")
if len(handlers) != 2 {
t.Fatalf("Expected 2 matching handlers, got %d", len(handlers))
}
// Execute all handlers
event := NewEvent(EventSourceDatabase, "public.users.create")
for _, h := range handlers {
h.Handle(context.Background(), event)
}
if !specificCalled || !genericCalled {
t.Error("Expected both specific and generic handlers to be called")
}
}

View File

@ -0,0 +1,141 @@
package eventbroker
import (
"context"
"sync"
"sync/atomic"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// workerPool manages a pool of workers for async event processing
type workerPool struct {
workerCount int
bufferSize int
eventQueue chan *Event
processor func(context.Context, *Event) error
activeWorkers atomic.Int32
isRunning atomic.Bool
stopCh chan struct{}
wg sync.WaitGroup
}
// newWorkerPool creates a new worker pool
func newWorkerPool(workerCount, bufferSize int, processor func(context.Context, *Event) error) *workerPool {
return &workerPool{
workerCount: workerCount,
bufferSize: bufferSize,
eventQueue: make(chan *Event, bufferSize),
processor: processor,
stopCh: make(chan struct{}),
}
}
// Start starts the worker pool
func (wp *workerPool) Start() {
if wp.isRunning.Load() {
return
}
wp.isRunning.Store(true)
// Start workers
for i := 0; i < wp.workerCount; i++ {
wp.wg.Add(1)
go wp.worker(i)
}
logger.Info("Worker pool started with %d workers", wp.workerCount)
}
// Stop stops the worker pool gracefully
func (wp *workerPool) Stop(ctx context.Context) error {
if !wp.isRunning.Load() {
return nil
}
wp.isRunning.Store(false)
// Close event queue to signal workers
close(wp.eventQueue)
// Wait for workers to finish with context timeout
done := make(chan struct{})
go func() {
wp.wg.Wait()
close(done)
}()
select {
case <-done:
logger.Info("Worker pool stopped gracefully")
return nil
case <-ctx.Done():
logger.Warn("Worker pool stop timed out, some events may be lost")
return ctx.Err()
}
}
// Submit submits an event to the queue
func (wp *workerPool) Submit(ctx context.Context, event *Event) error {
if !wp.isRunning.Load() {
return ErrWorkerPoolStopped
}
select {
case wp.eventQueue <- event:
return nil
case <-ctx.Done():
return ctx.Err()
default:
return ErrQueueFull
}
}
// worker is a worker goroutine that processes events from the queue
func (wp *workerPool) worker(id int) {
defer wp.wg.Done()
logger.Debug("Worker %d started", id)
for event := range wp.eventQueue {
wp.activeWorkers.Add(1)
// Process event with background context (detached from original request)
ctx := context.Background()
if err := wp.processor(ctx, event); err != nil {
logger.Error("Worker %d failed to process event %s: %v", id, event.ID, err)
}
wp.activeWorkers.Add(-1)
}
logger.Debug("Worker %d stopped", id)
}
// QueueSize returns the current queue size
func (wp *workerPool) QueueSize() int {
return len(wp.eventQueue)
}
// ActiveWorkers returns the number of currently active workers
func (wp *workerPool) ActiveWorkers() int {
return int(wp.activeWorkers.Load())
}
// Error definitions
var (
ErrWorkerPoolStopped = &BrokerError{Code: "worker_pool_stopped", Message: "worker pool is stopped"}
ErrQueueFull = &BrokerError{Code: "queue_full", Message: "event queue is full"}
)
// BrokerError represents an error from the event broker
type BrokerError struct {
Code string
Message string
}
func (e *BrokerError) Error() string {
return e.Message
}

1056
pkg/funcspec/function_api.go Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,910 @@
package funcspec
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// MockDatabase implements common.Database interface for testing
type MockDatabase struct {
QueryFunc func(ctx context.Context, dest interface{}, query string, args ...interface{}) error
ExecFunc func(ctx context.Context, query string, args ...interface{}) (common.Result, error)
RunInTransactionFunc func(ctx context.Context, fn func(common.Database) error) error
}
func (m *MockDatabase) NewSelect() common.SelectQuery {
return nil
}
func (m *MockDatabase) NewInsert() common.InsertQuery {
return nil
}
func (m *MockDatabase) NewUpdate() common.UpdateQuery {
return nil
}
func (m *MockDatabase) NewDelete() common.DeleteQuery {
return nil
}
func (m *MockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
if m.ExecFunc != nil {
return m.ExecFunc(ctx, query, args...)
}
return &MockResult{rows: 0}, nil
}
func (m *MockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if m.QueryFunc != nil {
return m.QueryFunc(ctx, dest, query, args...)
}
return nil
}
func (m *MockDatabase) BeginTx(ctx context.Context) (common.Database, error) {
return m, nil
}
func (m *MockDatabase) CommitTx(ctx context.Context) error {
return nil
}
func (m *MockDatabase) RollbackTx(ctx context.Context) error {
return nil
}
func (m *MockDatabase) RunInTransaction(ctx context.Context, fn func(common.Database) error) error {
if m.RunInTransactionFunc != nil {
return m.RunInTransactionFunc(ctx, fn)
}
return fn(m)
}
func (m *MockDatabase) GetUnderlyingDB() interface{} {
return m
}
// MockResult implements common.Result interface for testing
type MockResult struct {
rows int64
id int64
}
func (m *MockResult) RowsAffected() int64 {
return m.rows
}
func (m *MockResult) LastInsertId() (int64, error) {
return m.id, nil
}
// Helper function to create a test request with user context
func createTestRequest(method, path string, queryParams map[string]string, headers map[string]string, body []byte) *http.Request {
u, _ := url.Parse(path)
if queryParams != nil {
q := u.Query()
for k, v := range queryParams {
q.Set(k, v)
}
u.RawQuery = q.Encode()
}
var bodyReader *bytes.Reader
if body != nil {
bodyReader = bytes.NewReader(body)
} else {
bodyReader = bytes.NewReader([]byte{})
}
req := httptest.NewRequest(method, u.String(), bodyReader)
if headers != nil {
for k, v := range headers {
req.Header.Set(k, v)
}
}
// Add user context
userCtx := &security.UserContext{
UserID: 1,
UserName: "testuser",
SessionID: "test-session-123",
}
ctx := context.WithValue(req.Context(), security.UserContextKey, userCtx)
req = req.WithContext(ctx)
return req
}
// TestNewHandler tests handler creation
func TestNewHandler(t *testing.T) {
db := &MockDatabase{}
handler := NewHandler(db)
if handler == nil {
t.Fatal("Expected handler to be created, got nil")
}
if handler.db != db {
t.Error("Expected handler to have the provided database")
}
if handler.hooks == nil {
t.Error("Expected handler to have a hook registry")
}
}
// TestHandlerHooks tests the Hooks method
func TestHandlerHooks(t *testing.T) {
handler := NewHandler(&MockDatabase{})
hooks := handler.Hooks()
if hooks == nil {
t.Fatal("Expected hooks registry to be non-nil")
}
// Should return the same instance
hooks2 := handler.Hooks()
if hooks != hooks2 {
t.Error("Expected Hooks() to return the same registry instance")
}
}
// TestExtractInputVariables tests the extractInputVariables function
func TestExtractInputVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
expectedVars []string
}{
{
name: "No variables",
sqlQuery: "SELECT * FROM users",
expectedVars: []string{},
},
{
name: "Single variable",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
expectedVars: []string{"[user_id]"},
},
{
name: "Multiple variables",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = [user_name]",
expectedVars: []string{"[user_id]", "[user_name]"},
},
{
name: "Nested brackets",
sqlQuery: "SELECT * FROM users WHERE data::jsonb @> '[field]'::jsonb AND id = [user_id]",
expectedVars: []string{"[field]", "[user_id]"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
inputvars := make([]string, 0)
result := handler.extractInputVariables(tt.sqlQuery, &inputvars)
if result != tt.sqlQuery {
t.Errorf("Expected SQL query to be unchanged, got %s", result)
}
if len(inputvars) != len(tt.expectedVars) {
t.Errorf("Expected %d variables, got %d: %v", len(tt.expectedVars), len(inputvars), inputvars)
return
}
for i, expected := range tt.expectedVars {
if inputvars[i] != expected {
t.Errorf("Expected variable %d to be %s, got %s", i, expected, inputvars[i])
}
}
})
}
}
// TestValidSQL tests the SQL sanitization function
func TestValidSQL(t *testing.T) {
tests := []struct {
name string
input string
mode string
expected string
}{
{
name: "Column name with valid characters",
input: "user_id",
mode: "colname",
expected: "user_id",
},
{
name: "Column name with dots (table.column)",
input: "users.user_id",
mode: "colname",
expected: "users.user_id",
},
{
name: "Column name with SQL injection attempt",
input: "id'; DROP TABLE users--",
mode: "colname",
expected: "idDROPTABLEusers",
},
{
name: "Column value with single quotes",
input: "O'Brien",
mode: "colvalue",
expected: "O''Brien",
},
{
name: "Select with dangerous keywords",
input: "name, email; DROP TABLE users",
mode: "select",
expected: "name, email TABLE users",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ValidSQL(tt.input, tt.mode)
if result != tt.expected {
t.Errorf("ValidSQL(%q, %q) = %q, expected %q", tt.input, tt.mode, result, tt.expected)
}
})
}
}
// TestIsNumeric tests the IsNumeric function
func TestIsNumeric(t *testing.T) {
tests := []struct {
input string
expected bool
}{
{"123", true},
{"123.45", true},
{"-123", true},
{"-123.45", true},
{"0", true},
{"abc", false},
{"12.34.56", false},
{"", false},
{"123abc", false},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := IsNumeric(tt.input)
if result != tt.expected {
t.Errorf("IsNumeric(%q) = %v, expected %v", tt.input, result, tt.expected)
}
})
}
}
// TestSqlQryWhere tests the WHERE clause manipulation
func TestSqlQryWhere(t *testing.T) {
tests := []struct {
name string
sqlQuery string
condition string
expected string
}{
{
name: "Add WHERE to query without WHERE",
sqlQuery: "SELECT * FROM users",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' ",
},
{
name: "Add AND to query with existing WHERE",
sqlQuery: "SELECT * FROM users WHERE id > 0",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE id > 0 AND status = 'active' ",
},
{
name: "Add WHERE before ORDER BY",
sqlQuery: "SELECT * FROM users ORDER BY name",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' ORDER BY name",
},
{
name: "Add WHERE before GROUP BY",
sqlQuery: "SELECT COUNT(*) FROM users GROUP BY department",
condition: "status = 'active'",
expected: "SELECT COUNT(*) FROM users WHERE status = 'active' GROUP BY department",
},
{
name: "Add WHERE before LIMIT",
sqlQuery: "SELECT * FROM users LIMIT 10",
condition: "status = 'active'",
expected: "SELECT * FROM users WHERE status = 'active' LIMIT 10",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sqlQryWhere(tt.sqlQuery, tt.condition)
if result != tt.expected {
t.Errorf("sqlQryWhere() = %q, expected %q", result, tt.expected)
}
})
}
}
// TestGetIPAddress tests IP address extraction
func TestGetIPAddress(t *testing.T) {
tests := []struct {
name string
setupReq func() *http.Request
expected string
}{
{
name: "X-Forwarded-For header",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", "192.168.1.100, 10.0.0.1")
return req
},
expected: "192.168.1.100",
},
{
name: "X-Real-IP header",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.200")
return req
},
expected: "192.168.1.200",
},
{
name: "RemoteAddr fallback",
setupReq: func() *http.Request {
req := httptest.NewRequest("GET", "/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
return req
},
expected: "192.168.1.1:12345",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := tt.setupReq()
result := getIPAddress(req)
if result != tt.expected {
t.Errorf("getIPAddress() = %q, expected %q", result, tt.expected)
}
})
}
}
// TestParsePaginationParams tests pagination parameter parsing
func TestParsePaginationParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
queryParams map[string]string
expectedSort string
expectedLimit int
expectedOffset int
}{
{
name: "No parameters - defaults",
queryParams: map[string]string{},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "All parameters provided",
queryParams: map[string]string{
"sort": "name,-created_at",
"limit": "100",
"offset": "50",
},
expectedSort: "name,-created_at",
expectedLimit: 100,
expectedOffset: 50,
},
{
name: "Invalid limit - use default",
queryParams: map[string]string{
"limit": "invalid",
},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
{
name: "Negative offset - use default",
queryParams: map[string]string{
"offset": "-10",
},
expectedSort: "",
expectedLimit: 20,
expectedOffset: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
sort, limit, offset := handler.parsePaginationParams(req)
if sort != tt.expectedSort {
t.Errorf("Expected sort=%q, got %q", tt.expectedSort, sort)
}
if limit != tt.expectedLimit {
t.Errorf("Expected limit=%d, got %d", tt.expectedLimit, limit)
}
if offset != tt.expectedOffset {
t.Errorf("Expected offset=%d, got %d", tt.expectedOffset, offset)
}
})
}
}
// TestSqlQuery tests the SqlQuery handler for single record queries
func TestSqlQuery(t *testing.T) {
tests := []struct {
name string
sqlQuery string
blankParams bool
queryParams map[string]string
headers map[string]string
setupDB func() *MockDatabase
expectedStatus int
validateResp func(t *testing.T, body []byte)
}{
{
name: "Basic query - returns single record",
sqlQuery: "SELECT * FROM users WHERE id = 1",
blankParams: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "Test User", "email": "test@example.com"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, body []byte) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if result["name"] != "Test User" {
t.Errorf("Expected name='Test User', got %v", result["name"])
}
},
},
{
name: "Query with no results",
sqlQuery: "SELECT * FROM users WHERE id = 999",
blankParams: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
// Return empty array
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, body []byte) {
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 0 {
t.Errorf("Expected empty result, got %v", result)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.setupDB()
handler := NewHandler(db)
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQuery(tt.sqlQuery, SqlQueryOptions{BlankParams: tt.blankParams})
handlerFunc(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d", tt.expectedStatus, w.Code)
}
if tt.validateResp != nil {
tt.validateResp(t, w.Body.Bytes())
}
})
}
}
// TestSqlQueryList tests the SqlQueryList handler for list queries
func TestSqlQueryList(t *testing.T) {
tests := []struct {
name string
sqlQuery string
noCount bool
blankParams bool
allowFilter bool
queryParams map[string]string
headers map[string]string
setupDB func() *MockDatabase
expectedStatus int
validateResp func(t *testing.T, w *httptest.ResponseRecorder)
}{
{
name: "Basic list query",
sqlQuery: "SELECT * FROM users",
noCount: false,
blankParams: false,
allowFilter: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
callCount := 0
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
callCount++
if strings.Contains(query, "COUNT") {
// Count query
countResult := dest.(*struct{ Count int64 })
countResult.Count = 2
} else {
// Main query
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "User 1"},
{"id": float64(2), "name": "User 2"},
}
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var result []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 2 {
t.Errorf("Expected 2 results, got %d", len(result))
}
// Check Content-Range header
contentRange := w.Header().Get("Content-Range")
if !strings.Contains(contentRange, "2") {
t.Errorf("Expected Content-Range to contain total count, got: %s", contentRange)
}
},
},
{
name: "List query with noCount",
sqlQuery: "SELECT * FROM users",
noCount: true,
blankParams: false,
allowFilter: false,
setupDB: func() *MockDatabase {
return &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
db := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
if strings.Contains(query, "COUNT") {
t.Error("Count query should not be executed when noCount is true")
}
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "User 1"},
}
return nil
},
}
return fn(db)
},
}
},
expectedStatus: 200,
validateResp: func(t *testing.T, w *httptest.ResponseRecorder) {
var result []map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &result); err != nil {
t.Fatalf("Failed to unmarshal response: %v", err)
}
if len(result) != 1 {
t.Errorf("Expected 1 result, got %d", len(result))
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := tt.setupDB()
handler := NewHandler(db)
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQueryList(tt.sqlQuery, SqlQueryOptions{NoCount: tt.noCount, BlankParams: tt.blankParams, AllowFilter: tt.allowFilter})
handlerFunc(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("Expected status %d, got %d. Body: %s", tt.expectedStatus, w.Code, w.Body.String())
}
if tt.validateResp != nil {
tt.validateResp(t, w)
}
})
}
}
// TestMergeQueryParams tests query parameter merging
func TestMergeQueryParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
queryParams map[string]string
allowFilter bool
expectedQuery string
checkVars func(t *testing.T, vars map[string]interface{})
}{
{
name: "Replace placeholder with parameter",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
queryParams: map[string]string{"p-user_id": "123"},
allowFilter: false,
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["p-user_id"] != "123" {
t.Errorf("Expected p-user_id=123, got %v", vars["p-user_id"])
}
},
},
{
name: "Add filter when allowed",
sqlQuery: "SELECT * FROM users",
queryParams: map[string]string{"status": "active"},
allowFilter: true,
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["status"] != "active" {
t.Errorf("Expected status=active, got %v", vars["status"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
result := handler.mergeQueryParams(req, tt.sqlQuery, variables, tt.allowFilter, propQry)
if result == "" {
t.Error("Expected non-empty SQL query result")
}
if tt.checkVars != nil {
tt.checkVars(t, variables)
}
})
}
}
// TestMergeHeaderParams tests header parameter merging
func TestMergeHeaderParams(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
headers map[string]string
expectedQuery string
checkVars func(t *testing.T, vars map[string]interface{})
}{
{
name: "Field filter header",
sqlQuery: "SELECT * FROM users",
headers: map[string]string{"X-FieldFilter-Status": "1"},
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["x-fieldfilter-status"] != "1" {
t.Errorf("Expected x-fieldfilter-status=1, got %v", vars["x-fieldfilter-status"])
}
},
},
{
name: "Search filter header",
sqlQuery: "SELECT * FROM users",
headers: map[string]string{"X-SearchFilter-Name": "john"},
checkVars: func(t *testing.T, vars map[string]interface{}) {
if vars["x-searchfilter-name"] != "john" {
t.Errorf("Expected x-searchfilter-name=john, got %v", vars["x-searchfilter-name"])
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", nil, tt.headers, nil)
variables := make(map[string]interface{})
propQry := make(map[string]string)
complexAPI := false
result := handler.mergeHeaderParams(req, tt.sqlQuery, variables, propQry, &complexAPI)
if result == "" {
t.Error("Expected non-empty SQL query result")
}
if tt.checkVars != nil {
tt.checkVars(t, variables)
}
})
}
}
// TestReplaceMetaVariables tests meta variable replacement
func TestReplaceMetaVariables(t *testing.T) {
handler := NewHandler(&MockDatabase{})
userCtx := &security.UserContext{
UserID: 123,
UserName: "testuser",
SessionID: "ABC456",
SessionRID: 456,
}
metainfo := map[string]interface{}{
"ipaddress": "192.168.1.1",
"url": "/api/test",
}
variables := map[string]interface{}{
"param1": "value1",
}
tests := []struct {
name string
sqlQuery string
expectedCheck func(result string) bool
}{
{
name: "Replace [rid_user]",
sqlQuery: "SELECT * FROM users WHERE created_by = [rid_user]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "123")
},
},
{
name: "Replace [user]",
sqlQuery: "SELECT * FROM audit WHERE username = [user]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "'testuser'")
},
},
{
name: "Replace [rid_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [rid_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "456")
},
}, {
name: "Replace [id_session]",
sqlQuery: "SELECT * FROM sessions WHERE session_id = [id_session]",
expectedCheck: func(result string) bool {
return strings.Contains(result, "ABC456")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", nil, nil, nil)
result := handler.replaceMetaVariables(tt.sqlQuery, req, userCtx, metainfo, variables)
if !tt.expectedCheck(result) {
t.Errorf("Meta variable replacement failed. Query: %s", result)
}
})
}
}
// TestGetReplacementForBlankParam tests the blank parameter replacement logic
func TestGetReplacementForBlankParam(t *testing.T) {
tests := []struct {
name string
sqlQuery string
param string
expected string
}{
{
name: "Parameter in single quotes",
sqlQuery: "SELECT * FROM users WHERE name = '[username]'",
param: "[username]",
expected: "",
},
{
name: "Parameter in dollar quotes",
sqlQuery: "SELECT * FROM users WHERE data = $[jsondata]$",
param: "[jsondata]",
expected: "",
},
{
name: "Parameter not in quotes",
sqlQuery: "SELECT * FROM users WHERE id = [user_id]",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter not in quotes with AND",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND status = 1",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter in mixed quote context - before quote",
sqlQuery: "SELECT * FROM users WHERE id = [user_id] AND name = 'test'",
param: "[user_id]",
expected: "NULL",
},
{
name: "Parameter in mixed quote context - in quotes",
sqlQuery: "SELECT * FROM users WHERE name = '[username]' AND id = 1",
param: "[username]",
expected: "",
},
{
name: "Parameter with dollar quote tag",
sqlQuery: "SELECT * FROM users WHERE body = $tag$[content]$tag$",
param: "[content]",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getReplacementForBlankParam(tt.sqlQuery, tt.param)
if result != tt.expected {
t.Errorf("Expected replacement '%s', got '%s' for query: %s", tt.expected, result, tt.sqlQuery)
}
})
}
}

160
pkg/funcspec/hooks.go Normal file
View File

@ -0,0 +1,160 @@
package funcspec
import (
"context"
"fmt"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// HookType defines the type of hook to execute
type HookType string
const (
// Query operation hooks (for SqlQuery - single record)
BeforeQuery HookType = "before_query"
AfterQuery HookType = "after_query"
// Query list operation hooks (for SqlQueryList - multiple records)
BeforeQueryList HookType = "before_query_list"
AfterQueryList HookType = "after_query_list"
// SQL execution hooks (just before SQL is executed)
BeforeSQLExec HookType = "before_sql_exec"
AfterSQLExec HookType = "after_sql_exec"
// Response hooks (before response is sent)
BeforeResponse HookType = "before_response"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler // Reference to the handler for accessing database
Request *http.Request
Writer http.ResponseWriter
// SQL query and variables
SQLQuery string // The SQL query being executed (can be modified by hooks)
Variables map[string]interface{} // Variables extracted from request
InputVars []string // Input variable placeholders found in query
MetaInfo map[string]interface{} // Metadata about the request
PropQry map[string]string // Property query parameters
// User context
UserContext *security.UserContext
// Pagination and filtering (for list queries)
SortColumns string
Limit int
Offset int
// Results
Result interface{} // Query result (single record or list)
Total int64 // Total count (for list queries)
Error error // Error if operation failed
ComplexAPI bool // Whether complex API response format is requested
NoCount bool // Whether count query should be skipped
BlankParams bool // Whether blank parameters should be removed
AllowFilter bool // Whether filtering is allowed
// Allow hooks to abort the operation
Abort bool // If set to true, the operation will be aborted
AbortMessage string // Message to return if aborted
AbortCode int // HTTP status code if aborted
}
// HookFunc is the signature for hook functions
// It receives a HookContext and can modify it or return an error
// If an error is returned, the operation will be aborted
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
// NewHookRegistry creates a new hook registry
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
// Register adds a new hook for the specified hook type
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered funcspec hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
// RegisterMultiple registers a hook for multiple hook types
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
// Execute runs all hooks for the specified type in order
// If any hook returns an error, execution stops and the error is returned
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
return nil
}
logger.Debug("Executing %d funcspec hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("Funcspec hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
// Check if hook requested abort
if ctx.Abort {
logger.Warn("Funcspec hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil
}
// Clear removes all hooks for the specified type
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
logger.Info("Cleared all funcspec hooks for %s", hookType)
}
// ClearAll removes all registered hooks
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
logger.Info("Cleared all funcspec hooks")
}
// Count returns the number of hooks registered for a specific type
func (r *HookRegistry) Count(hookType HookType) int {
if hooks, exists := r.hooks[hookType]; exists {
return len(hooks)
}
return 0
}
// HasHooks returns true if there are any hooks registered for the specified type
func (r *HookRegistry) HasHooks(hookType HookType) bool {
return r.Count(hookType) > 0
}
// GetAllHookTypes returns all hook types that have registered hooks
func (r *HookRegistry) GetAllHookTypes() []HookType {
types := make([]HookType, 0, len(r.hooks))
for hookType := range r.hooks {
types = append(types, hookType)
}
return types
}

View File

@ -0,0 +1,137 @@
package funcspec
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Example hook functions demonstrating various use cases
// ExampleLoggingHook logs all SQL queries before execution
func ExampleLoggingHook(ctx *HookContext) error {
logger.Info("Executing SQL query for user %s: %s", ctx.UserContext.UserName, ctx.SQLQuery)
return nil
}
// ExampleSecurityHook validates user permissions before executing queries
func ExampleSecurityHook(ctx *HookContext) error {
// Example: Block queries that try to access sensitive tables
if strings.Contains(strings.ToLower(ctx.SQLQuery), "sensitive_table") {
if ctx.UserContext.UserID != 1 { // Only admin can access
ctx.Abort = true
ctx.AbortCode = 403
ctx.AbortMessage = "Access denied: insufficient permissions"
return fmt.Errorf("access denied to sensitive_table")
}
}
return nil
}
// ExampleQueryModificationHook modifies SQL queries to add user-specific filtering
func ExampleQueryModificationHook(ctx *HookContext) error {
// Example: Automatically add user_id filter for non-admin users
if ctx.UserContext.UserID != 1 { // Not admin
// Add WHERE clause to filter by user_id
if !strings.Contains(strings.ToLower(ctx.SQLQuery), "where") {
ctx.SQLQuery = fmt.Sprintf("%s WHERE user_id = %d", ctx.SQLQuery, ctx.UserContext.UserID)
} else {
ctx.SQLQuery = strings.Replace(
ctx.SQLQuery,
"WHERE",
fmt.Sprintf("WHERE user_id = %d AND", ctx.UserContext.UserID),
1,
)
}
logger.Debug("Modified query for user %d: %s", ctx.UserContext.UserID, ctx.SQLQuery)
}
return nil
}
// ExampleResultFilterHook filters results after query execution
func ExampleResultFilterHook(ctx *HookContext) error {
// Example: Remove sensitive fields from results for non-admin users
if ctx.UserContext.UserID != 1 { // Not admin
switch result := ctx.Result.(type) {
case []map[string]interface{}:
// Filter list results
for i := range result {
delete(result[i], "password")
delete(result[i], "ssn")
delete(result[i], "credit_card")
}
case map[string]interface{}:
// Filter single record
delete(result, "password")
delete(result, "ssn")
delete(result, "credit_card")
}
}
return nil
}
// ExampleAuditHook logs all queries and results for audit purposes
func ExampleAuditHook(ctx *HookContext) error {
// Log to audit table or external system
logger.Info("AUDIT: User %s (%d) executed query from %s",
ctx.UserContext.UserName,
ctx.UserContext.UserID,
ctx.Request.RemoteAddr,
)
// In a real implementation, you might:
// - Insert into an audit log table
// - Send to a logging service
// - Write to a file
return nil
}
// ExampleCacheHook implements simple response caching
func ExampleCacheHook(ctx *HookContext) error {
// This is a simplified example - real caching would use a proper cache store
// Check if we have a cached result for this query
// cacheKey := fmt.Sprintf("%s:%s", ctx.UserContext.UserName, ctx.SQLQuery)
// if cachedResult := checkCache(cacheKey); cachedResult != nil {
// ctx.Result = cachedResult
// ctx.Abort = true // Skip query execution
// ctx.AbortMessage = "Serving from cache"
// }
return nil
}
// ExampleErrorHandlingHook provides custom error handling
func ExampleErrorHandlingHook(ctx *HookContext) error {
if ctx.Error != nil {
// Log error with context
logger.Error("Query failed for user %s: %v\nQuery: %s",
ctx.UserContext.UserName,
ctx.Error,
ctx.SQLQuery,
)
// You could send notifications, update metrics, etc.
}
return nil
}
// Example of registering hooks:
//
// func SetupHooks(handler *Handler) {
// hooks := handler.Hooks()
//
// // Register security hook before query execution
// hooks.Register(BeforeQuery, ExampleSecurityHook)
// hooks.Register(BeforeQueryList, ExampleSecurityHook)
//
// // Register logging hook before SQL execution
// hooks.Register(BeforeSQLExec, ExampleLoggingHook)
//
// // Register result filtering after query
// hooks.Register(AfterQuery, ExampleResultFilterHook)
// hooks.Register(AfterQueryList, ExampleResultFilterHook)
//
// // Register audit hook after execution
// hooks.RegisterMultiple([]HookType{AfterQuery, AfterQueryList}, ExampleAuditHook)
// }

589
pkg/funcspec/hooks_test.go Normal file
View File

@ -0,0 +1,589 @@
package funcspec
import (
"context"
"fmt"
"net/http/httptest"
"testing"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// TestNewHookRegistry tests hook registry creation
func TestNewHookRegistry(t *testing.T) {
registry := NewHookRegistry()
if registry == nil {
t.Fatal("Expected registry to be created, got nil")
}
if registry.hooks == nil {
t.Error("Expected hooks map to be initialized")
}
}
// TestRegisterHook tests registering a single hook
func TestRegisterHook(t *testing.T) {
registry := NewHookRegistry()
hookCalled := false
testHook := func(ctx *HookContext) error {
hookCalled = true
return nil
}
registry.Register(BeforeQuery, testHook)
if !registry.HasHooks(BeforeQuery) {
t.Error("Expected hook to be registered")
}
if registry.Count(BeforeQuery) != 1 {
t.Errorf("Expected 1 hook, got %d", registry.Count(BeforeQuery))
}
// Execute the hook
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if !hookCalled {
t.Error("Expected hook to be called")
}
}
// TestRegisterMultipleHooks tests registering multiple hooks for same type
func TestRegisterMultipleHooks(t *testing.T) {
registry := NewHookRegistry()
callOrder := []int{}
hook1 := func(ctx *HookContext) error {
callOrder = append(callOrder, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
callOrder = append(callOrder, 2)
return nil
}
hook3 := func(ctx *HookContext) error {
callOrder = append(callOrder, 3)
return nil
}
registry.Register(BeforeQuery, hook1)
registry.Register(BeforeQuery, hook2)
registry.Register(BeforeQuery, hook3)
if registry.Count(BeforeQuery) != 3 {
t.Errorf("Expected 3 hooks, got %d", registry.Count(BeforeQuery))
}
// Execute hooks
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
// Verify hooks were called in order
if len(callOrder) != 3 {
t.Errorf("Expected 3 hooks to be called, got %d", len(callOrder))
}
for i, expected := range []int{1, 2, 3} {
if callOrder[i] != expected {
t.Errorf("Expected hook %d at position %d, got %d", expected, i, callOrder[i])
}
}
}
// TestRegisterMultipleHookTypes tests registering a hook for multiple types
func TestRegisterMultipleHookTypes(t *testing.T) {
registry := NewHookRegistry()
callCount := 0
testHook := func(ctx *HookContext) error {
callCount++
return nil
}
hookTypes := []HookType{BeforeQuery, AfterQuery, BeforeSQLExec}
registry.RegisterMultiple(hookTypes, testHook)
// Verify hook is registered for all types
for _, hookType := range hookTypes {
if !registry.HasHooks(hookType) {
t.Errorf("Expected hook to be registered for %s", hookType)
}
if registry.Count(hookType) != 1 {
t.Errorf("Expected 1 hook for %s, got %d", hookType, registry.Count(hookType))
}
}
// Execute each hook type
ctx := &HookContext{}
for _, hookType := range hookTypes {
if err := registry.Execute(hookType, ctx); err != nil {
t.Errorf("Hook execution failed for %s: %v", hookType, err)
}
}
if callCount != 3 {
t.Errorf("Expected hook to be called 3 times, got %d", callCount)
}
}
// TestHookError tests hook error handling
func TestHookError(t *testing.T) {
registry := NewHookRegistry()
expectedError := fmt.Errorf("test error")
errorHook := func(ctx *HookContext) error {
return expectedError
}
registry.Register(BeforeQuery, errorHook)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error from hook, got nil")
}
if err.Error() != fmt.Sprintf("hook execution failed: %v", expectedError) {
t.Errorf("Expected error message to contain hook error, got: %v", err)
}
}
// TestHookAbort tests hook abort functionality
func TestHookAbort(t *testing.T) {
registry := NewHookRegistry()
abortHook := func(ctx *HookContext) error {
ctx.Abort = true
ctx.AbortMessage = "Operation aborted by hook"
ctx.AbortCode = 403
return nil
}
registry.Register(BeforeQuery, abortHook)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error when hook aborts, got nil")
}
if !ctx.Abort {
t.Error("Expected Abort to be true")
}
if ctx.AbortMessage != "Operation aborted by hook" {
t.Errorf("Expected abort message, got: %s", ctx.AbortMessage)
}
if ctx.AbortCode != 403 {
t.Errorf("Expected abort code 403, got: %d", ctx.AbortCode)
}
}
// TestHookChainWithError tests that hook chain stops on first error
func TestHookChainWithError(t *testing.T) {
registry := NewHookRegistry()
callOrder := []int{}
hook1 := func(ctx *HookContext) error {
callOrder = append(callOrder, 1)
return nil
}
hook2 := func(ctx *HookContext) error {
callOrder = append(callOrder, 2)
return fmt.Errorf("error in hook 2")
}
hook3 := func(ctx *HookContext) error {
callOrder = append(callOrder, 3)
return nil
}
registry.Register(BeforeQuery, hook1)
registry.Register(BeforeQuery, hook2)
registry.Register(BeforeQuery, hook3)
ctx := &HookContext{}
err := registry.Execute(BeforeQuery, ctx)
if err == nil {
t.Error("Expected error from hook chain")
}
// Only first two hooks should have been called
if len(callOrder) != 2 {
t.Errorf("Expected 2 hooks to be called, got %d", len(callOrder))
}
if callOrder[0] != 1 || callOrder[1] != 2 {
t.Errorf("Expected hooks 1 and 2 to be called, got: %v", callOrder)
}
}
// TestClearHooks tests clearing hooks
func TestClearHooks(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
if !registry.HasHooks(BeforeQuery) {
t.Error("Expected BeforeQuery hook to be registered")
}
registry.Clear(BeforeQuery)
if registry.HasHooks(BeforeQuery) {
t.Error("Expected BeforeQuery hooks to be cleared")
}
if !registry.HasHooks(AfterQuery) {
t.Error("Expected AfterQuery hook to still be registered")
}
}
// TestClearAllHooks tests clearing all hooks
func TestClearAllHooks(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
registry.Register(BeforeSQLExec, testHook)
registry.ClearAll()
if registry.HasHooks(BeforeQuery) || registry.HasHooks(AfterQuery) || registry.HasHooks(BeforeSQLExec) {
t.Error("Expected all hooks to be cleared")
}
}
// TestGetAllHookTypes tests getting all registered hook types
func TestGetAllHookTypes(t *testing.T) {
registry := NewHookRegistry()
testHook := func(ctx *HookContext) error {
return nil
}
registry.Register(BeforeQuery, testHook)
registry.Register(AfterQuery, testHook)
types := registry.GetAllHookTypes()
if len(types) != 2 {
t.Errorf("Expected 2 hook types, got %d", len(types))
}
// Verify the types are present
foundBefore := false
foundAfter := false
for _, hookType := range types {
if hookType == BeforeQuery {
foundBefore = true
}
if hookType == AfterQuery {
foundAfter = true
}
}
if !foundBefore || !foundAfter {
t.Error("Expected both BeforeQuery and AfterQuery hook types")
}
}
// TestHookContextModification tests that hooks can modify the context
func TestHookContextModification(t *testing.T) {
registry := NewHookRegistry()
// Hook that modifies SQL query
modifyHook := func(ctx *HookContext) error {
ctx.SQLQuery = "SELECT * FROM modified_table"
ctx.Variables["new_var"] = "new_value"
return nil
}
registry.Register(BeforeQuery, modifyHook)
ctx := &HookContext{
SQLQuery: "SELECT * FROM original_table",
Variables: make(map[string]interface{}),
}
err := registry.Execute(BeforeQuery, ctx)
if err != nil {
t.Errorf("Hook execution failed: %v", err)
}
if ctx.SQLQuery != "SELECT * FROM modified_table" {
t.Errorf("Expected SQL query to be modified, got: %s", ctx.SQLQuery)
}
if ctx.Variables["new_var"] != "new_value" {
t.Errorf("Expected variable to be added, got: %v", ctx.Variables)
}
}
// TestExampleHooks tests the example hooks
func TestExampleLoggingHook(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: "SELECT * FROM test",
UserContext: &security.UserContext{
UserName: "testuser",
},
}
err := ExampleLoggingHook(ctx)
if err != nil {
t.Errorf("ExampleLoggingHook failed: %v", err)
}
}
func TestExampleSecurityHook(t *testing.T) {
tests := []struct {
name string
sqlQuery string
userID int
shouldAbort bool
}{
{
name: "Admin accessing sensitive table",
sqlQuery: "SELECT * FROM sensitive_table",
userID: 1,
shouldAbort: false,
},
{
name: "Non-admin accessing sensitive table",
sqlQuery: "SELECT * FROM sensitive_table",
userID: 2,
shouldAbort: true,
},
{
name: "Non-admin accessing normal table",
sqlQuery: "SELECT * FROM users",
userID: 2,
shouldAbort: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: tt.sqlQuery,
UserContext: &security.UserContext{
UserID: tt.userID,
},
}
_ = ExampleSecurityHook(ctx)
if tt.shouldAbort {
if !ctx.Abort {
t.Error("Expected security hook to abort operation")
}
if ctx.AbortCode != 403 {
t.Errorf("Expected abort code 403, got %d", ctx.AbortCode)
}
} else {
if ctx.Abort {
t.Error("Expected security hook not to abort operation")
}
}
})
}
}
func TestExampleResultFilterHook(t *testing.T) {
tests := []struct {
name string
userID int
result interface{}
validate func(t *testing.T, result interface{})
}{
{
name: "Admin user - no filtering",
userID: 1,
result: map[string]interface{}{
"id": 1,
"name": "Test",
"password": "secret",
},
validate: func(t *testing.T, result interface{}) {
m := result.(map[string]interface{})
if _, exists := m["password"]; !exists {
t.Error("Expected password field to remain for admin")
}
},
},
{
name: "Regular user - sensitive fields removed",
userID: 2,
result: map[string]interface{}{
"id": 1,
"name": "Test",
"password": "secret",
"ssn": "123-45-6789",
},
validate: func(t *testing.T, result interface{}) {
m := result.(map[string]interface{})
if _, exists := m["password"]; exists {
t.Error("Expected password field to be removed")
}
if _, exists := m["ssn"]; exists {
t.Error("Expected ssn field to be removed")
}
if _, exists := m["name"]; !exists {
t.Error("Expected name field to remain")
}
},
},
{
name: "Regular user - list results filtered",
userID: 2,
result: []map[string]interface{}{
{"id": 1, "name": "User 1", "password": "secret1"},
{"id": 2, "name": "User 2", "password": "secret2"},
},
validate: func(t *testing.T, result interface{}) {
list := result.([]map[string]interface{})
for _, m := range list {
if _, exists := m["password"]; exists {
t.Error("Expected password field to be removed from list")
}
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
Result: tt.result,
UserContext: &security.UserContext{
UserID: tt.userID,
},
}
err := ExampleResultFilterHook(ctx)
if err != nil {
t.Errorf("Hook failed: %v", err)
}
if tt.validate != nil {
tt.validate(t, ctx.Result)
}
})
}
}
func TestExampleAuditHook(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
ctx := &HookContext{
Context: context.Background(),
Request: req,
UserContext: &security.UserContext{
UserID: 123,
UserName: "testuser",
},
}
err := ExampleAuditHook(ctx)
if err != nil {
t.Errorf("ExampleAuditHook failed: %v", err)
}
}
func TestExampleErrorHandlingHook(t *testing.T) {
ctx := &HookContext{
Context: context.Background(),
SQLQuery: "SELECT * FROM test",
Error: fmt.Errorf("test error"),
UserContext: &security.UserContext{
UserName: "testuser",
},
}
err := ExampleErrorHandlingHook(ctx)
if err != nil {
t.Errorf("ExampleErrorHandlingHook failed: %v", err)
}
}
// TestHookIntegrationWithHandler tests hooks integrated with the handler
func TestHookIntegrationWithHandler(t *testing.T) {
db := &MockDatabase{
RunInTransactionFunc: func(ctx context.Context, fn func(common.Database) error) error {
queryDB := &MockDatabase{
QueryFunc: func(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
rows := dest.(*[]map[string]interface{})
*rows = []map[string]interface{}{
{"id": float64(1), "name": "Test User"},
}
return nil
},
}
return fn(queryDB)
},
}
handler := NewHandler(db)
// Register a hook that modifies the SQL query
hookCalled := false
handler.Hooks().Register(BeforeSQLExec, func(ctx *HookContext) error {
hookCalled = true
// Verify we can access context data
if ctx.SQLQuery == "" {
t.Error("Expected SQL query to be set")
}
if ctx.UserContext == nil {
t.Error("Expected user context to be set")
}
return nil
})
// Execute a query
req := createTestRequest("GET", "/test", nil, nil, nil)
w := httptest.NewRecorder()
handlerFunc := handler.SqlQuery("SELECT * FROM users WHERE id = 1", SqlQueryOptions{})
handlerFunc(w, req)
if !hookCalled {
t.Error("Expected hook to be called during query execution")
}
if w.Code != 200 {
t.Errorf("Expected status 200, got %d", w.Code)
}
}

411
pkg/funcspec/parameters.go Normal file
View File

@ -0,0 +1,411 @@
package funcspec
import (
"fmt"
"net/http"
"strconv"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/restheadspec"
)
// RequestParameters holds parsed parameters from headers and query string
type RequestParameters struct {
// Field selection
SelectFields []string
NotSelectFields []string
Distinct bool
// Filtering
FieldFilters map[string]string // column -> value (exact match)
SearchFilters map[string]string // column -> value (ILIKE)
SearchOps map[string]FilterOperator // column -> {operator, value, logic}
CustomSQLWhere string
CustomSQLOr string
// Sorting & Pagination
SortColumns string
Limit int
Offset int
// Advanced features
SkipCount bool
SkipCache bool
// Response format
ResponseFormat string // "simple", "detail", "syncfusion"
ComplexAPI bool // true if NOT simple API
}
// FilterOperator represents a filter with operator
type FilterOperator struct {
Operator string // eq, neq, gt, lt, gte, lte, like, ilike, in, between, etc.
Value string
Logic string // AND or OR
}
// ParseParameters parses all parameters from request headers and query string
func (h *Handler) ParseParameters(r *http.Request) *RequestParameters {
params := &RequestParameters{
FieldFilters: make(map[string]string),
SearchFilters: make(map[string]string),
SearchOps: make(map[string]FilterOperator),
Limit: 20, // Default limit
Offset: 0, // Default offset
ResponseFormat: "simple", // Default format
ComplexAPI: false, // Default to simple API
}
// Merge headers and query parameters
combined := make(map[string]string)
// Add all headers (normalize to lowercase)
for key, values := range r.Header {
if len(values) > 0 {
combined[strings.ToLower(key)] = values[0]
}
}
// Add all query parameters (override headers)
for key, values := range r.URL.Query() {
if len(values) > 0 {
combined[strings.ToLower(key)] = values[0]
}
}
// Parse each parameter
for key, value := range combined {
// Decode value if base64 encoded
decodedValue := h.decodeValue(value)
switch {
// Field Selection
case strings.HasPrefix(key, "x-select-fields"):
params.SelectFields = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(key, "x-not-select-fields"):
params.NotSelectFields = h.parseCommaSeparated(decodedValue)
case strings.HasPrefix(key, "x-distinct"):
params.Distinct = strings.EqualFold(decodedValue, "true")
// Filtering
case strings.HasPrefix(key, "x-fieldfilter-"):
colName := strings.TrimPrefix(key, "x-fieldfilter-")
params.FieldFilters[colName] = decodedValue
case strings.HasPrefix(key, "x-searchfilter-"):
colName := strings.TrimPrefix(key, "x-searchfilter-")
params.SearchFilters[colName] = decodedValue
case strings.HasPrefix(key, "x-searchop-"):
h.parseSearchOp(params, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-searchor-"):
h.parseSearchOp(params, key, decodedValue, "OR")
case strings.HasPrefix(key, "x-searchand-"):
h.parseSearchOp(params, key, decodedValue, "AND")
case strings.HasPrefix(key, "x-custom-sql-w"):
if params.CustomSQLWhere != "" {
params.CustomSQLWhere = fmt.Sprintf("%s AND (%s)", params.CustomSQLWhere, decodedValue)
} else {
params.CustomSQLWhere = decodedValue
}
case strings.HasPrefix(key, "x-custom-sql-or"):
if params.CustomSQLOr != "" {
params.CustomSQLOr = fmt.Sprintf("%s OR (%s)", params.CustomSQLOr, decodedValue)
} else {
params.CustomSQLOr = decodedValue
}
// Sorting & Pagination
case key == "sort" || strings.HasPrefix(key, "x-sort"):
params.SortColumns = decodedValue
case strings.HasPrefix(key, "sort(") && strings.Contains(key, ")"):
// Handle sort(col1,-col2) syntax
sortValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
params.SortColumns = sortValue
case key == "limit" || strings.HasPrefix(key, "x-limit"):
if limit, err := strconv.Atoi(decodedValue); err == nil && limit > 0 {
params.Limit = limit
}
case strings.HasPrefix(key, "limit(") && strings.Contains(key, ")"):
// Handle limit(offset,limit) or limit(limit) syntax
limitValue := key[strings.Index(key, "(")+1 : strings.Index(key, ")")]
parts := strings.Split(limitValue, ",")
if len(parts) > 1 {
if offset, err := strconv.Atoi(parts[0]); err == nil {
params.Offset = offset
}
if limit, err := strconv.Atoi(parts[1]); err == nil {
params.Limit = limit
}
} else {
if limit, err := strconv.Atoi(parts[0]); err == nil {
params.Limit = limit
}
}
case key == "offset" || strings.HasPrefix(key, "x-offset"):
if offset, err := strconv.Atoi(decodedValue); err == nil && offset >= 0 {
params.Offset = offset
}
// Advanced features
case strings.HasPrefix(key, "x-skipcount"):
params.SkipCount = strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(key, "x-skipcache"):
params.SkipCache = strings.EqualFold(decodedValue, "true")
// Response Format
case strings.HasPrefix(key, "x-simpleapi"):
params.ResponseFormat = "simple"
params.ComplexAPI = decodedValue != "1" && !strings.EqualFold(decodedValue, "true")
case strings.HasPrefix(key, "x-detailapi"):
params.ResponseFormat = "detail"
params.ComplexAPI = true
case strings.HasPrefix(key, "x-syncfusion"):
params.ResponseFormat = "syncfusion"
params.ComplexAPI = true
}
}
return params
}
// parseSearchOp parses x-searchop-{operator}-{column} or x-searchor-{operator}-{column}
func (h *Handler) parseSearchOp(params *RequestParameters, headerKey, value, logic string) {
var prefix string
if logic == "OR" {
prefix = "x-searchor-"
} else {
prefix = "x-searchop-"
if strings.HasPrefix(headerKey, "x-searchand-") {
prefix = "x-searchand-"
}
}
rest := strings.TrimPrefix(headerKey, prefix)
parts := strings.SplitN(rest, "-", 2)
if len(parts) != 2 {
logger.Warn("Invalid search operator header format: %s", headerKey)
return
}
operator := parts[0]
colName := parts[1]
params.SearchOps[colName] = FilterOperator{
Operator: operator,
Value: value,
Logic: logic,
}
logger.Debug("%s search operator: %s %s %s", logic, colName, operator, value)
}
// decodeValue decodes base64 encoded values (ZIP_ or __ prefix)
func (h *Handler) decodeValue(value string) string {
decoded, _ := restheadspec.DecodeParam(value)
return decoded
}
// parseCommaSeparated parses comma-separated values
func (h *Handler) parseCommaSeparated(value string) []string {
if value == "" {
return nil
}
parts := strings.Split(value, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part != "" {
result = append(result, part)
}
}
return result
}
// ApplyFieldSelection applies column selection to SQL query
func (h *Handler) ApplyFieldSelection(sqlQuery string, params *RequestParameters) string {
if len(params.SelectFields) == 0 && len(params.NotSelectFields) == 0 {
return sqlQuery
}
// This is a simplified implementation
// A full implementation would parse the SQL and replace the SELECT clause
// For now, we log a warning that this feature needs manual implementation
if len(params.SelectFields) > 0 {
logger.Debug("Field selection requested: %v (manual SQL adjustment may be needed)", params.SelectFields)
}
if len(params.NotSelectFields) > 0 {
logger.Debug("Field exclusion requested: %v (manual SQL adjustment may be needed)", params.NotSelectFields)
}
return sqlQuery
}
// ApplyFilters applies all filters to the SQL query
func (h *Handler) ApplyFilters(sqlQuery string, params *RequestParameters) string {
// Apply field filters (exact match)
for colName, value := range params.FieldFilters {
condition := ""
if value == "" || value == "0" {
condition = fmt.Sprintf("COALESCE(%s, 0) = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
} else {
condition = fmt.Sprintf("%s = %s", ValidSQL(colName, "colname"), ValidSQL(value, "colvalue"))
}
sqlQuery = sqlQryWhere(sqlQuery, condition)
logger.Debug("Applied field filter: %s", condition)
}
// Apply search filters (ILIKE)
for colName, value := range params.SearchFilters {
sval := strings.ReplaceAll(value, "'", "")
if sval != "" {
condition := fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colName, "colname"), ValidSQL(sval, "colvalue"))
sqlQuery = sqlQryWhere(sqlQuery, condition)
logger.Debug("Applied search filter: %s", condition)
}
}
// Apply search operators
for colName, filterOp := range params.SearchOps {
condition := h.buildFilterCondition(colName, filterOp)
if condition != "" {
if filterOp.Logic == "OR" {
sqlQuery = sqlQryWhereOr(sqlQuery, condition)
} else {
sqlQuery = sqlQryWhere(sqlQuery, condition)
}
logger.Debug("Applied search operator: %s", condition)
}
}
// Apply custom SQL WHERE
if params.CustomSQLWhere != "" {
colval := ValidSQL(params.CustomSQLWhere, "select")
if colval != "" {
sqlQuery = sqlQryWhere(sqlQuery, colval)
logger.Debug("Applied custom SQL WHERE: %s", colval)
}
}
// Apply custom SQL OR
if params.CustomSQLOr != "" {
colval := ValidSQL(params.CustomSQLOr, "select")
if colval != "" {
sqlQuery = sqlQryWhereOr(sqlQuery, colval)
logger.Debug("Applied custom SQL OR: %s", colval)
}
}
return sqlQuery
}
// buildFilterCondition builds a SQL condition from a FilterOperator
func (h *Handler) buildFilterCondition(colName string, op FilterOperator) string {
safCol := ValidSQL(colName, "colname")
operator := strings.ToLower(op.Operator)
value := op.Value
switch operator {
case "contains", "contain", "like":
return fmt.Sprintf("%s ILIKE '%%%s%%'", safCol, ValidSQL(value, "colvalue"))
case "beginswith", "startswith":
return fmt.Sprintf("%s ILIKE '%s%%'", safCol, ValidSQL(value, "colvalue"))
case "endswith":
return fmt.Sprintf("%s ILIKE '%%%s'", safCol, ValidSQL(value, "colvalue"))
case "equals", "eq", "=":
if IsNumeric(value) {
return fmt.Sprintf("%s = %s", safCol, ValidSQL(value, "colvalue"))
}
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
case "notequals", "neq", "ne", "!=", "<>":
if IsNumeric(value) {
return fmt.Sprintf("%s != %s", safCol, ValidSQL(value, "colvalue"))
}
return fmt.Sprintf("%s != '%s'", safCol, ValidSQL(value, "colvalue"))
case "greaterthan", "gt", ">":
return fmt.Sprintf("%s > %s", safCol, ValidSQL(value, "colvalue"))
case "lessthan", "lt", "<":
return fmt.Sprintf("%s < %s", safCol, ValidSQL(value, "colvalue"))
case "greaterthanorequal", "gte", "ge", ">=":
return fmt.Sprintf("%s >= %s", safCol, ValidSQL(value, "colvalue"))
case "lessthanorequal", "lte", "le", "<=":
return fmt.Sprintf("%s <= %s", safCol, ValidSQL(value, "colvalue"))
case "between":
parts := strings.Split(value, ",")
if len(parts) == 2 {
return fmt.Sprintf("%s > %s AND %s < %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
}
case "betweeninclusive":
parts := strings.Split(value, ",")
if len(parts) == 2 {
return fmt.Sprintf("%s >= %s AND %s <= %s", safCol, ValidSQL(parts[0], "colvalue"), safCol, ValidSQL(parts[1], "colvalue"))
}
case "in":
values := strings.Split(value, ",")
safeValues := make([]string, len(values))
for i, v := range values {
safeValues[i] = fmt.Sprintf("'%s'", ValidSQL(v, "colvalue"))
}
return fmt.Sprintf("%s IN (%s)", safCol, strings.Join(safeValues, ", "))
case "empty", "isnull", "null":
return fmt.Sprintf("(%s IS NULL OR %s = '')", safCol, safCol)
case "notempty", "isnotnull", "notnull":
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", safCol, safCol)
default:
logger.Warn("Unknown filter operator: %s, defaulting to equals", operator)
return fmt.Sprintf("%s = '%s'", safCol, ValidSQL(value, "colvalue"))
}
return ""
}
// ApplyDistinct adds DISTINCT to SQL query if requested
func (h *Handler) ApplyDistinct(sqlQuery string, params *RequestParameters) string {
if !params.Distinct {
return sqlQuery
}
// Add DISTINCT after SELECT
selectPos := strings.Index(strings.ToUpper(sqlQuery), "SELECT")
if selectPos >= 0 {
beforeSelect := sqlQuery[:selectPos+6] // "SELECT"
afterSelect := sqlQuery[selectPos+6:]
sqlQuery = beforeSelect + " DISTINCT" + afterSelect
logger.Debug("Applied DISTINCT to query")
}
return sqlQuery
}
// sqlQryWhereOr adds a WHERE clause with OR logic
func sqlQryWhereOr(sqlquery, condition string) string {
lowerQuery := strings.ToLower(sqlquery)
wherePos := strings.Index(lowerQuery, " where ")
groupPos := strings.Index(lowerQuery, " group by")
orderPos := strings.Index(lowerQuery, " order by")
limitPos := strings.Index(lowerQuery, " limit ")
// Find the insertion point
insertPos := len(sqlquery)
if groupPos > 0 && groupPos < insertPos {
insertPos = groupPos
}
if orderPos > 0 && orderPos < insertPos {
insertPos = orderPos
}
if limitPos > 0 && limitPos < insertPos {
insertPos = limitPos
}
if wherePos > 0 {
// WHERE exists, add OR condition
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s OR (%s) %s", before, condition, after)
} else {
// No WHERE exists, add it
before := sqlquery[:insertPos]
after := sqlquery[insertPos:]
return fmt.Sprintf("%s WHERE %s %s", before, condition, after)
}
}

View File

@ -0,0 +1,549 @@
package funcspec
import (
"strings"
"testing"
)
// TestParseParameters tests the comprehensive parameter parsing
func TestParseParameters(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
queryParams map[string]string
headers map[string]string
validate func(t *testing.T, params *RequestParameters)
}{
{
name: "Parse field selection",
headers: map[string]string{
"X-Select-Fields": "id,name,email",
"X-Not-Select-Fields": "password,ssn",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SelectFields) != 3 {
t.Errorf("Expected 3 select fields, got %d", len(params.SelectFields))
}
if len(params.NotSelectFields) != 2 {
t.Errorf("Expected 2 not-select fields, got %d", len(params.NotSelectFields))
}
},
},
{
name: "Parse distinct flag",
headers: map[string]string{
"X-Distinct": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if !params.Distinct {
t.Error("Expected Distinct to be true")
}
},
},
{
name: "Parse field filters",
headers: map[string]string{
"X-FieldFilter-Status": "active",
"X-FieldFilter-Type": "admin",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.FieldFilters) != 2 {
t.Errorf("Expected 2 field filters, got %d", len(params.FieldFilters))
}
if params.FieldFilters["status"] != "active" {
t.Errorf("Expected status filter=active, got %s", params.FieldFilters["status"])
}
},
},
{
name: "Parse search filters",
headers: map[string]string{
"X-SearchFilter-Name": "john",
"X-SearchFilter-Email": "test",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SearchFilters) != 2 {
t.Errorf("Expected 2 search filters, got %d", len(params.SearchFilters))
}
},
},
{
name: "Parse sort columns",
queryParams: map[string]string{
"sort": "-created_at,name",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.SortColumns != "-created_at,name" {
t.Errorf("Expected sort columns=-created_at,name, got %s", params.SortColumns)
}
},
},
{
name: "Parse limit and offset",
queryParams: map[string]string{
"limit": "100",
"offset": "50",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.Limit != 100 {
t.Errorf("Expected limit=100, got %d", params.Limit)
}
if params.Offset != 50 {
t.Errorf("Expected offset=50, got %d", params.Offset)
}
},
},
{
name: "Parse skip count",
headers: map[string]string{
"X-SkipCount": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if !params.SkipCount {
t.Error("Expected SkipCount to be true")
}
},
},
{
name: "Parse response format - syncfusion",
headers: map[string]string{
"X-Syncfusion": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "syncfusion" {
t.Errorf("Expected ResponseFormat=syncfusion, got %s", params.ResponseFormat)
}
if !params.ComplexAPI {
t.Error("Expected ComplexAPI to be true for syncfusion format")
}
},
},
{
name: "Parse response format - detail",
headers: map[string]string{
"X-DetailAPI": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "detail" {
t.Errorf("Expected ResponseFormat=detail, got %s", params.ResponseFormat)
}
},
},
{
name: "Parse simple API",
headers: map[string]string{
"X-SimpleAPI": "true",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.ResponseFormat != "simple" {
t.Errorf("Expected ResponseFormat=simple, got %s", params.ResponseFormat)
}
if params.ComplexAPI {
t.Error("Expected ComplexAPI to be false for simple API")
}
},
},
{
name: "Parse custom SQL WHERE",
headers: map[string]string{
"X-Custom-SQL-W": "status = 'active' AND deleted = false",
},
validate: func(t *testing.T, params *RequestParameters) {
if params.CustomSQLWhere == "" {
t.Error("Expected CustomSQLWhere to be set")
}
},
},
{
name: "Parse search operators - AND",
headers: map[string]string{
"X-SearchOp-Eq-Name": "john",
"X-SearchOp-Gt-Age": "18",
},
validate: func(t *testing.T, params *RequestParameters) {
if len(params.SearchOps) != 2 {
t.Errorf("Expected 2 search operators, got %d", len(params.SearchOps))
}
if op, exists := params.SearchOps["name"]; exists {
if op.Operator != "eq" {
t.Errorf("Expected operator=eq for name, got %s", op.Operator)
}
if op.Logic != "AND" {
t.Errorf("Expected logic=AND, got %s", op.Logic)
}
} else {
t.Error("Expected name search operator to exist")
}
},
},
{
name: "Parse search operators - OR",
headers: map[string]string{
"X-SearchOr-Like-Description": "test",
},
validate: func(t *testing.T, params *RequestParameters) {
if op, exists := params.SearchOps["description"]; exists {
if op.Logic != "OR" {
t.Errorf("Expected logic=OR, got %s", op.Logic)
}
} else {
t.Error("Expected description search operator to exist")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := createTestRequest("GET", "/test", tt.queryParams, tt.headers, nil)
params := handler.ParseParameters(req)
if tt.validate != nil {
tt.validate(t, params)
}
})
}
}
// TestBuildFilterCondition tests the filter condition builder
func TestBuildFilterCondition(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
colName string
operator FilterOperator
expected string
}{
{
name: "Equals operator - numeric",
colName: "age",
operator: FilterOperator{
Operator: "eq",
Value: "25",
Logic: "AND",
},
expected: "age = 25",
},
{
name: "Equals operator - string",
colName: "name",
operator: FilterOperator{
Operator: "eq",
Value: "john",
Logic: "AND",
},
expected: "name = 'john'",
},
{
name: "Not equals operator",
colName: "status",
operator: FilterOperator{
Operator: "neq",
Value: "inactive",
Logic: "AND",
},
expected: "status != 'inactive'",
},
{
name: "Greater than operator",
colName: "age",
operator: FilterOperator{
Operator: "gt",
Value: "18",
Logic: "AND",
},
expected: "age > 18",
},
{
name: "Less than operator",
colName: "price",
operator: FilterOperator{
Operator: "lt",
Value: "100",
Logic: "AND",
},
expected: "price < 100",
},
{
name: "Contains operator",
colName: "description",
operator: FilterOperator{
Operator: "contains",
Value: "test",
Logic: "AND",
},
expected: "description ILIKE '%test%'",
},
{
name: "Starts with operator",
colName: "name",
operator: FilterOperator{
Operator: "startswith",
Value: "john",
Logic: "AND",
},
expected: "name ILIKE 'john%'",
},
{
name: "Ends with operator",
colName: "email",
operator: FilterOperator{
Operator: "endswith",
Value: "@example.com",
Logic: "AND",
},
expected: "email ILIKE '%@example.com'",
},
{
name: "Between operator",
colName: "age",
operator: FilterOperator{
Operator: "between",
Value: "18,65",
Logic: "AND",
},
expected: "age > 18 AND age < 65",
},
{
name: "IN operator",
colName: "status",
operator: FilterOperator{
Operator: "in",
Value: "active,pending,approved",
Logic: "AND",
},
expected: "status IN ('active', 'pending', 'approved')",
},
{
name: "IS NULL operator",
colName: "deleted_at",
operator: FilterOperator{
Operator: "null",
Value: "",
Logic: "AND",
},
expected: "(deleted_at IS NULL OR deleted_at = '')",
},
{
name: "IS NOT NULL operator",
colName: "created_at",
operator: FilterOperator{
Operator: "notnull",
Value: "",
Logic: "AND",
},
expected: "(created_at IS NOT NULL AND created_at != '')",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.buildFilterCondition(tt.colName, tt.operator)
if result != tt.expected {
t.Errorf("Expected: %s\nGot: %s", tt.expected, result)
}
})
}
}
// TestApplyFilters tests the filter application to SQL queries
func TestApplyFilters(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
params *RequestParameters
expectedSQL string
shouldContain []string
}{
{
name: "Apply field filter",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
FieldFilters: map[string]string{
"status": "active",
},
},
shouldContain: []string{"WHERE", "status"},
},
{
name: "Apply search filter",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
SearchFilters: map[string]string{
"name": "john",
},
},
shouldContain: []string{"WHERE", "name", "ILIKE"},
},
{
name: "Apply search operators",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
SearchOps: map[string]FilterOperator{
"age": {
Operator: "gt",
Value: "18",
Logic: "AND",
},
},
},
shouldContain: []string{"WHERE", "age", ">", "18"},
},
{
name: "Apply custom SQL WHERE",
sqlQuery: "SELECT * FROM users",
params: &RequestParameters{
CustomSQLWhere: "deleted = false",
},
shouldContain: []string{"WHERE", "deleted"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.ApplyFilters(tt.sqlQuery, tt.params)
for _, expected := range tt.shouldContain {
if !strings.Contains(result, expected) {
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
}
}
})
}
}
// TestApplyDistinct tests DISTINCT application
func TestApplyDistinct(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
sqlQuery string
distinct bool
shouldHave string
}{
{
name: "Apply DISTINCT",
sqlQuery: "SELECT id, name FROM users",
distinct: true,
shouldHave: "SELECT DISTINCT",
},
{
name: "Do not apply DISTINCT",
sqlQuery: "SELECT id, name FROM users",
distinct: false,
shouldHave: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
params := &RequestParameters{Distinct: tt.distinct}
result := handler.ApplyDistinct(tt.sqlQuery, params)
if tt.shouldHave != "" {
if !strings.Contains(result, tt.shouldHave) {
t.Errorf("Expected SQL to contain %q, got: %s", tt.shouldHave, result)
}
} else {
// Should not have DISTINCT when not requested
if strings.Contains(result, "DISTINCT") && !tt.distinct {
t.Errorf("SQL should not contain DISTINCT when not requested: %s", result)
}
}
})
}
}
// TestParseCommaSeparated tests comma-separated value parsing
func TestParseCommaSeparated(t *testing.T) {
handler := NewHandler(&MockDatabase{})
tests := []struct {
name string
input string
expected []string
}{
{
name: "Simple comma-separated",
input: "id,name,email",
expected: []string{"id", "name", "email"},
},
{
name: "With spaces",
input: "id, name, email",
expected: []string{"id", "name", "email"},
},
{
name: "Empty string",
input: "",
expected: nil,
},
{
name: "Single value",
input: "id",
expected: []string{"id"},
},
{
name: "With extra commas",
input: "id,,name,,email",
expected: []string{"id", "name", "email"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := handler.parseCommaSeparated(tt.input)
if len(result) != len(tt.expected) {
t.Errorf("Expected %d values, got %d", len(tt.expected), len(result))
return
}
for i, expected := range tt.expected {
if result[i] != expected {
t.Errorf("Expected value %d to be %s, got %s", i, expected, result[i])
}
}
})
}
}
// TestSqlQryWhereOr tests OR WHERE clause manipulation
func TestSqlQryWhereOr(t *testing.T) {
tests := []struct {
name string
sqlQuery string
condition string
shouldContain []string
}{
{
name: "Add WHERE with OR to query without WHERE",
sqlQuery: "SELECT * FROM users",
condition: "status = 'inactive'",
shouldContain: []string{"WHERE", "status = 'inactive'"},
},
{
name: "Add OR to query with existing WHERE",
sqlQuery: "SELECT * FROM users WHERE id > 0",
condition: "status = 'inactive'",
shouldContain: []string{"WHERE", "OR", "(status = 'inactive')"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sqlQryWhereOr(tt.sqlQuery, tt.condition)
for _, expected := range tt.shouldContain {
if !strings.Contains(result, expected) {
t.Errorf("Expected SQL to contain %q, got: %s", expected, result)
}
}
})
}
}

View File

@ -0,0 +1,83 @@
package funcspec
import (
"context"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks registers security hooks for funcspec handlers
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
// We provide audit logging for data access tracking
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 1: BeforeQueryList - Audit logging before query list execution
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
secCtx := newFuncSpecSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
// Hook 2: BeforeQuery - Audit logging before single query execution
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
secCtx := newFuncSpecSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
// Note: Row-level security and column masking are challenging in funcspec
// because the SQL query is fully user-defined. Security should be implemented
// at the SQL function level or through database policies (RLS).
}
// funcSpecSecurityContext adapts funcspec.HookContext to security.SecurityContext interface
type funcSpecSecurityContext struct {
ctx *HookContext
}
func newFuncSpecSecurityContext(ctx *HookContext) security.SecurityContext {
return &funcSpecSecurityContext{ctx: ctx}
}
func (f *funcSpecSecurityContext) GetContext() context.Context {
return f.ctx.Context
}
func (f *funcSpecSecurityContext) GetUserID() (int, bool) {
if f.ctx.UserContext == nil {
return 0, false
}
return int(f.ctx.UserContext.UserID), true
}
func (f *funcSpecSecurityContext) GetSchema() string {
// funcspec doesn't have a schema concept, extract from SQL query or use default
return "public"
}
func (f *funcSpecSecurityContext) GetEntity() string {
// funcspec doesn't have an entity concept, could parse from SQL or use a placeholder
return "sql_query"
}
func (f *funcSpecSecurityContext) GetModel() interface{} {
// funcspec doesn't use models in the same way as restheadspec
return nil
}
func (f *funcSpecSecurityContext) GetQuery() interface{} {
// In funcspec, the query is a string, not a query builder object
return f.ctx.SQLQuery
}
func (f *funcSpecSecurityContext) SetQuery(query interface{}) {
// In funcspec, we could modify the SQL string, but this should be done cautiously
if sqlQuery, ok := query.(string); ok {
f.ctx.SQLQuery = sqlQuery
}
}
func (f *funcSpecSecurityContext) GetResult() interface{} {
return f.ctx.Result
}
func (f *funcSpecSecurityContext) SetResult(result interface{}) {
f.ctx.Result = result
}

View File

@ -1,14 +1,19 @@
package logger package logger
import ( import (
"context"
"fmt" "fmt"
"log" "log"
"os" "os"
"runtime/debug"
"go.uber.org/zap" "go.uber.org/zap"
errortracking "github.com/bitechdev/ResolveSpec/pkg/errortracking"
) )
var Logger *zap.SugaredLogger var Logger *zap.SugaredLogger
var errorTracker errortracking.Provider
func Init(dev bool) { func Init(dev bool) {
@ -22,6 +27,15 @@ func Init(dev bool) {
} }
func UpdateLoggerPath(path string, dev bool) {
defaultConfig := zap.NewProductionConfig()
if dev {
defaultConfig = zap.NewDevelopmentConfig()
}
defaultConfig.OutputPaths = []string{path}
UpdateLogger(&defaultConfig)
}
func UpdateLogger(config *zap.Config) { func UpdateLogger(config *zap.Config) {
defaultConfig := zap.NewProductionConfig() defaultConfig := zap.NewProductionConfig()
defaultConfig.OutputPaths = []string{"resolvespec.log"} defaultConfig.OutputPaths = []string{"resolvespec.log"}
@ -39,6 +53,28 @@ func UpdateLogger(config *zap.Config) {
Info("ResolveSpec Logger initialized") Info("ResolveSpec Logger initialized")
} }
// InitErrorTracking initializes the error tracking provider
func InitErrorTracking(provider errortracking.Provider) {
errorTracker = provider
if errorTracker != nil {
Info("Error tracking initialized")
}
}
// GetErrorTracker returns the current error tracking provider
func GetErrorTracker() errortracking.Provider {
return errorTracker
}
// CloseErrorTracking flushes and closes the error tracking provider
func CloseErrorTracking() error {
if errorTracker != nil {
errorTracker.Flush(5)
return errorTracker.Close()
}
return nil
}
func Info(template string, args ...interface{}) { func Info(template string, args ...interface{}) {
if Logger == nil { if Logger == nil {
log.Printf(template, args...) log.Printf(template, args...)
@ -48,19 +84,35 @@ func Info(template string, args ...interface{}) {
} }
func Warn(template string, args ...interface{}) { func Warn(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
if Logger == nil { if Logger == nil {
log.Printf(template, args...) log.Printf("%s", message)
return } else {
Logger.Warnw(message, "process_id", os.Getpid())
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityWarning, map[string]interface{}{
"process_id": os.Getpid(),
})
} }
Logger.Warnw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
} }
func Error(template string, args ...interface{}) { func Error(template string, args ...interface{}) {
message := fmt.Sprintf(template, args...)
if Logger == nil { if Logger == nil {
log.Printf(template, args...) log.Printf("%s", message)
return } else {
Logger.Errorw(message, "process_id", os.Getpid())
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CaptureMessage(context.Background(), message, errortracking.SeverityError, map[string]interface{}{
"process_id": os.Getpid(),
})
} }
Logger.Errorw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
} }
func Debug(template string, args ...interface{}) { func Debug(template string, args ...interface{}) {
@ -70,3 +122,58 @@ func Debug(template string, args ...interface{}) {
} }
Logger.Debugw(fmt.Sprintf(template, args...), "process_id", os.Getpid()) Logger.Debugw(fmt.Sprintf(template, args...), "process_id", os.Getpid())
} }
// CatchPanic - Handle panic
func CatchPanicCallback(location string, cb func(err any)) {
if err := recover(); err != nil {
callstack := debug.Stack()
if Logger != nil {
Error("Panic in %s : %v", location, err)
} else {
fmt.Printf("%s:PANIC->%+v", location, err)
debug.PrintStack()
}
// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(context.Background(), err, callstack, map[string]interface{}{
"location": location,
"process_id": os.Getpid(),
})
}
if cb != nil {
cb(err)
}
}
}
// CatchPanic - Handle panic
func CatchPanic(location string) {
CatchPanicCallback(location, nil)
}
// HandlePanic logs a panic and returns it as an error
// This should be called with the result of recover() from a deferred function
// Example usage:
//
// defer func() {
// if r := recover(); r != nil {
// err = logger.HandlePanic("MethodName", r)
// }
// }()
func HandlePanic(methodName string, r any) error {
stack := debug.Stack()
Error("Panic in %s: %v\nStack trace:\n%s", methodName, r, string(stack))
// Send to error tracker
if errorTracker != nil {
errorTracker.CapturePanic(context.Background(), r, stack, map[string]interface{}{
"method": methodName,
"process_id": os.Getpid(),
})
}
return fmt.Errorf("panic in %s: %v", methodName, r)
}

259
pkg/metrics/README.md Normal file
View File

@ -0,0 +1,259 @@
# Metrics Package
A pluggable metrics collection system with Prometheus implementation.
## Quick Start
```go
import "github.com/bitechdev/ResolveSpec/pkg/metrics"
// Initialize Prometheus provider
provider := metrics.NewPrometheusProvider()
metrics.SetProvider(provider)
// Apply middleware to your router
router.Use(provider.Middleware)
// Expose metrics endpoint
http.Handle("/metrics", provider.Handler())
```
## Provider Interface
The package uses a provider interface, allowing you to plug in different metric systems:
```go
type Provider interface {
RecordHTTPRequest(method, path, status string, duration time.Duration)
IncRequestsInFlight()
DecRequestsInFlight()
RecordDBQuery(operation, table string, duration time.Duration, err error)
RecordCacheHit(provider string)
RecordCacheMiss(provider string)
UpdateCacheSize(provider string, size int64)
Handler() http.Handler
}
```
## Recording Metrics
### HTTP Metrics (Automatic)
When using the middleware, HTTP metrics are recorded automatically:
```go
router.Use(provider.Middleware)
```
**Collected:**
- Request duration (histogram)
- Request count by method, path, and status
- Requests in flight (gauge)
### Database Metrics
```go
start := time.Now()
rows, err := db.Query("SELECT * FROM users WHERE id = ?", userID)
duration := time.Since(start)
metrics.GetProvider().RecordDBQuery("SELECT", "users", duration, err)
```
### Cache Metrics
```go
// Record cache hit
metrics.GetProvider().RecordCacheHit("memory")
// Record cache miss
metrics.GetProvider().RecordCacheMiss("memory")
// Update cache size
metrics.GetProvider().UpdateCacheSize("memory", 1024)
```
## Prometheus Metrics
When using `PrometheusProvider`, the following metrics are available:
| Metric Name | Type | Labels | Description |
|-------------|------|--------|-------------|
| `http_request_duration_seconds` | Histogram | method, path, status | HTTP request duration |
| `http_requests_total` | Counter | method, path, status | Total HTTP requests |
| `http_requests_in_flight` | Gauge | - | Current in-flight requests |
| `db_query_duration_seconds` | Histogram | operation, table | Database query duration |
| `db_queries_total` | Counter | operation, table, status | Total database queries |
| `cache_hits_total` | Counter | provider | Total cache hits |
| `cache_misses_total` | Counter | provider | Total cache misses |
| `cache_size_items` | Gauge | provider | Current cache size |
## Prometheus Queries
### HTTP Request Rate
```promql
rate(http_requests_total[5m])
```
### HTTP Request Duration (95th percentile)
```promql
histogram_quantile(0.95, rate(http_request_duration_seconds_bucket[5m]))
```
### Database Query Error Rate
```promql
rate(db_queries_total{status="error"}[5m])
```
### Cache Hit Rate
```promql
rate(cache_hits_total[5m]) / (rate(cache_hits_total[5m]) + rate(cache_misses_total[5m]))
```
## No-Op Provider
If metrics are disabled:
```go
// No provider set - uses no-op provider automatically
metrics.GetProvider().RecordHTTPRequest(...) // Does nothing
```
## Custom Provider
Implement your own metrics provider:
```go
type CustomProvider struct{}
func (c *CustomProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
// Send to your metrics system
}
// Implement other Provider interface methods...
func (c *CustomProvider) Handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Return your metrics format
})
}
// Use it
metrics.SetProvider(&CustomProvider{})
```
## Complete Example
```go
package main
import (
"database/sql"
"log"
"net/http"
"time"
"github.com/bitechdev/ResolveSpec/pkg/metrics"
"github.com/gorilla/mux"
)
func main() {
// Initialize metrics
provider := metrics.NewPrometheusProvider()
metrics.SetProvider(provider)
// Create router
router := mux.NewRouter()
// Apply metrics middleware
router.Use(provider.Middleware)
// Expose metrics endpoint
router.Handle("/metrics", provider.Handler())
// Your API routes
router.HandleFunc("/api/users", getUsersHandler)
log.Fatal(http.ListenAndServe(":8080", router))
}
func getUsersHandler(w http.ResponseWriter, r *http.Request) {
// Record database query
start := time.Now()
users, err := fetchUsers()
duration := time.Since(start)
metrics.GetProvider().RecordDBQuery("SELECT", "users", duration, err)
if err != nil {
http.Error(w, "Internal Server Error", 500)
return
}
// Return users...
}
```
## Docker Compose Example
```yaml
version: '3'
services:
app:
build: .
ports:
- "8080:8080"
prometheus:
image: prom/prometheus
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
command:
- '--config.file=/etc/prometheus/prometheus.yml'
grafana:
image: grafana/grafana
ports:
- "3000:3000"
depends_on:
- prometheus
```
**prometheus.yml:**
```yaml
global:
scrape_interval: 15s
scrape_configs:
- job_name: 'resolvespec'
static_configs:
- targets: ['app:8080']
```
## Best Practices
1. **Label Cardinality**: Keep labels low-cardinality
- ✅ Good: `method`, `status_code`
- ❌ Bad: `user_id`, `timestamp`
2. **Path Normalization**: Normalize dynamic paths
```go
// Instead of /api/users/123
// Use /api/users/:id
```
3. **Metric Naming**: Follow Prometheus conventions
- Use `_total` suffix for counters
- Use `_seconds` suffix for durations
- Use base units (seconds, not milliseconds)
4. **Performance**: Metrics collection is lock-free and highly performant
- Safe for high-throughput applications
- Minimal overhead (<1% in most cases)

86
pkg/metrics/interfaces.go Normal file
View File

@ -0,0 +1,86 @@
package metrics
import (
"net/http"
"time"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// Provider defines the interface for metric collection
type Provider interface {
// RecordHTTPRequest records metrics for an HTTP request
RecordHTTPRequest(method, path, status string, duration time.Duration)
// IncRequestsInFlight increments the in-flight requests counter
IncRequestsInFlight()
// DecRequestsInFlight decrements the in-flight requests counter
DecRequestsInFlight()
// RecordDBQuery records metrics for a database query
RecordDBQuery(operation, table string, duration time.Duration, err error)
// RecordCacheHit records a cache hit
RecordCacheHit(provider string)
// RecordCacheMiss records a cache miss
RecordCacheMiss(provider string)
// UpdateCacheSize updates the cache size metric
UpdateCacheSize(provider string, size int64)
// RecordEventPublished records an event publication
RecordEventPublished(source, eventType string)
// RecordEventProcessed records an event processing with its status
RecordEventProcessed(source, eventType, status string, duration time.Duration)
// UpdateEventQueueSize updates the event queue size metric
UpdateEventQueueSize(size int64)
// Handler returns an HTTP handler for exposing metrics (e.g., /metrics endpoint)
Handler() http.Handler
}
// globalProvider is the global metrics provider
var globalProvider Provider
// SetProvider sets the global metrics provider
func SetProvider(p Provider) {
globalProvider = p
}
// GetProvider returns the current metrics provider
func GetProvider() Provider {
if globalProvider == nil {
// Return no-op provider if none is set
return &NoOpProvider{}
}
return globalProvider
}
// NoOpProvider is a no-op implementation of Provider
type NoOpProvider struct{}
func (n *NoOpProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {}
func (n *NoOpProvider) IncRequestsInFlight() {}
func (n *NoOpProvider) DecRequestsInFlight() {}
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
}
func (n *NoOpProvider) RecordCacheHit(provider string) {}
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
func (n *NoOpProvider) UpdateCacheSize(provider string, size int64) {}
func (n *NoOpProvider) RecordEventPublished(source, eventType string) {}
func (n *NoOpProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
}
func (n *NoOpProvider) UpdateEventQueueSize(size int64) {}
func (n *NoOpProvider) Handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
_, err := w.Write([]byte("Metrics provider not configured"))
if err != nil {
logger.Warn("Failed to write. %v", err)
}
})
}

174
pkg/metrics/prometheus.go Normal file
View File

@ -0,0 +1,174 @@
package metrics
import (
"net/http"
"strconv"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// PrometheusProvider implements the Provider interface using Prometheus
type PrometheusProvider struct {
requestDuration *prometheus.HistogramVec
requestTotal *prometheus.CounterVec
requestsInFlight prometheus.Gauge
dbQueryDuration *prometheus.HistogramVec
dbQueryTotal *prometheus.CounterVec
cacheHits *prometheus.CounterVec
cacheMisses *prometheus.CounterVec
cacheSize *prometheus.GaugeVec
}
// NewPrometheusProvider creates a new Prometheus metrics provider
func NewPrometheusProvider() *PrometheusProvider {
return &PrometheusProvider{
requestDuration: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "HTTP request duration in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"method", "path", "status"},
),
requestTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"method", "path", "status"},
),
requestsInFlight: promauto.NewGauge(
prometheus.GaugeOpts{
Name: "http_requests_in_flight",
Help: "Current number of HTTP requests being processed",
},
),
dbQueryDuration: promauto.NewHistogramVec(
prometheus.HistogramOpts{
Name: "db_query_duration_seconds",
Help: "Database query duration in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"operation", "table"},
),
dbQueryTotal: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "db_queries_total",
Help: "Total number of database queries",
},
[]string{"operation", "table", "status"},
),
cacheHits: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_hits_total",
Help: "Total number of cache hits",
},
[]string{"provider"},
),
cacheMisses: promauto.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_misses_total",
Help: "Total number of cache misses",
},
[]string{"provider"},
),
cacheSize: promauto.NewGaugeVec(
prometheus.GaugeOpts{
Name: "cache_size_items",
Help: "Number of items in cache",
},
[]string{"provider"},
),
}
}
// ResponseWriter wraps http.ResponseWriter to capture status code
type ResponseWriter struct {
http.ResponseWriter
statusCode int
}
func NewResponseWriter(w http.ResponseWriter) *ResponseWriter {
return &ResponseWriter{
ResponseWriter: w,
statusCode: http.StatusOK,
}
}
func (rw *ResponseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
// RecordHTTPRequest implements Provider interface
func (p *PrometheusProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
p.requestDuration.WithLabelValues(method, path, status).Observe(duration.Seconds())
p.requestTotal.WithLabelValues(method, path, status).Inc()
}
// IncRequestsInFlight implements Provider interface
func (p *PrometheusProvider) IncRequestsInFlight() {
p.requestsInFlight.Inc()
}
// DecRequestsInFlight implements Provider interface
func (p *PrometheusProvider) DecRequestsInFlight() {
p.requestsInFlight.Dec()
}
// RecordDBQuery implements Provider interface
func (p *PrometheusProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
status := "success"
if err != nil {
status = "error"
}
p.dbQueryDuration.WithLabelValues(operation, table).Observe(duration.Seconds())
p.dbQueryTotal.WithLabelValues(operation, table, status).Inc()
}
// RecordCacheHit implements Provider interface
func (p *PrometheusProvider) RecordCacheHit(provider string) {
p.cacheHits.WithLabelValues(provider).Inc()
}
// RecordCacheMiss implements Provider interface
func (p *PrometheusProvider) RecordCacheMiss(provider string) {
p.cacheMisses.WithLabelValues(provider).Inc()
}
// UpdateCacheSize implements Provider interface
func (p *PrometheusProvider) UpdateCacheSize(provider string, size int64) {
p.cacheSize.WithLabelValues(provider).Set(float64(size))
}
// Handler implements Provider interface
func (p *PrometheusProvider) Handler() http.Handler {
return promhttp.Handler()
}
// Middleware returns an HTTP middleware that collects metrics
func (p *PrometheusProvider) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// Increment in-flight requests
p.IncRequestsInFlight()
defer p.DecRequestsInFlight()
// Wrap response writer to capture status code
rw := NewResponseWriter(w)
// Call next handler
next.ServeHTTP(rw, r)
// Record metrics
duration := time.Since(start)
status := strconv.Itoa(rw.statusCode)
p.RecordHTTPRequest(r.Method, r.URL.Path, status, duration)
})
}

806
pkg/middleware/README.md Normal file
View File

@ -0,0 +1,806 @@
# Middleware Package
HTTP middleware utilities for security and performance.
## Table of Contents
1. [Rate Limiting](#rate-limiting)
2. [Request Size Limits](#request-size-limits)
3. [Input Sanitization](#input-sanitization)
---
## Rate Limiting
Production-grade rate limiting using token bucket algorithm.
### Quick Start
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
// Create rate limiter: 100 requests per second, burst of 20
rateLimiter := middleware.NewRateLimiter(100, 20)
// Apply to all routes
router.Use(rateLimiter.Middleware)
```
### Basic Usage
```go
package main
import (
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/gorilla/mux"
)
func main() {
router := mux.NewRouter()
// Rate limit: 10 requests per second, burst of 5
rateLimiter := middleware.NewRateLimiter(10, 5)
router.Use(rateLimiter.Middleware)
router.HandleFunc("/api/data", dataHandler)
log.Fatal(http.ListenAndServe(":8080", router))
}
```
### Custom Key Extraction
By default, rate limiting is per IP address. Customize the key:
```go
// Rate limit by User ID from header
keyFunc := func(r *http.Request) string {
userID := r.Header.Get("X-User-ID")
if userID == "" {
return r.RemoteAddr // Fallback to IP
}
return "user:" + userID
}
router.Use(rateLimiter.MiddlewareWithKeyFunc(keyFunc))
```
### Advanced Key Functions
**By API Key:**
```go
keyFunc := func(r *http.Request) string {
apiKey := r.Header.Get("X-API-Key")
if apiKey == "" {
return r.RemoteAddr
}
return "api:" + apiKey
}
```
**By Authenticated User:**
```go
keyFunc := func(r *http.Request) string {
// Extract from JWT or session
user := getUserFromContext(r.Context())
if user != nil {
return "user:" + user.ID
}
return r.RemoteAddr
}
```
**By Path + User:**
```go
keyFunc := func(r *http.Request) string {
user := getUserFromContext(r.Context())
if user != nil {
return fmt.Sprintf("user:%s:path:%s", user.ID, r.URL.Path)
}
return r.URL.Path + ":" + r.RemoteAddr
}
```
### Different Limits Per Route
```go
func main() {
router := mux.NewRouter()
// Public endpoints: 10 rps
publicLimiter := middleware.NewRateLimiter(10, 5)
// API endpoints: 100 rps
apiLimiter := middleware.NewRateLimiter(100, 20)
// Admin endpoints: 1000 rps
adminLimiter := middleware.NewRateLimiter(1000, 50)
// Apply different limiters to subrouters
publicRouter := router.PathPrefix("/public").Subrouter()
publicRouter.Use(publicLimiter.Middleware)
apiRouter := router.PathPrefix("/api").Subrouter()
apiRouter.Use(apiLimiter.Middleware)
adminRouter := router.PathPrefix("/admin").Subrouter()
adminRouter.Use(adminLimiter.Middleware)
}
```
### Rate Limit Response
When rate limited, clients receive:
```http
HTTP/1.1 429 Too Many Requests
Content-Type: text/plain
{"error":"rate_limit_exceeded","message":"Too many requests"}
```
### Configuration Examples
**Tight Rate Limit (Anti-abuse):**
```go
// 1 request per second, burst of 3
rateLimiter := middleware.NewRateLimiter(1, 3)
```
**Moderate Rate Limit (Standard API):**
```go
// 100 requests per second, burst of 20
rateLimiter := middleware.NewRateLimiter(100, 20)
```
**Generous Rate Limit (Internal Services):**
```go
// 1000 requests per second, burst of 100
rateLimiter := middleware.NewRateLimiter(1000, 100)
```
**Time-based Limits:**
```go
// 60 requests per minute = 1 request per second
rateLimiter := middleware.NewRateLimiter(1, 10)
// 1000 requests per hour ≈ 0.28 requests per second
rateLimiter := middleware.NewRateLimiter(0.28, 50)
```
### Understanding Burst
The burst parameter allows short bursts above the rate:
```go
// Rate: 10 rps, Burst: 5
// Allows up to 5 requests immediately, then 10/second
rateLimiter := middleware.NewRateLimiter(10, 5)
```
**Bucket fills at rate:** 10 tokens/second
**Bucket capacity:** 5 tokens
**Request consumes:** 1 token
**Example traffic pattern:**
- T=0s: 5 requests → ✅ All allowed (burst)
- T=0.1s: 1 request → ❌ Denied (bucket empty)
- T=0.5s: 1 request → ✅ Allowed (bucket refilled 0.5 tokens)
- T=1s: 1 request → ✅ Allowed (bucket has ~1 token)
### Cleanup Behavior
The rate limiter automatically cleans up inactive limiters every 5 minutes to prevent memory leaks.
### Performance Characteristics
- **Memory**: ~100 bytes per active limiter
- **Throughput**: >1M requests/second
- **Latency**: <1μs per request
- **Concurrency**: Lock-free for rate checks
### Production Deployment
**With Reverse Proxy:**
```go
// Use X-Forwarded-For or X-Real-IP
keyFunc := func(r *http.Request) string {
// Check proxy headers first
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return strings.Split(ip, ",")[0]
}
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return ip
}
return r.RemoteAddr
}
router.Use(rateLimiter.MiddlewareWithKeyFunc(keyFunc))
```
**Environment-based Configuration:**
```go
import "os"
func getRateLimiter() *middleware.RateLimiter {
rps := getEnvFloat("RATE_LIMIT_RPS", 100)
burst := getEnvInt("RATE_LIMIT_BURST", 20)
return middleware.NewRateLimiter(rps, burst)
}
```
### Testing Rate Limits
```bash
# Send 10 requests rapidly
for i in {1..10}; do
curl -w "Status: %{http_code}\n" http://localhost:8080/api/data
done
```
**Expected output:**
```
Status: 200 # Request 1-5 (within burst)
Status: 200
Status: 200
Status: 200
Status: 200
Status: 429 # Request 6-10 (rate limited)
Status: 429
Status: 429
Status: 429
Status: 429
```
### Complete Example
```go
package main
import (
"encoding/json"
"log"
"net/http"
"os"
"strconv"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/gorilla/mux"
)
func main() {
// Configuration from environment
rps, _ := strconv.ParseFloat(os.Getenv("RATE_LIMIT_RPS"), 64)
if rps == 0 {
rps = 100 // Default
}
burst, _ := strconv.Atoi(os.Getenv("RATE_LIMIT_BURST"))
if burst == 0 {
burst = 20 // Default
}
// Create rate limiter
rateLimiter := middleware.NewRateLimiter(rps, burst)
// Custom key extraction
keyFunc := func(r *http.Request) string {
// Try API key first
if apiKey := r.Header.Get("X-API-Key"); apiKey != "" {
return "api:" + apiKey
}
// Try authenticated user
if userID := r.Header.Get("X-User-ID"); userID != "" {
return "user:" + userID
}
// Fall back to IP
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
return ip
}
return r.RemoteAddr
}
// Create router
router := mux.NewRouter()
// Apply rate limiting
router.Use(rateLimiter.MiddlewareWithKeyFunc(keyFunc))
// Routes
router.HandleFunc("/api/data", dataHandler)
router.HandleFunc("/health", healthHandler)
log.Printf("Starting server with rate limit: %.1f rps, burst: %d", rps, burst)
log.Fatal(http.ListenAndServe(":8080", router))
}
func dataHandler(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]string{
"message": "Data endpoint",
})
}
func healthHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("OK"))
}
```
## Best Practices
1. **Set Appropriate Limits**: Consider your backend capacity
- Database: Can it handle X queries/second?
- External APIs: What are their rate limits?
- Server resources: CPU, memory, connections
2. **Use Burst Wisely**: Allow legitimate traffic spikes
- Too low: Reject valid bursts
- Too high: Allow abuse
3. **Monitor Rate Limits**: Track how often limits are hit
```go
// Log rate limit events
if rateLimited {
log.Printf("Rate limited: %s", clientKey)
}
```
4. **Provide Feedback**: Include rate limit headers (future enhancement)
```http
X-RateLimit-Limit: 100
X-RateLimit-Remaining: 95
X-RateLimit-Reset: 1640000000
```
5. **Tiered Limits**: Different limits for different user tiers
```go
func getRateLimiter(userTier string) *middleware.RateLimiter {
switch userTier {
case "premium":
return middleware.NewRateLimiter(1000, 100)
case "standard":
return middleware.NewRateLimiter(100, 20)
default:
return middleware.NewRateLimiter(10, 5)
}
}
```
---
## Request Size Limits
Protect against oversized request bodies with configurable size limits.
### Quick Start
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
// Default: 10MB limit
sizeLimiter := middleware.NewRequestSizeLimiter(0)
router.Use(sizeLimiter.Middleware)
```
### Custom Size Limit
```go
// 5MB limit
sizeLimiter := middleware.NewRequestSizeLimiter(5 * 1024 * 1024)
router.Use(sizeLimiter.Middleware)
// Or use constants
sizeLimiter := middleware.NewRequestSizeLimiter(middleware.Size5MB)
```
### Available Size Constants
```go
middleware.Size1MB // 1 MB
middleware.Size5MB // 5 MB
middleware.Size10MB // 10 MB (default)
middleware.Size50MB // 50 MB
middleware.Size100MB // 100 MB
```
### Different Limits Per Route
```go
func main() {
router := mux.NewRouter()
// File upload endpoint: 50MB
uploadLimiter := middleware.NewRequestSizeLimiter(middleware.Size50MB)
uploadRouter := router.PathPrefix("/upload").Subrouter()
uploadRouter.Use(uploadLimiter.Middleware)
// API endpoints: 1MB
apiLimiter := middleware.NewRequestSizeLimiter(middleware.Size1MB)
apiRouter := router.PathPrefix("/api").Subrouter()
apiRouter.Use(apiLimiter.Middleware)
}
```
### Dynamic Size Limits
```go
// Custom size based on request
sizeFunc := func(r *http.Request) int64 {
// Premium users get 50MB
if isPremiumUser(r) {
return middleware.Size50MB
}
// Free users get 5MB
return middleware.Size5MB
}
router.Use(sizeLimiter.MiddlewareWithCustomSize(sizeFunc))
```
**By Content-Type:**
```go
sizeFunc := func(r *http.Request) int64 {
contentType := r.Header.Get("Content-Type")
switch {
case strings.Contains(contentType, "multipart/form-data"):
return middleware.Size50MB // File uploads
case strings.Contains(contentType, "application/json"):
return middleware.Size1MB // JSON APIs
default:
return middleware.Size10MB // Default
}
}
```
### Error Response
When size limit exceeded:
```http
HTTP/1.1 413 Request Entity Too Large
X-Max-Request-Size: 10485760
http: request body too large
```
### Complete Example
```go
package main
import (
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/gorilla/mux"
)
func main() {
router := mux.NewRouter()
// API routes: 1MB limit
api := router.PathPrefix("/api").Subrouter()
apiLimiter := middleware.NewRequestSizeLimiter(middleware.Size1MB)
api.Use(apiLimiter.Middleware)
api.HandleFunc("/users", createUserHandler).Methods("POST")
// Upload routes: 50MB limit
upload := router.PathPrefix("/upload").Subrouter()
uploadLimiter := middleware.NewRequestSizeLimiter(middleware.Size50MB)
upload.Use(uploadLimiter.Middleware)
upload.HandleFunc("/file", uploadFileHandler).Methods("POST")
log.Fatal(http.ListenAndServe(":8080", router))
}
```
---
## Input Sanitization
Protect against XSS, injection attacks, and malicious input.
### Quick Start
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
// Default sanitizer (safe defaults)
sanitizer := middleware.DefaultSanitizer()
router.Use(sanitizer.Middleware)
```
### Sanitizer Types
**Default Sanitizer (Recommended):**
```go
sanitizer := middleware.DefaultSanitizer()
// ✓ Escapes HTML entities
// ✓ Removes null bytes
// ✓ Removes control characters
// ✓ Blocks XSS patterns (script tags, event handlers)
// ✗ Does not strip HTML (allows legitimate content)
```
**Strict Sanitizer:**
```go
sanitizer := middleware.StrictSanitizer()
// ✓ All default features
// ✓ Strips ALL HTML tags
// ✓ Max string length: 10,000 chars
```
### Custom Configuration
```go
sanitizer := &middleware.Sanitizer{
StripHTML: true, // Remove HTML tags
EscapeHTML: false, // Don't escape (already stripped)
RemoveNullBytes: true, // Remove \x00
RemoveControlChars: true, // Remove dangerous control chars
MaxStringLength: 5000, // Limit to 5000 chars
// Block patterns (regex)
BlockPatterns: []*regexp.Regexp{
regexp.MustCompile(`(?i)<script`),
regexp.MustCompile(`(?i)javascript:`),
},
// Custom sanitization function
CustomSanitizer: func(s string) string {
// Your custom logic
return strings.ToLower(s)
},
}
router.Use(sanitizer.Middleware)
```
### What Gets Sanitized
**Automatic (via middleware):**
- Query parameters
- Headers (User-Agent, Referer, X-Forwarded-For, X-Real-IP)
**Manual (in your handler):**
- Request body (JSON, form data)
- Database queries
- File names
### Manual Sanitization
**String Values:**
```go
sanitizer := middleware.DefaultSanitizer()
// Sanitize user input
username := sanitizer.Sanitize(r.FormValue("username"))
email := sanitizer.Sanitize(r.FormValue("email"))
```
**Map/JSON Data:**
```go
var data map[string]interface{}
json.Unmarshal(body, &data)
// Sanitize all string values recursively
sanitizedData := sanitizer.SanitizeMap(data)
```
**Nested Structures:**
```go
type User struct {
Name string
Email string
Bio string
Profile map[string]interface{}
}
// After unmarshaling
user.Name = sanitizer.Sanitize(user.Name)
user.Email = sanitizer.Sanitize(user.Email)
user.Bio = sanitizer.Sanitize(user.Bio)
user.Profile = sanitizer.SanitizeMap(user.Profile)
```
### Specialized Sanitizers
**Filenames:**
```go
import "github.com/bitechdev/ResolveSpec/pkg/middleware"
filename := middleware.SanitizeFilename(uploadedFilename)
// Removes: .., /, \, null bytes
// Limits: 255 characters
```
**Emails:**
```go
email := middleware.SanitizeEmail(" USER@EXAMPLE.COM ")
// Result: "user@example.com"
// Trims, lowercases, removes null bytes
```
**URLs:**
```go
url := middleware.SanitizeURL(userInput)
// Blocks: javascript:, data: protocols
// Removes: null bytes
```
### Blocked Patterns (Default)
The default sanitizer blocks:
1. **Script tags**: `<script>...</script>`
2. **JavaScript protocol**: `javascript:alert(1)`
3. **Event handlers**: `onclick="..."`, `onerror="..."`
4. **Iframes**: `<iframe src="...">`
5. **Objects**: `<object data="...">`
6. **Embeds**: `<embed src="...">`
### Security Best Practices
**1. Layer Defense:**
```go
// Layer 1: Middleware (query params, headers)
router.Use(sanitizer.Middleware)
// Layer 2: Input validation (in handler)
func createUserHandler(w http.ResponseWriter, r *http.Request) {
var user User
json.NewDecoder(r.Body).Decode(&user)
// Sanitize
user.Name = sanitizer.Sanitize(user.Name)
user.Email = middleware.SanitizeEmail(user.Email)
// Validate
if !isValidEmail(user.Email) {
http.Error(w, "Invalid email", 400)
return
}
// Use parameterized queries (prevents SQL injection)
db.Exec("INSERT INTO users (name, email) VALUES (?, ?)",
user.Name, user.Email)
}
```
**2. Context-Aware Sanitization:**
```go
// HTML content (user posts, comments)
sanitizer := middleware.StrictSanitizer()
post.Content = sanitizer.Sanitize(post.Content)
// Structured data (JSON API)
sanitizer := middleware.DefaultSanitizer()
data = sanitizer.SanitizeMap(jsonData)
// Search queries (preserve special chars)
query = middleware.SanitizeFilename(searchTerm) // Light sanitization
```
**3. Output Encoding:**
```go
// When rendering HTML
import "html/template"
tmpl := template.Must(template.New("page").Parse(`
<h1>{{.Title}}</h1>
<p>{{.Content}}</p>
`))
// template.HTML automatically escapes
tmpl.Execute(w, data)
```
### Complete Example
```go
package main
import (
"encoding/json"
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/middleware"
"github.com/gorilla/mux"
)
func main() {
router := mux.NewRouter()
// Apply sanitization middleware
sanitizer := middleware.DefaultSanitizer()
router.Use(sanitizer.Middleware)
router.HandleFunc("/api/users", createUserHandler).Methods("POST")
log.Fatal(http.ListenAndServe(":8080", router))
}
func createUserHandler(w http.ResponseWriter, r *http.Request) {
sanitizer := middleware.DefaultSanitizer()
var user struct {
Name string `json:"name"`
Email string `json:"email"`
Bio string `json:"bio"`
}
if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
http.Error(w, "Invalid JSON", 400)
return
}
// Sanitize inputs
user.Name = sanitizer.Sanitize(user.Name)
user.Email = middleware.SanitizeEmail(user.Email)
user.Bio = sanitizer.Sanitize(user.Bio)
// Validate
if len(user.Name) == 0 || len(user.Email) == 0 {
http.Error(w, "Name and email required", 400)
return
}
// Save to database (use parameterized queries!)
// db.Exec("INSERT INTO users (name, email, bio) VALUES (?, ?, ?)",
// user.Name, user.Email, user.Bio)
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{
"status": "created",
})
}
```
### Testing Sanitization
```bash
# Test XSS prevention
curl -X POST http://localhost:8080/api/users \
-H "Content-Type: application/json" \
-d '{
"name": "<script>alert(1)</script>John",
"email": "test@example.com",
"bio": "My bio with <iframe src=\"evil.com\"></iframe>"
}'
# Script tags and iframes should be removed
```
### Performance
- **Overhead**: <1ms per request for typical payloads
- **Regex compilation**: Done once at initialization
- **Safe for production**: Minimal performance impact

Some files were not shown because too many files have changed in this diff Show More