Compare commits

...

25 Commits

Author SHA1 Message Date
c2e2c9b873 feat(transport): add streamable HTTP transport for MCP 2026-04-07 19:52:38 +02:00
4adf94fe37 feat(go.mod): add mcp-go dependency for enhanced functionality 2026-04-07 19:09:51 +02:00
Hein
405a04a192 feat(security): integrate security hooks for access control
Some checks failed
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m6s
Tests / Unit Tests (push) Successful in -30m22s
Tests / Integration Tests (push) Failing after -30m41s
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m3s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m36s
Build , Vet Test, and Lint / Build (push) Successful in -29m58s
* Add security hooks for per-entity operation rules and row/column-level security.
* Implement annotation tool for storing and retrieving freeform annotations.
* Enhance handler to support model registration with access rules.
2026-04-07 15:53:12 +02:00
Hein
c1b16d363a feat(db): add DB method to sqlConnection and mongoConnection
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m22s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m59s
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m11s
Build , Vet Test, and Lint / Build (push) Successful in -30m12s
Tests / Unit Tests (push) Successful in -30m49s
Tests / Integration Tests (push) Failing after -30m59s
2026-04-01 15:34:09 +02:00
Hein
568df8c6d6 feat(security): add configurable SQL procedure names
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -25m9s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -24m29s
Build , Vet Test, and Lint / Build (push) Successful in -30m5s
Build , Vet Test, and Lint / Lint Code (push) Failing after -28m58s
Tests / Integration Tests (push) Failing after -30m26s
Tests / Unit Tests (push) Successful in -28m7s
* Introduce SQLNames struct to define stored procedure names.
* Update DatabaseAuthenticator, JWTAuthenticator, and other providers to use SQLNames for procedure calls.
* Remove hardcoded procedure names for better flexibility and customization.
* Implement validation for SQL names to ensure they are valid identifiers.
* Add tests for SQLNames functionality and merging behavior.
2026-03-31 14:25:59 +02:00
Hein
aa362c77da fix(cursor): trim parentheses from sort column names 2026-03-27 15:07:10 +02:00
Hein
1641eaf278 feat(resolvemcp): enhance handler with configuration support
* Introduce Config struct for BaseURL and BasePath settings
* Update handler creation functions to accept configuration
* Modify SSEServer to use dynamic base URL detection
* Adjust route setup functions to utilize BasePath from config
2026-03-27 13:56:03 +02:00
Hein
200a03c225 feat(resolvemcp): add SSE server and bunrouter setup functions
* Introduce SSEServer method for creating an SSE server bound to the handler.
* Add SetupBunRouterRoutes function to mount MCP HTTP/SSE endpoints on bunrouter.
* Update README with usage examples for new features.
2026-03-27 13:28:03 +02:00
Hein
7ef9cf39d3 style(tools): simplify string formatting in descriptions 2026-03-27 13:10:50 +02:00
Hein
7f6410f665 feat(resolvemcp): add support for join-column sorting in cursor pagination
* Enhance getCursorFilter to accept join clauses for sorting
* Update resolveColumn to handle joined columns
* Modify tests to validate new join functionality
2026-03-27 13:10:42 +02:00
Hein
835bbb0727 style(hooks): reorder fields in HookContext for consistency 2026-03-27 12:57:30 +02:00
Hein
047a1cc187 feat(resolvemcp): add hook system for model operations
* Implement hooks for CRUD operations: before/after handle, read, create, update, delete.
* Introduce HookContext and HookRegistry for managing hooks.
* Allow registration and execution of multiple hooks per operation.

feat(resolvemcp): implement MCP tools for CRUD operations
* Register tools for reading, creating, updating, and deleting records.
* Define tool arguments and handle requests with appropriate responses.
* Support for resource registration with metadata.

fix(restheadspec): enhance cursor handling for joins
* Improve cursor filter generation to support lateral joins.
* Update join alias extraction to handle lateral joins correctly.
* Ensure cursor filters do not contain empty comparisons.

test(restheadspec): add tests for cursor filters and join alias extraction
* Create tests for lateral join scenarios in cursor filter generation.
* Validate join alias extraction for various join types, including lateral joins.
2026-03-27 12:57:08 +02:00
Hein
7a498edab7 fix(headers): enhance relation name resolution logic
* Allow resolution for both regular headers and X-Files.
* Introduce join-key-aware resolution for disambiguation.
* Add new function to handle multiple fields pointing to the same type.
2026-03-25 12:09:03 +02:00
Hein
f10bb0827e fix(sql_helpers): ensure case-insensitive matching for allowed prefixes 2026-03-25 10:57:42 +02:00
Hein
22a4ab345a feat(security): add session cookie management functions
* Introduce SessionCookieOptions for configurable session cookies
* Implement SetSessionCookie, GetSessionCookie, and ClearSessionCookie functions
* Enhance cookie handling in DatabaseAuthenticator
2026-03-24 17:11:53 +02:00
Hein
e289c2ed8f fix(handler): restore JoinAliases for proper WHERE sanitization 2026-03-24 12:00:02 +02:00
Hein
0d50bcfee6 fix(provider): enhance file opening logic with alternate path. Handling broken cases to be compatible with Bitech clients
* Implemented alternate path handling for file retrieval
* Improved error messaging for file not found scenarios
2026-03-24 09:02:17 +02:00
4df626ea71 chore(license): update project notice and clarify licensing terms 2026-03-23 20:32:09 +02:00
Hein
7dd630dec2 fix(handler): set default sort to primary key if none provided
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -26m15s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -26m11s
Build , Vet Test, and Lint / Lint Code (push) Failing after -30m52s
Build , Vet Test, and Lint / Build (push) Successful in -30m44s
Tests / Integration Tests (push) Failing after -31m5s
Tests / Unit Tests (push) Successful in -29m6s
2026-03-11 14:37:04 +02:00
Hein
613bf22cbd fix(cursor): use full schema-qualified table name in filters 2026-03-11 14:25:44 +02:00
d1ae4fe64e refactor(handler): unify filter operator handling for consistency
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m26s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m58s
Build , Vet Test, and Lint / Lint Code (push) Successful in -29m48s
Build , Vet Test, and Lint / Build (push) Successful in -30m4s
Tests / Integration Tests (push) Failing after -30m39s
Tests / Unit Tests (push) Successful in -30m29s
2026-03-01 13:21:38 +02:00
254102bfac refactor(auth): simplify handler type assertions for middleware
Some checks failed
Build , Vet Test, and Lint / Run Vet Tests (1.24.x) (push) Successful in -30m2s
Build , Vet Test, and Lint / Run Vet Tests (1.23.x) (push) Successful in -29m31s
Build , Vet Test, and Lint / Lint Code (push) Successful in -29m19s
Build , Vet Test, and Lint / Build (push) Successful in -29m42s
Tests / Integration Tests (push) Failing after -30m35s
Tests / Unit Tests (push) Successful in -30m17s
2026-03-01 12:08:36 +02:00
6c27419dbc refactor(auth): enhance request handling with middleware-enriched context 2026-03-01 12:06:43 +02:00
377336caf4 feat(sql): implement IN condition handling with parameterized queries 2026-03-01 09:52:32 +02:00
79720d5421 feat(security): add BeforeHandle hook for auth checks after model resolution
- Implement BeforeHandle hook to enforce authentication based on model rules.
- Integrate with existing security mechanisms to allow or deny access.
- Update documentation to reflect new hook and its usage.
2026-03-01 09:15:30 +02:00
56 changed files with 4393 additions and 295 deletions

27
LICENSE
View File

@@ -1,3 +1,18 @@
Project Notice
This project was independently developed.
The contents of this repository were prepared and published outside any time
allocated to Bitech Systems CC and do not contain, incorporate, disclose,
or rely upon any proprietary or confidential information, trade secrets,
protected designs, or other intellectual property of Bitech Systems CC.
No portion of this repository reproduces any Bitech Systems CC-specific
implementation, design asset, confidential workflow, or non-public technical material.
This notice is provided for clarification only and does not modify the terms of
the Apache License, Version 2.0.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
@@ -32,15 +47,15 @@ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
@@ -56,7 +71,7 @@ END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
Copyright 2025 wdevs

View File

@@ -9,6 +9,7 @@ ResolveSpec is a flexible and powerful REST API specification and implementation
3. **FuncSpec** - Header-based API to map and call API's to sql functions
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
6. **ResolveMCP** - Model Context Protocol (MCP) server that exposes models as AI-consumable tools and resources over HTTP/SSE
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
@@ -21,6 +22,7 @@ All share the same core architecture and provide dynamic data querying, relation
* [Quick Start](#quick-start)
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
* [ResolveMCP (MCP Server)](#resolvemcp---mcp-server)
* [Architecture](#architecture)
* [API Structure](#api-structure)
* [RestHeadSpec Overview](#restheadspec-header-based-api)
@@ -50,6 +52,15 @@ All share the same core architecture and provide dynamic data querying, relation
* **🆕 Backward Compatible**: Existing code works without changes
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
### ResolveMCP (v3.2+)
* **🆕 MCP Server**: Expose any registered database model as Model Context Protocol tools and resources
* **🆕 AI-Ready Descriptions**: Tool descriptions include the full column schema, primary key, nullable flags, and relations — giving AI models everything they need to query correctly without guessing
* **🆕 Four Tools Per Model**: `read_`, `create_`, `update_`, `delete_` tools auto-registered per model
* **🆕 Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
* **🆕 HTTP/SSE Transport**: Standards-compliant SSE transport for use with Claude Desktop, Cursor, and any MCP-compatible client
* **🆕 Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth and side-effects
### RestHeadSpec (v2.1+)
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
@@ -190,6 +201,40 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
### ResolveMCP (MCP Server)
ResolveMCP exposes registered models as Model Context Protocol tools so AI models (Claude, Cursor, etc.) can query and mutate your database directly:
```go
import "github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
// Create handler
handler := resolvemcp.NewHandlerWithGORM(db)
// Register models — must be done BEFORE Build()
handler.RegisterModel("public", "users", &User{})
handler.RegisterModel("public", "posts", &Post{})
// Finalize: registers MCP tools and resources
handler.Build()
// Mount SSE transport on your existing router
router := mux.NewRouter()
resolvemcp.SetupMuxRoutes(router, handler, "http://localhost:8080")
// MCP clients connect to:
// SSE stream: GET http://localhost:8080/mcp/sse
// Messages: POST http://localhost:8080/mcp/message
//
// Auto-registered tools per model:
// read_public_users — filter, sort, paginate, preload
// create_public_users — insert a new record
// update_public_users — update a record by ID
// delete_public_users — delete a record by ID
```
For complete documentation, see [pkg/resolvemcp/README.md](pkg/resolvemcp/README.md) (if present) or the package source.
## Architecture
### Two Complementary APIs
@@ -344,6 +389,19 @@ Alternative REST API where query options are passed via HTTP headers.
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
#### ResolveMCP - MCP Server
Expose any registered model as Model Context Protocol tools and resources consumable by AI models over HTTP/SSE.
**Key Features**:
- Four tools per model: `read_`, `create_`, `update_`, `delete_`
- Rich AI-readable descriptions: column names, types, primary key, nullable flags, and preloadable relations
- Full query support: filters, sort, limit/offset, cursor pagination, column selection, preloads
- HTTP/SSE transport compatible with Claude Desktop, Cursor, and any MCP client
- Same Before/After lifecycle hooks as ResolveSpec
For complete documentation, see [pkg/resolvemcp/](pkg/resolvemcp/).
#### FuncSpec - Function-Based SQL API
Execute SQL functions and queries through a simple HTTP API with header-based parameters.
@@ -529,7 +587,18 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
## What's New
### v3.1 (Latest - February 2026)
### v3.2 (Latest - March 2026)
**ResolveMCP - Model Context Protocol Server (🆕)**:
* **MCP Tools**: Four tools auto-registered per model (`read_`, `create_`, `update_`, `delete_`) over HTTP/SSE transport
* **AI-Ready Descriptions**: Full column schema, primary key, nullable flags, and relation names surfaced in tool descriptions so AI models can query without guessing
* **Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
* **HTTP/SSE Transport**: Standards-compliant transport compatible with Claude Desktop, Cursor, and any MCP 2024-11-05 client
* **Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth, auditing, and side-effects
* **MCP Resources**: Each model also exposed as a named resource for direct data access by AI clients
### v3.1 (February 2026)
**SQLite Schema Translation (🆕)**:

5
go.mod
View File

@@ -15,6 +15,7 @@ require (
github.com/gorilla/websocket v1.5.3
github.com/jackc/pgx/v5 v5.8.0
github.com/klauspost/compress v1.18.2
github.com/mark3labs/mcp-go v0.46.0
github.com/mattn/go-sqlite3 v1.14.33
github.com/microsoft/go-mssqldb v1.9.5
github.com/mochi-mqtt/server/v2 v2.7.9
@@ -40,6 +41,7 @@ require (
go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/zap v1.27.1
golang.org/x/crypto v0.46.0
golang.org/x/oauth2 v0.34.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.6.0
gorm.io/driver/sqlite v1.6.0
@@ -78,6 +80,7 @@ require (
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
github.com/golang-sql/sqlexp v0.1.0 // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/google/jsonschema-go v0.4.2 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
@@ -131,6 +134,7 @@ require (
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
github.com/xdg-go/scram v1.2.0 // indirect
github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
@@ -143,7 +147,6 @@ require (
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
golang.org/x/mod v0.31.0 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/oauth2 v0.34.0 // indirect
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect

6
go.sum
View File

@@ -120,6 +120,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@@ -173,6 +175,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
@@ -326,6 +330,8 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=

View File

@@ -2,6 +2,7 @@ package common
import (
"fmt"
"reflect"
"regexp"
"strings"
@@ -167,16 +168,17 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
}
// Build a set of allowed table prefixes (main table + preloaded relations)
// Keys are stored lowercase for case-insensitive matching
allowedPrefixes := make(map[string]bool)
if tableName != "" {
allowedPrefixes[tableName] = true
allowedPrefixes[strings.ToLower(tableName)] = true
}
// Add preload relation names as allowed prefixes
if len(options) > 0 && options[0] != nil {
for pi := range options[0].Preload {
if options[0].Preload[pi].Relation != "" {
allowedPrefixes[options[0].Preload[pi].Relation] = true
allowedPrefixes[strings.ToLower(options[0].Preload[pi].Relation)] = true
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
}
}
@@ -184,7 +186,7 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
// Add join aliases as allowed prefixes
for _, alias := range options[0].JoinAliases {
if alias != "" {
allowedPrefixes[alias] = true
allowedPrefixes[strings.ToLower(alias)] = true
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
}
}
@@ -216,8 +218,8 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
currentPrefix, columnName := extractTableAndColumn(condToCheck)
if currentPrefix != "" && columnName != "" {
// Check if the prefix is allowed (main table or preload relation)
if !allowedPrefixes[currentPrefix] {
// Check if the prefix is allowed (main table or preload relation) - case-insensitive
if !allowedPrefixes[strings.ToLower(currentPrefix)] {
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
if validColumns == nil || isValidColumn(columnName, validColumns) {
// Replace the incorrect prefix with the correct main table name
@@ -925,3 +927,36 @@ func extractLeftSideOfComparison(cond string) string {
return ""
}
// FilterValueToSlice converts a filter value to []interface{} for use with IN operators.
// JSON-decoded arrays arrive as []interface{}, but typed slices (e.g. []string) also work.
// Returns a single-element slice if the value is not a slice type.
func FilterValueToSlice(v interface{}) []interface{} {
if v == nil {
return nil
}
rv := reflect.ValueOf(v)
if rv.Kind() == reflect.Slice {
result := make([]interface{}, rv.Len())
for i := 0; i < rv.Len(); i++ {
result[i] = rv.Index(i).Interface()
}
return result
}
return []interface{}{v}
}
// BuildInCondition builds a parameterized IN condition from a filter value.
// Returns the condition string (e.g. "col IN (?,?)") and the individual values as args.
// Returns ("", nil) if the value is empty or not a slice.
func BuildInCondition(column string, v interface{}) (query string, args []interface{}) {
values := FilterValueToSlice(v)
if len(values) == 0 {
return "", nil
}
placeholders := make([]string, len(values))
for i := range values {
placeholders[i] = "?"
}
return fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")), values
}

View File

@@ -26,6 +26,7 @@ type Connection interface {
Bun() (*bun.DB, error)
GORM() (*gorm.DB, error)
Native() (*sql.DB, error)
DB() (*sql.DB, error)
// Common Database interface (for SQL databases)
Database() (common.Database, error)
@@ -224,6 +225,11 @@ func (c *sqlConnection) Native() (*sql.DB, error) {
return c.nativeDB, nil
}
// DB returns the underlying *sql.DB connection
func (c *sqlConnection) DB() (*sql.DB, error) {
return c.Native()
}
// Bun returns a Bun ORM instance wrapping the native connection
func (c *sqlConnection) Bun() (*bun.DB, error) {
if c == nil {
@@ -645,6 +651,11 @@ func (c *mongoConnection) Native() (*sql.DB, error) {
return nil, ErrNotSQLDatabase
}
// DB returns an error for MongoDB connections
func (c *mongoConnection) DB() (*sql.DB, error) {
return nil, ErrNotSQLDatabase
}
// Database returns an error for MongoDB connections
func (c *mongoConnection) Database() (common.Database, error) {
return nil, ErrNotSQLDatabase

View File

@@ -2,14 +2,38 @@ package funcspec
import (
"context"
"fmt"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks registers security hooks for funcspec handlers
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
// We provide audit logging for data access tracking
// We provide auth enforcement and audit logging for data access tracking
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 0: BeforeQueryList - Auth check before list query execution
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
hookCtx.Abort = true
hookCtx.AbortMessage = "authentication required"
hookCtx.AbortCode = http.StatusUnauthorized
return fmt.Errorf("authentication required")
}
return nil
})
// Hook 0: BeforeQuery - Auth check before single query execution
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
hookCtx.Abort = true
hookCtx.AbortMessage = "authentication required"
hookCtx.AbortCode = http.StatusUnauthorized
return fmt.Errorf("authentication required")
}
return nil
})
// Hook 1: BeforeQueryList - Audit logging before query list execution
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
secCtx := newFuncSpecSecurityContext(hookCtx)

View File

@@ -10,6 +10,8 @@ import (
type ModelRules struct {
CanPublicRead bool // Whether the model can be read (GET operations)
CanPublicUpdate bool // Whether the model can be updated (PUT/PATCH operations)
CanPublicCreate bool // Whether the model can be created (POST operations)
CanPublicDelete bool // Whether the model can be deleted (DELETE operations)
CanRead bool // Whether the model can be read (GET operations)
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
CanCreate bool // Whether the model can be created (POST operations)
@@ -26,6 +28,8 @@ func DefaultModelRules() ModelRules {
CanDelete: true,
CanPublicRead: false,
CanPublicUpdate: false,
CanPublicCreate: false,
CanPublicDelete: false,
SecurityDisabled: false,
}
}

View File

@@ -9,7 +9,7 @@ MQTTSpec is an MQTT-based database query framework that enables real-time databa
- **Full CRUD Operations**: Create, Read, Update, Delete with hooks
- **Real-time Subscriptions**: Subscribe to entity changes with filtering
- **Database Agnostic**: GORM and Bun ORM support
- **Lifecycle Hooks**: 12 hooks for authentication, authorization, validation, and auditing
- **Lifecycle Hooks**: 13 hooks for authentication, authorization, validation, and auditing
- **Multi-tenancy Support**: Built-in tenant isolation via hooks
- **Thread-safe**: Proper concurrency handling throughout
@@ -326,10 +326,11 @@ When any client creates/updates/deletes a user matching the subscription filters
## Lifecycle Hooks
MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
MQTTSpec provides 13 lifecycle hooks for implementing cross-cutting concerns:
### Hook Types
- `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
- `BeforeConnect` / `AfterConnect` - Connection lifecycle
- `BeforeDisconnect` / `AfterDisconnect` - Disconnection lifecycle
- `BeforeRead` / `AfterRead` - Read operations
@@ -339,6 +340,20 @@ MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
- `BeforeSubscribe` / `AfterSubscribe` - Subscription creation
- `BeforeUnsubscribe` / `AfterUnsubscribe` - Subscription removal
### Security Hooks (Recommended)
Use `RegisterSecurityHooks` for integrated auth with model-rule support:
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList := security.NewSecurityList(provider)
mqttspec.RegisterSecurityHooks(handler, securityList)
// Registers BeforeHandle (model auth), BeforeRead (load rules),
// AfterRead (column security + audit), BeforeUpdate, BeforeDelete
```
### Authentication Example (JWT)
```go
@@ -657,7 +672,7 @@ handler, err := mqttspec.NewHandlerWithGORM(db,
| **Network Efficiency** | Better for unreliable networks | Better for low-latency |
| **Best For** | IoT, mobile apps, distributed systems | Web applications, real-time dashboards |
| **Message Protocol** | Same JSON structure | Same JSON structure |
| **Hooks** | Same 12 hooks | Same 12 hooks |
| **Hooks** | Same 13 hooks | Same 13 hooks |
| **CRUD Operations** | Identical | Identical |
| **Subscriptions** | Identical (via MQTT topics) | Identical (via app-level) |

View File

@@ -284,6 +284,15 @@ func (h *Handler) handleRequest(client *Client, msg *Message) {
},
}
// Execute BeforeHandle hook - auth check fires here, after model resolution
hookCtx.Operation = string(msg.Operation)
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
if hookCtx.Abort {
h.sendError(client.ID, msg.ID, "unauthorized", hookCtx.AbortMessage)
}
return
}
// Route to operation handler
switch msg.Operation {
case OperationRead:

View File

@@ -20,8 +20,11 @@ type (
HookRegistry = websocketspec.HookRegistry
)
// Hook type constants - all 12 lifecycle hooks
// Hook type constants - all lifecycle hooks
const (
// BeforeHandle fires after model resolution, before operation dispatch
BeforeHandle = websocketspec.BeforeHandle
// CRUD operation hooks
BeforeRead = websocketspec.BeforeRead
AfterRead = websocketspec.AfterRead

View File

@@ -0,0 +1,108 @@
package mqttspec
import (
"context"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks registers all security-related hooks with the MQTT handler
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 0: BeforeHandle - enforce auth after model resolution
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
hookCtx.Abort = true
hookCtx.AbortMessage = err.Error()
hookCtx.AbortCode = http.StatusUnauthorized
return err
}
return nil
})
// Hook 1: BeforeRead - Load security rules
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.LoadSecurityRules(secCtx, securityList)
})
// Hook 2: AfterRead - Apply column-level security (masking)
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.ApplyColumnSecurity(secCtx, securityList)
})
// Hook 3 (Optional): Audit logging
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.LogDataAccess(secCtx)
})
// Hook 4: BeforeUpdate - enforce CanUpdate rule from context/registry
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.CheckModelUpdateAllowed(secCtx)
})
// Hook 5: BeforeDelete - enforce CanDelete rule from context/registry
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)
return security.CheckModelDeleteAllowed(secCtx)
})
logger.Info("Security hooks registered for mqttspec handler")
}
// securityContext adapts mqttspec.HookContext to security.SecurityContext interface
type securityContext struct {
ctx *HookContext
}
func newSecurityContext(ctx *HookContext) security.SecurityContext {
return &securityContext{ctx: ctx}
}
func (s *securityContext) GetContext() context.Context {
return s.ctx.Context
}
func (s *securityContext) GetUserID() (int, bool) {
return security.GetUserID(s.ctx.Context)
}
func (s *securityContext) GetSchema() string {
return s.ctx.Schema
}
func (s *securityContext) GetEntity() string {
return s.ctx.Entity
}
func (s *securityContext) GetModel() interface{} {
return s.ctx.Model
}
// GetQuery retrieves a stored query from hook metadata
func (s *securityContext) GetQuery() interface{} {
if s.ctx.Metadata == nil {
return nil
}
return s.ctx.Metadata["query"]
}
// SetQuery stores the query in hook metadata
func (s *securityContext) SetQuery(query interface{}) {
if s.ctx.Metadata == nil {
s.ctx.Metadata = make(map[string]interface{})
}
s.ctx.Metadata["query"] = query
}
func (s *securityContext) GetResult() interface{} {
return s.ctx.Result
}
func (s *securityContext) SetResult(result interface{}) {
s.ctx.Result = result
}

542
pkg/resolvemcp/README.md Normal file
View File

@@ -0,0 +1,542 @@
# resolvemcp
Package `resolvemcp` exposes registered database models as **Model Context Protocol (MCP) tools and resources** over HTTP/SSE transport. It mirrors the `resolvespec` package patterns — same model registration API, same filter/sort/pagination/preload options, same lifecycle hook system.
## Quick Start
```go
import (
"github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
"github.com/gorilla/mux"
)
// 1. Create a handler
handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{
BaseURL: "http://localhost:8080",
})
// 2. Register models
handler.RegisterModel("public", "users", &User{})
handler.RegisterModel("public", "orders", &Order{})
// 3. Mount routes
r := mux.NewRouter()
resolvemcp.SetupMuxRoutes(r, handler)
```
---
## Config
```go
type Config struct {
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
// Sent to MCP clients during the SSE handshake so they know where to POST messages.
// If empty, it is detected from each incoming request using the Host header and
// TLS state (X-Forwarded-Proto is honoured for reverse-proxy deployments).
BaseURL string
// BasePath is the URL path prefix where MCP endpoints are mounted (e.g. "/mcp").
// Required.
BasePath string
}
```
## Handler Creation
| Function | Description |
|---|---|
| `NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler` | Backed by GORM |
| `NewHandlerWithBun(db *bun.DB, cfg Config) *Handler` | Backed by Bun |
| `NewHandlerWithDB(db common.Database, cfg Config) *Handler` | Backed by any `common.Database` |
| `NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler` | Full control over registry |
---
## Registering Models
```go
handler.RegisterModel(schema, entity string, model interface{}) error
```
- `schema` — database schema name (e.g. `"public"`), or empty string for no schema prefix.
- `entity` — table/entity name (e.g. `"users"`).
- `model` — a pointer to a struct (e.g. `&User{}`).
Each call immediately creates four MCP **tools** and one MCP **resource** for the model.
---
## HTTP Transports
`Config.BasePath` is required and used for all route registration.
`Config.BaseURL` is optional — when empty it is detected from each request.
Two transports are supported: **SSE** (legacy, two-endpoint) and **Streamable HTTP** (recommended, single-endpoint).
---
### SSE Transport
Two endpoints: `GET {BasePath}/sse` (subscribe) + `POST {BasePath}/message` (send).
#### Gorilla Mux
```go
resolvemcp.SetupMuxRoutes(r, handler)
```
| Route | Method | Description |
|---|---|---|
| `{BasePath}/sse` | GET | SSE connection — clients subscribe here |
| `{BasePath}/message` | POST | JSON-RPC — clients send requests here |
#### bunrouter
```go
resolvemcp.SetupBunRouterRoutes(router, handler)
```
#### Gin / net/http / Echo
```go
sse := handler.SSEServer()
engine.Any("/mcp/*path", gin.WrapH(sse)) // Gin
http.Handle("/mcp/", sse) // net/http
e.Any("/mcp/*", echo.WrapHandler(sse)) // Echo
```
---
### Streamable HTTP Transport
Single endpoint at `{BasePath}`. Handles POST (client→server) and GET (server→client streaming). Preferred for new integrations.
#### Gorilla Mux
```go
resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler)
```
Mounts the handler at `{BasePath}` (all methods).
#### bunrouter
```go
resolvemcp.SetupBunRouterStreamableHTTPRoutes(router, handler)
```
Registers GET, POST, DELETE on `{BasePath}`.
#### Gin / net/http / Echo
```go
h := handler.StreamableHTTPServer()
// or: h := resolvemcp.NewStreamableHTTPHandler(handler)
engine.Any("/mcp", gin.WrapH(h)) // Gin
http.Handle("/mcp", h) // net/http
e.Any("/mcp", echo.WrapHandler(h)) // Echo
```
---
### Authentication
Add middleware before the MCP routes. The handler itself has no auth layer.
---
## Security
`resolvemcp` integrates with the `security` package to provide per-entity access control, row-level security, and column-level security — the same system used by `resolvespec` and `restheadspec`.
### Wiring security hooks
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
securityList := security.NewSecurityList(mySecurityProvider)
resolvemcp.RegisterSecurityHooks(handler, securityList)
```
Call `RegisterSecurityHooks` **once**, after creating the handler and before registering models. It installs these controls automatically:
| Hook | Effect |
|---|---|
| `BeforeHandle` | Enforces per-entity operation rules (see below) |
| `BeforeRead` | Loads RLS/CLS rules, then injects a user-scoped WHERE clause |
| `AfterRead` | Masks/hides columns per column-security rules; writes audit log |
| `BeforeUpdate` | Blocks update if `CanUpdate` is false |
| `BeforeDelete` | Blocks delete if `CanDelete` is false |
### Per-entity operation rules
Use `RegisterModelWithRules` instead of `RegisterModel` to set access rules at registration time:
```go
import "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
// Read-only entity
handler.RegisterModelWithRules("public", "audit_logs", &AuditLog{}, modelregistry.ModelRules{
CanRead: true,
CanCreate: false,
CanUpdate: false,
CanDelete: false,
})
// Public read, authenticated write
handler.RegisterModelWithRules("public", "products", &Product{}, modelregistry.ModelRules{
CanPublicRead: true,
CanRead: true,
CanCreate: true,
CanUpdate: true,
CanDelete: false,
})
```
To update rules for an already-registered model:
```go
handler.SetModelRules("public", "users", modelregistry.ModelRules{
CanRead: true,
CanCreate: true,
CanUpdate: true,
CanDelete: false,
})
```
`RegisterModel` (no rules) registers with all-allowed defaults (`CanRead/Create/Update/Delete = true`).
### ModelRules fields
| Field | Default | Description |
|---|---|---|
| `CanPublicRead` | `false` | Allow unauthenticated reads |
| `CanPublicCreate` | `false` | Allow unauthenticated creates |
| `CanPublicUpdate` | `false` | Allow unauthenticated updates |
| `CanPublicDelete` | `false` | Allow unauthenticated deletes |
| `CanRead` | `true` | Allow authenticated reads |
| `CanCreate` | `true` | Allow authenticated creates |
| `CanUpdate` | `true` | Allow authenticated updates |
| `CanDelete` | `true` | Allow authenticated deletes |
| `SecurityDisabled` | `false` | Skip all security checks for this model |
---
## MCP Tools
### Tool Naming
```
{operation}_{schema}_{entity} // e.g. read_public_users
{operation}_{entity} // e.g. read_users (when schema is empty)
```
Operations: `read`, `create`, `update`, `delete`.
### Read Tool — `read_{schema}_{entity}`
Fetch one or many records.
| Argument | Type | Description |
|---|---|---|
| `id` | string | Primary key value. Omit to return multiple records. |
| `limit` | number | Max records per page (recommended: 10100). |
| `offset` | number | Records to skip (offset-based pagination). |
| `cursor_forward` | string | PK of the **last** record on the current page (next-page cursor). |
| `cursor_backward` | string | PK of the **first** record on the current page (prev-page cursor). |
| `columns` | array | Column names to include. Omit for all columns. |
| `omit_columns` | array | Column names to exclude. |
| `filters` | array | Filter objects (see [Filtering](#filtering)). |
| `sort` | array | Sort objects (see [Sorting](#sorting)). |
| `preloads` | array | Relation preload objects (see [Preloading](#preloading)). |
**Response:**
```json
{
"success": true,
"data": [...],
"metadata": {
"total": 100,
"filtered": 100,
"count": 10,
"limit": 10,
"offset": 0
}
}
```
### Create Tool — `create_{schema}_{entity}`
Insert one or more records.
| Argument | Type | Description |
|---|---|---|
| `data` | object \| array | Single object or array of objects to insert. |
Array input runs inside a single transaction — all succeed or all fail.
**Response:**
```json
{ "success": true, "data": { ... } }
```
### Update Tool — `update_{schema}_{entity}`
Partially update an existing record. Only non-null, non-empty fields in `data` are applied; existing values are preserved for omitted fields.
| Argument | Type | Description |
|---|---|---|
| `id` | string | Primary key of the record. Can also be included inside `data`. |
| `data` | object (required) | Fields to update. |
**Response:**
```json
{ "success": true, "data": { ...merged record... } }
```
### Delete Tool — `delete_{schema}_{entity}`
Delete a record by primary key. **Irreversible.**
| Argument | Type | Description |
|---|---|---|
| `id` | string (required) | Primary key of the record to delete. |
**Response:**
```json
{ "success": true, "data": { ...deleted record... } }
```
### Annotation Tool — `resolvespec_annotate`
Store or retrieve freeform annotation records for any tool, model, or entity. Registered automatically on every handler.
| Argument | Type | Description |
|---|---|---|
| `tool_name` | string (required) | Key to annotate — an MCP tool name (e.g. `read_public_users`), a model name (e.g. `public.users`), or any other identifier. |
| `annotations` | object | Annotation data to persist. Omit to retrieve existing annotations instead. |
**Set annotations** (calls `resolvespec_set_annotation(tool_name, annotations)`):
```json
{ "tool_name": "read_public_users", "annotations": { "description": "Returns active users", "owner": "platform-team" } }
```
**Response:**
```json
{ "success": true, "tool_name": "read_public_users", "action": "set" }
```
**Get annotations** (calls `resolvespec_get_annotation(tool_name)`):
```json
{ "tool_name": "read_public_users" }
```
**Response:**
```json
{ "success": true, "tool_name": "read_public_users", "action": "get", "annotations": { ... } }
```
---
### Resource — `{schema}.{entity}`
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.
---
## Filtering
Pass an array of filter objects to the `filters` argument:
```json
[
{ "column": "status", "operator": "=", "value": "active" },
{ "column": "age", "operator": ">", "value": 18, "logic_operator": "AND" },
{ "column": "role", "operator": "in", "value": ["admin", "editor"], "logic_operator": "OR" }
]
```
### Supported Operators
| Operator | Aliases | Description |
|---|---|---|
| `=` | `eq` | Equal |
| `!=` | `neq`, `<>` | Not equal |
| `>` | `gt` | Greater than |
| `>=` | `gte` | Greater than or equal |
| `<` | `lt` | Less than |
| `<=` | `lte` | Less than or equal |
| `like` | | SQL LIKE (case-sensitive) |
| `ilike` | | SQL ILIKE (case-insensitive) |
| `in` | | Value in list |
| `is_null` | | Column IS NULL |
| `is_not_null` | | Column IS NOT NULL |
### Logic Operators
- `"logic_operator": "AND"` (default) — filter is AND-chained with the previous condition.
- `"logic_operator": "OR"` — filter is OR-grouped with the previous condition.
Consecutive OR filters are grouped into a single `(cond1 OR cond2 OR ...)` clause.
---
## Sorting
```json
[
{ "column": "created_at", "direction": "desc" },
{ "column": "name", "direction": "asc" }
]
```
---
## Pagination
### Offset-Based
```json
{ "limit": 20, "offset": 40 }
```
### Cursor-Based
Cursor pagination uses a SQL `EXISTS` subquery for stable, efficient paging. Always pair with a `sort` argument.
```json
// Next page: pass the PK of the last record on the current page
{ "cursor_forward": "42", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
// Previous page: pass the PK of the first record on the current page
{ "cursor_backward": "23", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
```
---
## Preloading Relations
```json
[
{ "relation": "Profile" },
{ "relation": "Orders" }
]
```
Available relations are listed in each tool's description. Only relations defined on the model struct are valid.
---
## Hook System
Hooks let you intercept and modify CRUD operations at well-defined lifecycle points.
### Hook Types
| Constant | Fires |
|---|---|
| `BeforeHandle` | After model resolution, before operation dispatch (all CRUD) |
| `BeforeRead` / `AfterRead` | Around read queries |
| `BeforeCreate` / `AfterCreate` | Around insert |
| `BeforeUpdate` / `AfterUpdate` | Around update |
| `BeforeDelete` / `AfterDelete` | Around delete |
### Registering Hooks
```go
handler.Hooks().Register(resolvemcp.BeforeCreate, func(ctx *resolvemcp.HookContext) error {
// Inject a timestamp before insert
if data, ok := ctx.Data.(map[string]interface{}); ok {
data["created_at"] = time.Now()
}
return nil
})
// Register the same hook for multiple events
handler.Hooks().RegisterMultiple(
[]resolvemcp.HookType{resolvemcp.BeforeCreate, resolvemcp.BeforeUpdate},
auditHook,
)
```
### HookContext Fields
| Field | Type | Description |
|---|---|---|
| `Context` | `context.Context` | Request context |
| `Handler` | `*Handler` | The resolvemcp handler |
| `Schema` | `string` | Database schema name |
| `Entity` | `string` | Entity/table name |
| `Model` | `interface{}` | Registered model instance |
| `Options` | `common.RequestOptions` | Parsed request options (read operations) |
| `Operation` | `string` | `"read"`, `"create"`, `"update"`, or `"delete"` |
| `ID` | `string` | Primary key from request (read/update/delete) |
| `Data` | `interface{}` | Input data (create/update — modifiable) |
| `Result` | `interface{}` | Output data (set by After hooks) |
| `Error` | `error` | Operation error, if any |
| `Query` | `common.SelectQuery` | Live query object (available in `BeforeRead`) |
| `Tx` | `common.Database` | Database/transaction handle |
| `Abort` | `bool` | Set to `true` to abort the operation |
| `AbortMessage` | `string` | Error message returned when aborting |
| `AbortCode` | `int` | Optional status code for the abort |
### Aborting an Operation
```go
handler.Hooks().Register(resolvemcp.BeforeDelete, func(ctx *resolvemcp.HookContext) error {
ctx.Abort = true
ctx.AbortMessage = "deletion is disabled"
return nil
})
```
### Managing Hooks
```go
registry := handler.Hooks()
registry.HasHooks(resolvemcp.BeforeCreate) // bool
registry.Clear(resolvemcp.BeforeCreate) // remove hooks for one type
registry.ClearAll() // remove all hooks
```
---
## Context Helpers
Request metadata is threaded through `context.Context` during handler execution. Hooks and custom tools can read it:
```go
schema := resolvemcp.GetSchema(ctx)
entity := resolvemcp.GetEntity(ctx)
tableName := resolvemcp.GetTableName(ctx)
model := resolvemcp.GetModel(ctx)
modelPtr := resolvemcp.GetModelPtr(ctx)
```
You can also set values manually (e.g. in middleware):
```go
ctx = resolvemcp.WithSchema(ctx, "tenant_a")
```
---
## Adding Custom MCP Tools
Access the underlying `*server.MCPServer` to register additional tools:
```go
mcpServer := handler.MCPServer()
mcpServer.AddTool(myTool, myHandler)
```
---
## Table Name Resolution
The handler resolves table names in priority order:
1. `TableNameProvider` interface — `TableName() string` (can return `"schema.table"`)
2. `SchemaProvider` interface — `SchemaName() string` (combined with entity name)
3. Fallback: `schema.entity` (or `schema_entity` for SQLite)

View File

@@ -0,0 +1,107 @@
package resolvemcp
import (
"context"
"encoding/json"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
)
const annotationToolName = "resolvespec_annotate"
// registerAnnotationTool adds the resolvespec_annotate tool to the MCP server.
// The tool lets models/entities store and retrieve freeform annotation records
// using the resolvespec_set_annotation / resolvespec_get_annotation database procedures.
func registerAnnotationTool(h *Handler) {
tool := mcp.NewTool(annotationToolName,
mcp.WithDescription(
"Store or retrieve annotations for any MCP tool, model, or entity.\n\n"+
"To set annotations: provide both 'tool_name' and 'annotations'. "+
"Calls resolvespec_set_annotation(tool_name, annotations) to persist the data.\n\n"+
"To get annotations: provide only 'tool_name'. "+
"Calls resolvespec_get_annotation(tool_name) and returns the stored annotations.\n\n"+
"'tool_name' may be any identifier: an MCP tool name (e.g. 'read_public_users'), "+
"a model/entity name (e.g. 'public.users'), or any other key.",
),
mcp.WithString("tool_name",
mcp.Description("Name of the tool, model, or entity to annotate (e.g. 'read_public_users', 'public.users')."),
mcp.Required(),
),
mcp.WithObject("annotations",
mcp.Description("Annotation data to store. Omit to retrieve existing annotations instead of setting them."),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
toolName, ok := args["tool_name"].(string)
if !ok || toolName == "" {
return mcp.NewToolResultError("missing required argument: tool_name"), nil
}
annotations, hasAnnotations := args["annotations"]
if hasAnnotations && annotations != nil {
return executeSetAnnotation(ctx, h, toolName, annotations)
}
return executeGetAnnotation(ctx, h, toolName)
})
}
func executeSetAnnotation(ctx context.Context, h *Handler, toolName string, annotations interface{}) (*mcp.CallToolResult, error) {
jsonBytes, err := json.Marshal(annotations)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal annotations: %v", err)), nil
}
_, err = h.db.Exec(ctx, "SELECT resolvespec_set_annotation($1, $2)", toolName, string(jsonBytes))
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to set annotation: %v", err)), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"tool_name": toolName,
"action": "set",
})
}
func executeGetAnnotation(ctx context.Context, h *Handler, toolName string) (*mcp.CallToolResult, error) {
var rows []map[string]interface{}
err := h.db.Query(ctx, &rows, "SELECT resolvespec_get_annotation($1)", toolName)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get annotation: %v", err)), nil
}
var annotations interface{}
if len(rows) > 0 {
// The procedure returns a single value; extract the first column of the first row.
for _, v := range rows[0] {
annotations = v
break
}
}
// If the value is a []byte or string containing JSON, decode it so it round-trips cleanly.
switch v := annotations.(type) {
case []byte:
var decoded interface{}
if json.Unmarshal(v, &decoded) == nil {
annotations = decoded
}
case string:
var decoded interface{}
if json.Unmarshal([]byte(v), &decoded) == nil {
annotations = decoded
}
}
return marshalResult(map[string]interface{}{
"success": true,
"tool_name": toolName,
"action": "get",
"annotations": annotations,
})
}

71
pkg/resolvemcp/context.go Normal file
View File

@@ -0,0 +1,71 @@
package resolvemcp
import "context"
type contextKey string
const (
contextKeySchema contextKey = "schema"
contextKeyEntity contextKey = "entity"
contextKeyTableName contextKey = "tableName"
contextKeyModel contextKey = "model"
contextKeyModelPtr contextKey = "modelPtr"
)
func WithSchema(ctx context.Context, schema string) context.Context {
return context.WithValue(ctx, contextKeySchema, schema)
}
func GetSchema(ctx context.Context) string {
if v := ctx.Value(contextKeySchema); v != nil {
return v.(string)
}
return ""
}
func WithEntity(ctx context.Context, entity string) context.Context {
return context.WithValue(ctx, contextKeyEntity, entity)
}
func GetEntity(ctx context.Context) string {
if v := ctx.Value(contextKeyEntity); v != nil {
return v.(string)
}
return ""
}
func WithTableName(ctx context.Context, tableName string) context.Context {
return context.WithValue(ctx, contextKeyTableName, tableName)
}
func GetTableName(ctx context.Context) string {
if v := ctx.Value(contextKeyTableName); v != nil {
return v.(string)
}
return ""
}
func WithModel(ctx context.Context, model interface{}) context.Context {
return context.WithValue(ctx, contextKeyModel, model)
}
func GetModel(ctx context.Context) interface{} {
return ctx.Value(contextKeyModel)
}
func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context {
return context.WithValue(ctx, contextKeyModelPtr, modelPtr)
}
func GetModelPtr(ctx context.Context) interface{} {
return ctx.Value(contextKeyModelPtr)
}
func withRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
ctx = WithSchema(ctx, schema)
ctx = WithEntity(ctx, entity)
ctx = WithTableName(ctx, tableName)
ctx = WithModel(ctx, model)
ctx = WithModelPtr(ctx, modelPtr)
return ctx
}

161
pkg/resolvemcp/cursor.go Normal file
View File

@@ -0,0 +1,161 @@
package resolvemcp
// Cursor-based pagination adapted from pkg/resolvespec/cursor.go.
import (
"fmt"
"strings"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
type cursorDirection int
const (
cursorForward cursorDirection = 1
cursorBackward cursorDirection = -1
)
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support.
func getCursorFilter(
tableName string,
pkName string,
modelColumns []string,
options common.RequestOptions,
expandJoins map[string]string,
) (string, error) {
fullTableName := tableName
if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1]
}
cursorID, direction := getActiveCursor(options)
if cursorID == "" {
return "", fmt.Errorf("no cursor provided for table %s", tableName)
}
sortItems := options.Sort
if len(sortItems) == 0 {
return "", fmt.Errorf("no sort columns defined")
}
var whereClauses []string
joinSQL := ""
reverse := direction < 0
for _, s := range sortItems {
col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" {
continue
}
parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".")
desc := strings.EqualFold(s.Direction, "desc")
if reverse {
desc = !desc
}
cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
if err != nil {
logger.Warn("Skipping invalid sort column %q: %v", col, err)
continue
}
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
op := "<"
if desc {
op = ">"
}
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
}
if len(whereClauses) == 0 {
return "", fmt.Errorf("no valid sort columns after filtering")
}
orSQL := buildCursorPriorityChain(whereClauses)
query := fmt.Sprintf(`EXISTS (
SELECT 1
FROM %s cursor_select
%s
WHERE cursor_select.%s = %s
AND (%s)
)`,
fullTableName,
joinSQL,
pkName,
cursorID,
orSQL,
)
return query, nil
}
func getActiveCursor(options common.RequestOptions) (id string, direction cursorDirection) {
if options.CursorForward != "" {
return options.CursorForward, cursorForward
}
if options.CursorBackward != "" {
return options.CursorBackward, cursorBackward
}
return "", 0
}
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) {
if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, false, nil
}
if modelColumns != nil {
for _, col := range modelColumns {
if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, false, nil
}
}
} else {
return "cursor_select." + field, tableName + "." + field, false, nil
}
if prefix != "" && prefix != tableName {
return "", "", true, nil
}
return "", "", false, fmt.Errorf("invalid column: %s", field)
}
func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
}
func buildCursorPriorityChain(clauses []string) string {
var or []string
for i := 0; i < len(clauses); i++ {
and := strings.Join(clauses[:i+1], "\n AND ")
or = append(or, "("+and+")")
}
return strings.Join(or, "\n OR ")
}

743
pkg/resolvemcp/handler.go Normal file
View File

@@ -0,0 +1,743 @@
package resolvemcp
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"reflect"
"strings"
"sync"
"github.com/mark3labs/mcp-go/server"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// Handler exposes registered database models as MCP tools and resources.
type Handler struct {
db common.Database
registry common.ModelRegistry
hooks *HookRegistry
mcpServer *server.MCPServer
config Config
name string
version string
}
// NewHandler creates a Handler with the given database, model registry, and config.
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
h := &Handler{
db: db,
registry: registry,
hooks: NewHookRegistry(),
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
config: cfg,
name: "resolvemcp",
version: "1.0.0",
}
registerAnnotationTool(h)
return h
}
// Hooks returns the hook registry.
func (h *Handler) Hooks() *HookRegistry {
return h.hooks
}
// GetDatabase returns the underlying database.
func (h *Handler) GetDatabase() common.Database {
return h.db
}
// MCPServer returns the underlying MCP server, e.g. to add custom tools.
func (h *Handler) MCPServer() *server.MCPServer {
return h.mcpServer
}
// SSEServer returns an http.Handler that serves MCP over SSE.
// Config.BasePath must be set. Config.BaseURL is used when set; if empty it is
// detected automatically from each incoming request.
func (h *Handler) SSEServer() http.Handler {
if h.config.BaseURL != "" {
return h.newSSEServer(h.config.BaseURL, h.config.BasePath)
}
return &dynamicSSEHandler{h: h}
}
// StreamableHTTPServer returns an http.Handler that serves MCP over the streamable HTTP transport.
// Unlike SSE (which requires two endpoints), streamable HTTP uses a single endpoint for all
// client-server communication (POST for requests, GET for server-initiated messages).
// Mount the returned handler at the desired path; the path itself becomes the MCP endpoint.
func (h *Handler) StreamableHTTPServer() http.Handler {
return server.NewStreamableHTTPServer(h.mcpServer)
}
// newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values.
func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
return server.NewSSEServer(
h.mcpServer,
server.WithBaseURL(baseURL),
server.WithBasePath(basePath),
)
}
// dynamicSSEHandler detects BaseURL from each request and delegates to a cached
// *server.SSEServer per detected baseURL. Used when Config.BaseURL is empty.
type dynamicSSEHandler struct {
h *Handler
mu sync.Mutex
pool map[string]*server.SSEServer
}
func (d *dynamicSSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
baseURL := requestBaseURL(r)
d.mu.Lock()
if d.pool == nil {
d.pool = make(map[string]*server.SSEServer)
}
s, ok := d.pool[baseURL]
if !ok {
s = d.h.newSSEServer(baseURL, d.h.config.BasePath)
d.pool[baseURL] = s
}
d.mu.Unlock()
s.ServeHTTP(w, r)
}
// requestBaseURL builds the base URL from an incoming request.
// It honours the X-Forwarded-Proto header for deployments behind a proxy.
func requestBaseURL(r *http.Request) string {
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
scheme = proto
}
return scheme + "://" + r.Host
}
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
fullName := buildModelName(schema, entity)
if err := h.registry.RegisterModel(fullName, model); err != nil {
return err
}
registerModelTools(h, schema, entity, model)
return nil
}
// RegisterModelWithRules registers a model and sets per-entity operation rules
// (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*, SecurityDisabled).
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
func (h *Handler) RegisterModelWithRules(schema, entity string, model interface{}, rules modelregistry.ModelRules) error {
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
if !ok {
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
}
fullName := buildModelName(schema, entity)
if err := reg.RegisterModelWithRules(fullName, model, rules); err != nil {
return err
}
registerModelTools(h, schema, entity, model)
return nil
}
// SetModelRules updates the operation rules for an already-registered model.
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
func (h *Handler) SetModelRules(schema, entity string, rules modelregistry.ModelRules) error {
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
if !ok {
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
}
return reg.SetModelRules(buildModelName(schema, entity), rules)
}
// buildModelName builds the registry key for a model (same format as resolvespec).
func buildModelName(schema, entity string) string {
if schema == "" {
return entity
}
return fmt.Sprintf("%s.%s", schema, entity)
}
// getTableName returns the fully qualified table name for a model.
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
if schemaName != "" {
if h.db.DriverName() == "sqlite" {
return fmt.Sprintf("%s_%s", schemaName, tableName)
}
return fmt.Sprintf("%s.%s", schemaName, tableName)
}
return tableName
}
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
if tableProvider, ok := model.(common.TableNameProvider); ok {
tableName := tableProvider.TableName()
if idx := strings.LastIndex(tableName, "."); idx != -1 {
return tableName[:idx], tableName[idx+1:]
}
if schemaProvider, ok := model.(common.SchemaProvider); ok {
return schemaProvider.SchemaName(), tableName
}
return defaultSchema, tableName
}
if schemaProvider, ok := model.(common.SchemaProvider); ok {
return schemaProvider.SchemaName(), entity
}
return defaultSchema, entity
}
// executeRead reads records from the database and returns raw data + metadata.
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (interface{}, *common.Metadata, error) {
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, nil, fmt.Errorf("model not found: %w", err)
}
unwrapped, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, nil, fmt.Errorf("invalid model: %w", err)
}
model = unwrapped.Model
modelType := unwrapped.ModelType
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, unwrapped.ModelPtr)
validator := common.NewColumnValidator(model)
options = validator.FilterRequestOptions(options)
// BeforeHandle hook
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "read",
Options: options,
ID: id,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, nil, err
}
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
modelPtr := reflect.New(sliceType).Interface()
query := h.db.NewSelect().Model(modelPtr)
tempInstance := reflect.New(modelType).Interface()
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
query = query.Table(tableName)
}
// Column selection
if len(options.Columns) == 0 && len(options.ComputedColumns) > 0 {
options.Columns = reflection.GetSQLModelColumns(model)
}
for _, col := range options.Columns {
query = query.Column(reflection.ExtractSourceColumn(col))
}
for _, cu := range options.ComputedColumns {
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
}
// Preloads
if len(options.Preload) > 0 {
var err error
query, err = h.applyPreloads(model, query, options.Preload)
if err != nil {
return nil, nil, fmt.Errorf("failed to apply preloads: %w", err)
}
}
// Filters
query = h.applyFilters(query, options.Filters)
// Custom operators
for _, customOp := range options.CustomOperators {
query = query.Where(customOp.SQL)
}
// Sorting
for _, sort := range options.Sort {
direction := "ASC"
if strings.EqualFold(sort.Direction, "desc") {
direction = "DESC"
}
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
}
// Cursor pagination
if options.CursorForward != "" || options.CursorBackward != "" {
pkName := reflection.GetPrimaryKeyName(model)
modelColumns := reflection.GetModelColumns(model)
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// expandJoins is empty for resolvemcp — no custom SQL join support yet
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
return nil, nil, fmt.Errorf("cursor error: %w", err)
}
if cursorFilter != "" {
sanitized := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
sanitized = common.EnsureOuterParentheses(sanitized)
if sanitized != "" {
query = query.Where(sanitized)
}
}
}
// Count
total, err := query.Count(ctx)
if err != nil {
return nil, nil, fmt.Errorf("error counting records: %w", err)
}
// Pagination
if options.Limit != nil && *options.Limit > 0 {
query = query.Limit(*options.Limit)
}
if options.Offset != nil && *options.Offset > 0 {
query = query.Offset(*options.Offset)
}
// BeforeRead hook
hookCtx.Query = query
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
return nil, nil, err
}
var data interface{}
if id != "" {
singleResult := reflect.New(modelType).Interface()
pkName := reflection.GetPrimaryKeyName(singleResult)
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := query.Scan(ctx, singleResult); err != nil {
if err == sql.ErrNoRows {
return nil, nil, fmt.Errorf("record not found")
}
return nil, nil, fmt.Errorf("query error: %w", err)
}
data = singleResult
} else {
if err := query.Scan(ctx, modelPtr); err != nil {
return nil, nil, fmt.Errorf("query error: %w", err)
}
data = reflect.ValueOf(modelPtr).Elem().Interface()
}
limit := 0
offset := 0
if options.Limit != nil {
limit = *options.Limit
}
if options.Offset != nil {
offset = *options.Offset
}
// Count is the number of records in this page, not the total.
var pageCount int64
if id != "" {
pageCount = 1
} else {
pageCount = int64(reflect.ValueOf(data).Len())
}
metadata := &common.Metadata{
Total: int64(total),
Filtered: int64(total),
Count: pageCount,
Limit: limit,
Offset: offset,
}
// AfterRead hook
hookCtx.Result = data
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
return nil, nil, err
}
return data, metadata, nil
}
// executeCreate inserts one or more records.
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (interface{}, error) {
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "create",
Data: data,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, err
}
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
return nil, err
}
// Use potentially modified data
data = hookCtx.Data
switch v := data.(type) {
case map[string]interface{}:
query := h.db.NewInsert().Table(tableName)
for key, value := range v {
query = query.Value(key, value)
}
if _, err := query.Exec(ctx); err != nil {
return nil, fmt.Errorf("create error: %w", err)
}
hookCtx.Result = v
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
}
return v, nil
case []interface{}:
results := make([]interface{}, 0, len(v))
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
for _, item := range v {
itemMap, ok := item.(map[string]interface{})
if !ok {
return fmt.Errorf("each item must be an object")
}
q := tx.NewInsert().Table(tableName)
for key, value := range itemMap {
q = q.Value(key, value)
}
if _, err := q.Exec(ctx); err != nil {
return err
}
results = append(results, item)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("batch create error: %w", err)
}
hookCtx.Result = results
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
}
return results, nil
default:
return nil, fmt.Errorf("data must be an object or array of objects")
}
}
// executeUpdate updates a record by ID.
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (interface{}, error) {
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
updates, ok := data.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("data must be an object")
}
if id == "" {
if idVal, exists := updates["id"]; exists {
id = fmt.Sprintf("%v", idVal)
}
}
if id == "" {
return nil, fmt.Errorf("update requires an ID")
}
pkName := reflection.GetPrimaryKeyName(model)
var updateResult interface{}
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
// Read existing record
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
existingRecord := reflect.New(modelType).Interface()
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("no records found to update")
}
return fmt.Errorf("error fetching existing record: %w", err)
}
// Convert to map
existingMap := make(map[string]interface{})
jsonData, err := json.Marshal(existingRecord)
if err != nil {
return fmt.Errorf("error marshaling existing record: %w", err)
}
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
return fmt.Errorf("error unmarshaling existing record: %w", err)
}
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "update",
ID: id,
Data: updates,
Tx: tx,
}
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
return err
}
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
updates = modifiedData
}
// Merge non-nil, non-empty values
for key, newValue := range updates {
if newValue == nil {
continue
}
if strVal, ok := newValue.(string); ok && strVal == "" {
continue
}
existingMap[key] = newValue
}
q := tx.NewUpdate().Table(tableName).SetMap(existingMap).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
res, err := q.Exec(ctx)
if err != nil {
return fmt.Errorf("error updating record: %w", err)
}
if res.RowsAffected() == 0 {
return fmt.Errorf("no records found to update")
}
updateResult = existingMap
hookCtx.Result = updateResult
return h.hooks.Execute(AfterUpdate, hookCtx)
})
if err != nil {
return nil, err
}
return updateResult, nil
}
// executeDelete deletes a record by ID.
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (interface{}, error) {
if id == "" {
return nil, fmt.Errorf("delete requires an ID")
}
model, err := h.registry.GetModelByEntity(schema, entity)
if err != nil {
return nil, fmt.Errorf("model not found: %w", err)
}
result, err := common.ValidateAndUnwrapModel(model)
if err != nil {
return nil, fmt.Errorf("invalid model: %w", err)
}
model = result.Model
tableName := h.getTableName(schema, entity, model)
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
pkName := reflection.GetPrimaryKeyName(model)
hookCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Operation: "delete",
ID: id,
Tx: h.db,
}
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
return nil, err
}
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
return nil, err
}
modelType := reflect.TypeOf(model)
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
var recordToDelete interface{}
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
record := reflect.New(modelType).Interface()
selectQuery := tx.NewSelect().Model(record).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
if err := selectQuery.ScanModel(ctx); err != nil {
if err == sql.ErrNoRows {
return fmt.Errorf("record not found")
}
return fmt.Errorf("error fetching record: %w", err)
}
res, err := tx.NewDelete().Table(tableName).
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
Exec(ctx)
if err != nil {
return fmt.Errorf("delete error: %w", err)
}
if res.RowsAffected() == 0 {
return fmt.Errorf("record not found or already deleted")
}
recordToDelete = record
hookCtx.Tx = tx
hookCtx.Result = record
return h.hooks.Execute(AfterDelete, hookCtx)
})
if err != nil {
return nil, err
}
logger.Info("[resolvemcp] Deleted record %s from %s.%s", id, schema, entity)
return recordToDelete, nil
}
// applyFilters applies all filters with OR grouping logic.
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
if len(filters) == 0 {
return query
}
i := 0
for i < len(filters) {
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
if startORGroup {
orGroup := []common.FilterOption{filters[i]}
j := i + 1
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
orGroup = append(orGroup, filters[j])
j++
}
query = h.applyFilterGroup(query, orGroup)
i = j
} else {
condition, args := h.buildFilterCondition(filters[i])
if condition != "" {
query = query.Where(condition, args...)
}
i++
}
}
return query
}
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
var conditions []string
var args []interface{}
for _, filter := range filters {
condition, filterArgs := h.buildFilterCondition(filter)
if condition != "" {
conditions = append(conditions, condition)
args = append(args, filterArgs...)
}
}
if len(conditions) == 0 {
return query
}
if len(conditions) == 1 {
return query.Where(conditions[0], args...)
}
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
}
func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) {
switch filter.Operator {
case "eq", "=":
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
case "neq", "!=", "<>":
return fmt.Sprintf("%s != ?", filter.Column), []interface{}{filter.Value}
case "gt", ">":
return fmt.Sprintf("%s > ?", filter.Column), []interface{}{filter.Value}
case "gte", ">=":
return fmt.Sprintf("%s >= ?", filter.Column), []interface{}{filter.Value}
case "lt", "<":
return fmt.Sprintf("%s < ?", filter.Column), []interface{}{filter.Value}
case "lte", "<=":
return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value}
case "like":
return fmt.Sprintf("%s LIKE ?", filter.Column), []interface{}{filter.Value}
case "ilike":
return fmt.Sprintf("%s ILIKE ?", filter.Column), []interface{}{filter.Value}
case "in":
condition, args := common.BuildInCondition(filter.Column, filter.Value)
return condition, args
case "is_null":
return fmt.Sprintf("%s IS NULL", filter.Column), nil
case "is_not_null":
return fmt.Sprintf("%s IS NOT NULL", filter.Column), nil
}
return "", nil
}
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
for _, preload := range preloads {
if preload.Relation == "" {
continue
}
query = query.PreloadRelation(preload.Relation)
}
return query, nil
}

113
pkg/resolvemcp/hooks.go Normal file
View File

@@ -0,0 +1,113 @@
package resolvemcp
import (
"context"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
)
// HookType defines the type of hook to execute
type HookType string
const (
// BeforeHandle fires after model resolution, before operation dispatch.
BeforeHandle HookType = "before_handle"
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
BeforeCreate HookType = "before_create"
AfterCreate HookType = "after_create"
BeforeUpdate HookType = "before_update"
AfterUpdate HookType = "after_update"
BeforeDelete HookType = "before_delete"
AfterDelete HookType = "after_delete"
)
// HookContext contains all the data available to a hook
type HookContext struct {
Context context.Context
Handler *Handler
Schema string
Entity string
Model interface{}
Options common.RequestOptions
Operation string
ID string
Data interface{}
Result interface{}
Error error
Query common.SelectQuery
Abort bool
AbortMessage string
AbortCode int
Tx common.Database
}
// HookFunc is the signature for hook functions
type HookFunc func(*HookContext) error
// HookRegistry manages all registered hooks
type HookRegistry struct {
hooks map[HookType][]HookFunc
}
func NewHookRegistry() *HookRegistry {
return &HookRegistry{
hooks: make(map[HookType][]HookFunc),
}
}
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
if r.hooks == nil {
r.hooks = make(map[HookType][]HookFunc)
}
r.hooks[hookType] = append(r.hooks[hookType], hook)
logger.Info("Registered resolvemcp hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
}
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
for _, hookType := range hookTypes {
r.Register(hookType, hook)
}
}
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
hooks, exists := r.hooks[hookType]
if !exists || len(hooks) == 0 {
return nil
}
logger.Debug("Executing %d resolvemcp hook(s) for %s", len(hooks), hookType)
for i, hook := range hooks {
if err := hook(ctx); err != nil {
logger.Error("resolvemcp hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
if ctx.Abort {
logger.Warn("resolvemcp hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil
}
func (r *HookRegistry) Clear(hookType HookType) {
delete(r.hooks, hookType)
}
func (r *HookRegistry) ClearAll() {
r.hooks = make(map[HookType][]HookFunc)
}
func (r *HookRegistry) HasHooks(hookType HookType) bool {
hooks, exists := r.hooks[hookType]
return exists && len(hooks) > 0
}

View File

@@ -0,0 +1,133 @@
// Package resolvemcp exposes registered database models as Model Context Protocol (MCP) tools
// and resources over HTTP/SSE transport.
//
// It mirrors the resolvespec package patterns:
// - Same model registration API
// - Same filter, sort, cursor pagination, preload options
// - Same lifecycle hook system
//
// Usage:
//
// handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{BaseURL: "http://localhost:8080"})
// handler.RegisterModel("public", "users", &User{})
//
// r := mux.NewRouter()
// resolvemcp.SetupMuxRoutes(r, handler)
package resolvemcp
import (
"net/http"
"github.com/gorilla/mux"
"github.com/uptrace/bun"
bunrouter "github.com/uptrace/bunrouter"
"gorm.io/gorm"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
)
// Config holds configuration for the resolvemcp handler.
type Config struct {
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
BaseURL string
// BasePath is the URL path prefix where the MCP endpoints are mounted (e.g. "/mcp").
// If empty, the path is detected from each incoming request automatically.
BasePath string
}
// NewHandlerWithGORM creates a Handler backed by a GORM database connection.
func NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler {
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry(), cfg)
}
// NewHandlerWithBun creates a Handler backed by a Bun database connection.
func NewHandlerWithBun(db *bun.DB, cfg Config) *Handler {
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry(), cfg)
}
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
func NewHandlerWithDB(db common.Database, cfg Config) *Handler {
return NewHandler(db, modelregistry.NewModelRegistry(), cfg)
}
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router
// using the base path from Config.BasePath (falls back to "/mcp" if empty).
//
// Two routes are registered:
// - GET {basePath}/sse — SSE connection endpoint (client subscribes here)
// - POST {basePath}/message — JSON-RPC message endpoint (client sends requests here)
//
// To protect these routes with authentication, wrap the mux router or apply middleware
// before calling SetupMuxRoutes.
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.SSEServer()
muxRouter.Handle(basePath+"/sse", h).Methods("GET", "OPTIONS")
muxRouter.Handle(basePath+"/message", h).Methods("POST", "OPTIONS")
// Convenience: also expose the full SSE server at basePath for clients that
// use ServeHTTP directly (e.g. net/http default mux).
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}
// SetupBunRouterRoutes mounts the MCP HTTP/SSE endpoints on a bunrouter router
// using the base path from Config.BasePath.
//
// Two routes are registered:
// - GET {basePath}/sse — SSE connection endpoint
// - POST {basePath}/message — JSON-RPC message endpoint
func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.SSEServer()
router.GET(basePath+"/sse", bunrouter.HTTPHandler(h))
router.POST(basePath+"/message", bunrouter.HTTPHandler(h))
}
// NewSSEServer returns an http.Handler that serves MCP over SSE.
// If Config.BasePath is set it is used directly; otherwise the base path is
// detected from each incoming request (by stripping the "/sse" or "/message" suffix).
//
// h := resolvemcp.NewSSEServer(handler)
// http.Handle("/api/mcp/", h)
func NewSSEServer(handler *Handler) http.Handler {
return handler.SSEServer()
}
// SetupMuxStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on the given Gorilla Mux router.
// The streamable HTTP transport uses a single endpoint (Config.BasePath) for all communication:
// POST for client→server messages, GET for server→client streaming.
//
// Example:
//
// resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler) // mounts at Config.BasePath
func SetupMuxStreamableHTTPRoutes(muxRouter *mux.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.StreamableHTTPServer()
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
}
// SetupBunRouterStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on a bunrouter router.
// The streamable HTTP transport uses a single endpoint (Config.BasePath).
func SetupBunRouterStreamableHTTPRoutes(router *bunrouter.Router, handler *Handler) {
basePath := handler.config.BasePath
h := handler.StreamableHTTPServer()
router.GET(basePath, bunrouter.HTTPHandler(h))
router.POST(basePath, bunrouter.HTTPHandler(h))
router.DELETE(basePath, bunrouter.HTTPHandler(h))
}
// NewStreamableHTTPHandler returns an http.Handler that serves MCP over the streamable HTTP transport.
// Mount it at the desired path; that path becomes the MCP endpoint.
//
// h := resolvemcp.NewStreamableHTTPHandler(handler)
// http.Handle("/mcp", h)
// engine.Any("/mcp", gin.WrapH(h))
func NewStreamableHTTPHandler(handler *Handler) http.Handler {
return handler.StreamableHTTPServer()
}

View File

@@ -0,0 +1,115 @@
package resolvemcp
import (
"context"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
)
// RegisterSecurityHooks wires the security package's access-control layer into the
// resolvemcp handler. Call it once after creating the handler, before registering models.
//
// The following controls are applied:
// - Per-entity operation rules (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*)
// stored via RegisterModelWithRules / SetModelRules.
// - Row-level security: WHERE clause injected per user from the SecurityList provider.
// - Column-level security: sensitive columns masked/hidden in read results.
// - Audit logging after each read.
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// BeforeHandle: enforce model-level operation rules (auth check).
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
hookCtx.Abort = true
hookCtx.AbortMessage = err.Error()
hookCtx.AbortCode = http.StatusUnauthorized
return err
}
return nil
})
// BeforeRead (1st): load RLS + CLS rules from the provider into SecurityList.
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
return security.LoadSecurityRules(newSecurityContext(hookCtx), securityList)
})
// BeforeRead (2nd): apply row-level security — injects a WHERE clause into the query.
// resolvemcp has no separate BeforeScan hook; the query is available in BeforeRead.
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
return security.ApplyRowSecurity(newSecurityContext(hookCtx), securityList)
})
// AfterRead (1st): apply column-level security — mask/hide columns in the result.
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
return security.ApplyColumnSecurity(newSecurityContext(hookCtx), securityList)
})
// AfterRead (2nd): audit log.
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
return security.LogDataAccess(newSecurityContext(hookCtx))
})
// BeforeUpdate: enforce CanUpdate rule.
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
return security.CheckModelUpdateAllowed(newSecurityContext(hookCtx))
})
// BeforeDelete: enforce CanDelete rule.
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
return security.CheckModelDeleteAllowed(newSecurityContext(hookCtx))
})
logger.Info("Security hooks registered for resolvemcp handler")
}
// --------------------------------------------------------------------------
// securityContext — adapts resolvemcp.HookContext to security.SecurityContext
// --------------------------------------------------------------------------
type securityContext struct {
ctx *HookContext
}
func newSecurityContext(ctx *HookContext) security.SecurityContext {
return &securityContext{ctx: ctx}
}
func (s *securityContext) GetContext() context.Context {
return s.ctx.Context
}
func (s *securityContext) GetUserID() (int, bool) {
return security.GetUserID(s.ctx.Context)
}
func (s *securityContext) GetSchema() string {
return s.ctx.Schema
}
func (s *securityContext) GetEntity() string {
return s.ctx.Entity
}
func (s *securityContext) GetModel() interface{} {
return s.ctx.Model
}
func (s *securityContext) GetQuery() interface{} {
return s.ctx.Query
}
func (s *securityContext) SetQuery(query interface{}) {
if q, ok := query.(common.SelectQuery); ok {
s.ctx.Query = q
}
}
func (s *securityContext) GetResult() interface{} {
return s.ctx.Result
}
func (s *securityContext) SetResult(result interface{}) {
s.ctx.Result = result
}

692
pkg/resolvemcp/tools.go Normal file
View File

@@ -0,0 +1,692 @@
package resolvemcp
import (
"context"
"encoding/json"
"fmt"
"reflect"
"strings"
"github.com/mark3labs/mcp-go/mcp"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/reflection"
)
// toolName builds the MCP tool name for a given operation and model.
func toolName(operation, schema, entity string) string {
if schema == "" {
return fmt.Sprintf("%s_%s", operation, entity)
}
return fmt.Sprintf("%s_%s_%s", operation, schema, entity)
}
// registerModelTools registers the four CRUD tools and resource for a model.
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
info := buildModelInfo(schema, entity, model)
registerReadTool(h, schema, entity, info)
registerCreateTool(h, schema, entity, info)
registerUpdateTool(h, schema, entity, info)
registerDeleteTool(h, schema, entity, info)
registerModelResource(h, schema, entity, info)
logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName)
}
// --------------------------------------------------------------------------
// Model introspection
// --------------------------------------------------------------------------
// modelInfo holds pre-computed metadata for a model used in tool descriptions.
type modelInfo struct {
fullName string // e.g. "public.users"
pkName string // e.g. "id"
columns []columnInfo
relationNames []string
schemaDoc string // formatted multi-line schema listing
}
type columnInfo struct {
jsonName string
sqlName string
goType string
sqlType string
isPrimary bool
isUnique bool
isFK bool
nullable bool
}
// buildModelInfo extracts column metadata and pre-builds the schema documentation string.
func buildModelInfo(schema, entity string, model interface{}) modelInfo {
info := modelInfo{
fullName: buildModelName(schema, entity),
pkName: reflection.GetPrimaryKeyName(model),
}
// Unwrap to base struct type
modelType := reflect.TypeOf(model)
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return info
}
details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem())
for _, d := range details {
// Derive the JSON name from the struct field
jsonName := fieldJSONName(modelType, d.Name)
if jsonName == "" || jsonName == "-" {
continue
}
// Skip relation fields (slice or user-defined struct that isn't time.Time).
fieldType, found := modelType.FieldByName(d.Name)
if found {
ft := fieldType.Type
if ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
if ft.Kind() == reflect.Slice || isUserStruct {
info.relationNames = append(info.relationNames, jsonName)
continue
}
}
sqlName := d.SQLName
if sqlName == "" {
sqlName = jsonName
}
// Derive Go type name, unwrapping pointer if needed.
goType := d.DataType
if goType == "" && found {
ft := fieldType.Type
for ft.Kind() == reflect.Ptr {
ft = ft.Elem()
}
goType = ft.Name()
}
// isPrimary: use both the GORM-tag detection and a name comparison against
// the known primary key (handles camelCase "primaryKey" tags correctly).
isPrimary := d.SQLKey == "primary_key" ||
(info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName))
ci := columnInfo{
jsonName: jsonName,
sqlName: sqlName,
goType: goType,
sqlType: d.SQLDataType,
isPrimary: isPrimary,
isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex",
isFK: d.SQLKey == "foreign_key",
nullable: d.Nullable,
}
info.columns = append(info.columns, ci)
}
info.schemaDoc = buildSchemaDoc(info)
return info
}
// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name.
func fieldJSONName(modelType reflect.Type, fieldName string) string {
field, ok := modelType.FieldByName(fieldName)
if !ok {
return fieldName
}
tag := field.Tag.Get("json")
if tag == "" {
return fieldName
}
parts := strings.SplitN(tag, ",", 2)
if parts[0] == "" {
return fieldName
}
return parts[0]
}
// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions.
func buildSchemaDoc(info modelInfo) string {
if len(info.columns) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("Columns:\n")
for _, c := range info.columns {
line := fmt.Sprintf(" • %s", c.jsonName)
typeDesc := c.goType
if c.sqlType != "" {
typeDesc = c.sqlType
}
if typeDesc != "" {
line += fmt.Sprintf(" (%s)", typeDesc)
}
var flags []string
if c.isPrimary {
flags = append(flags, "primary key")
}
if c.isUnique {
flags = append(flags, "unique")
}
if c.isFK {
flags = append(flags, "foreign key")
}
if !c.nullable && !c.isPrimary {
flags = append(flags, "not null")
} else if c.nullable {
flags = append(flags, "nullable")
}
if len(flags) > 0 {
line += " — " + strings.Join(flags, ", ")
}
sb.WriteString(line + "\n")
}
if len(info.relationNames) > 0 {
sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n")
}
return sb.String()
}
// columnNameList returns a comma-separated list of JSON column names (for descriptions).
func columnNameList(cols []columnInfo) string {
names := make([]string, len(cols))
for i, c := range cols {
names[i] = c.jsonName
}
return strings.Join(names, ", ")
}
// writableColumnNames returns JSON names for all non-primary-key columns.
func writableColumnNames(cols []columnInfo) []string {
var names []string
for _, c := range cols {
if !c.isPrimary {
names = append(names, c.jsonName)
}
}
return names
}
// --------------------------------------------------------------------------
// Read tool
// --------------------------------------------------------------------------
func registerReadTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("read", schema, entity)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName))
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName))
}
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
descParts = append(descParts,
"Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.",
"Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.",
"Sorting: each sort object requires 'column' and 'direction' (asc or desc).",
)
if len(info.relationNames) > 0 {
descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", ")))
}
description := strings.Join(descParts, "\n\n")
filterDesc := `Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`
if len(info.columns) > 0 {
filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
}
sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]`
if len(info.columns) > 0 {
sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
}
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)),
),
mcp.WithNumber("limit",
mcp.Description("Maximum number of records to return per page. Recommended: 10100."),
),
mcp.WithNumber("offset",
mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."),
),
mcp.WithString("cursor_forward",
mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)),
),
mcp.WithString("cursor_backward",
mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)),
),
mcp.WithArray("columns",
mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))),
),
mcp.WithArray("omit_columns",
mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))),
),
mcp.WithArray("filters",
mcp.Description(filterDesc),
),
mcp.WithArray("sort",
mcp.Description(sortDesc),
),
mcp.WithArray("preloads",
mcp.Description(buildPreloadDesc(info)),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
options := parseRequestOptions(args)
data, metadata, err := h.executeRead(ctx, schema, entity, id, options)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": data,
"metadata": metadata,
})
})
}
func buildPreloadDesc(info modelInfo) string {
if len(info.relationNames) == 0 {
return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.`
}
return fmt.Sprintf(
`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`,
strings.Join(info.relationNames, ", "),
)
}
// --------------------------------------------------------------------------
// Create tool
// --------------------------------------------------------------------------
func registerCreateTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("create", schema, entity)
writable := writableColumnNames(info.columns)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName))
if len(writable) > 0 {
descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", ")))
}
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName))
}
descParts = append(descParts,
"Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).",
)
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
description := strings.Join(descParts, "\n\n")
dataDesc := "Record fields to create."
if len(writable) > 0 {
dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", "))
}
dataDesc += " Pass a single object or an array of objects."
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithObject("data",
mcp.Description(dataDesc),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
data, ok := args["data"]
if !ok {
return mcp.NewToolResultError("missing required argument: data"), nil
}
result, err := h.executeCreate(ctx, schema, entity, data)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Update tool
// --------------------------------------------------------------------------
func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("update", schema, entity)
writable := writableColumnNames(info.columns)
var descParts []string
descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName))
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName))
}
if len(writable) > 0 {
descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", ")))
}
descParts = append(descParts,
"Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.",
)
if info.schemaDoc != "" {
descParts = append(descParts, info.schemaDoc)
}
description := strings.Join(descParts, "\n\n")
idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName)
dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)."
if len(writable) > 0 {
dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", "))
}
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description(idDesc),
),
mcp.WithObject("data",
mcp.Description(dataDesc),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
data, ok := args["data"]
if !ok {
return mcp.NewToolResultError("missing required argument: data"), nil
}
dataMap, ok := data.(map[string]interface{})
if !ok {
return mcp.NewToolResultError("data must be an object"), nil
}
result, err := h.executeUpdate(ctx, schema, entity, id, dataMap)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Delete tool
// --------------------------------------------------------------------------
func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) {
name := toolName("delete", schema, entity)
descParts := []string{
fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName),
}
if info.pkName != "" {
descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName))
}
descParts = append(descParts, "Returns the deleted record. This operation is irreversible.")
description := strings.Join(descParts, " ")
tool := mcp.NewTool(name,
mcp.WithDescription(description),
mcp.WithString("id",
mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)),
mcp.Required(),
),
)
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
args := req.GetArguments()
id, _ := args["id"].(string)
result, err := h.executeDelete(ctx, schema, entity, id)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
return marshalResult(map[string]interface{}{
"success": true,
"data": result,
})
})
}
// --------------------------------------------------------------------------
// Resource registration
// --------------------------------------------------------------------------
func registerModelResource(h *Handler, schema, entity string, info modelInfo) {
resourceURI := info.fullName
var resourceDesc strings.Builder
fmt.Fprintf(&resourceDesc, "Database table: %s", info.fullName)
if info.pkName != "" {
fmt.Fprintf(&resourceDesc, " (primary key: %s)", info.pkName)
}
if info.schemaDoc != "" {
resourceDesc.WriteString("\n\n")
resourceDesc.WriteString(info.schemaDoc)
}
resource := mcp.NewResource(
resourceURI,
entity,
mcp.WithResourceDescription(resourceDesc.String()),
mcp.WithMIMEType("application/json"),
)
h.mcpServer.AddResource(resource, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
limit := 100
options := common.RequestOptions{Limit: &limit}
data, metadata, err := h.executeRead(ctx, schema, entity, "", options)
if err != nil {
return nil, err
}
payload := map[string]interface{}{
"data": data,
"metadata": metadata,
}
jsonBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling resource: %w", err)
}
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: req.Params.URI,
MIMEType: "application/json",
Text: string(jsonBytes),
},
}, nil
})
}
// --------------------------------------------------------------------------
// Argument parsing helpers
// --------------------------------------------------------------------------
// parseRequestOptions converts raw MCP tool arguments into common.RequestOptions.
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
options := common.RequestOptions{}
if v, ok := args["limit"]; ok {
switch n := v.(type) {
case float64:
limit := int(n)
options.Limit = &limit
case int:
options.Limit = &n
}
}
if v, ok := args["offset"]; ok {
switch n := v.(type) {
case float64:
offset := int(n)
options.Offset = &offset
case int:
options.Offset = &n
}
}
if v, ok := args["cursor_forward"].(string); ok {
options.CursorForward = v
}
if v, ok := args["cursor_backward"].(string); ok {
options.CursorBackward = v
}
options.Columns = parseStringArray(args["columns"])
options.OmitColumns = parseStringArray(args["omit_columns"])
options.Filters = parseFilters(args["filters"])
options.Sort = parseSortOptions(args["sort"])
options.Preload = parsePreloadOptions(args["preloads"])
return options
}
func parseStringArray(raw interface{}) []string {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]string, 0, len(items))
for _, item := range items {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
}
func parseFilters(raw interface{}) []common.FilterOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.FilterOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var f common.FilterOption
if err := json.Unmarshal(b, &f); err != nil {
continue
}
if f.Column == "" || f.Operator == "" {
continue
}
if strings.EqualFold(f.LogicOperator, "or") {
f.LogicOperator = "OR"
} else {
f.LogicOperator = "AND"
}
result = append(result, f)
}
return result
}
func parseSortOptions(raw interface{}) []common.SortOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.SortOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var s common.SortOption
if err := json.Unmarshal(b, &s); err != nil {
continue
}
if s.Column == "" {
continue
}
result = append(result, s)
}
return result
}
func parsePreloadOptions(raw interface{}) []common.PreloadOption {
if raw == nil {
return nil
}
items, ok := raw.([]interface{})
if !ok {
return nil
}
result := make([]common.PreloadOption, 0, len(items))
for _, item := range items {
b, err := json.Marshal(item)
if err != nil {
continue
}
var p common.PreloadOption
if err := json.Unmarshal(b, &p); err != nil {
continue
}
if p.Relation == "" {
continue
}
result = append(result, p)
}
return result
}
// marshalResult marshals a value to JSON and returns it as an MCP text result.
func marshalResult(v interface{}) (*mcp.CallToolResult, error) {
b, err := json.Marshal(v)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("error marshaling result: %v", err)), nil
}
return mcp.NewToolResultText(string(b)), nil
}

View File

@@ -644,6 +644,7 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
CreatedAt time.Time `json:"created_at"`
Tags []Tag `json:"tags,omitempty" gorm:"many2many:post_tags"`
}
// Schema.Table format
handler.registry.RegisterModel("core.users", &User{})
handler.registry.RegisterModel("core.posts", &Post{})
@@ -654,11 +655,13 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
```go
package main
import (
"log"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
"github.com/gorilla/mux"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)

View File

@@ -24,6 +24,7 @@ const (
// - pkName: primary key column (e.g. "id")
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
// - options: the request options containing sort and cursor information
// - expandJoins: optional map[alias]string of JOIN clauses for join-column sort support
//
// Returns SQL snippet to embed in WHERE clause.
func GetCursorFilter(
@@ -31,8 +32,10 @@ func GetCursorFilter(
pkName string,
modelColumns []string,
options common.RequestOptions,
expandJoins map[string]string,
) (string, error) {
// Remove schema prefix if present
// Separate schema prefix from bare table name
fullTableName := tableName
if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1]
}
@@ -57,18 +60,19 @@ func GetCursorFilter(
// 3. Prepare
// --------------------------------------------------------------------- //
var whereClauses []string
joinSQL := ""
reverse := direction < 0
// --------------------------------------------------------------------- //
// 4. Process each sort column
// --------------------------------------------------------------------- //
for _, s := range sortItems {
col := strings.TrimSpace(s.Column)
col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" {
continue
}
// Parse: "created_at", "user.name", etc.
// Parse: "created_at", "user.name", "fn.sortorder", etc.
parts := strings.Split(col, ".")
field := strings.TrimSpace(parts[len(parts)-1])
prefix := strings.Join(parts[:len(parts)-1], ".")
@@ -81,7 +85,7 @@ func GetCursorFilter(
}
// Resolve column
cursorCol, targetCol, err := resolveColumn(
cursorCol, targetCol, isJoin, err := resolveColumn(
field, prefix, tableName, modelColumns,
)
if err != nil {
@@ -89,6 +93,22 @@ func GetCursorFilter(
continue
}
// Handle joins
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
// Build inequality
op := "<"
if desc {
@@ -112,10 +132,12 @@ func GetCursorFilter(
query := fmt.Sprintf(`EXISTS (
SELECT 1
FROM %s cursor_select
%s
WHERE cursor_select.%s = %s
AND (%s)
)`,
tableName,
fullTableName,
joinSQL,
pkName,
cursorID,
orSQL,
@@ -136,35 +158,44 @@ func getActiveCursor(options common.RequestOptions) (id string, direction Cursor
return "", 0
}
// Helper: resolve column (main table only for now)
// Helper: resolve column (main table or join)
func resolveColumn(
field, prefix, tableName string,
modelColumns []string,
) (cursorCol, targetCol string, err error) {
) (cursorCol, targetCol string, isJoin bool, err error) {
// JSON field
if strings.Contains(field, "->") {
return "cursor_select." + field, tableName + "." + field, nil
return "cursor_select." + field, tableName + "." + field, false, nil
}
// Main table column
if modelColumns != nil {
for _, col := range modelColumns {
if strings.EqualFold(col, field) {
return "cursor_select." + field, tableName + "." + field, nil
return "cursor_select." + field, tableName + "." + field, false, nil
}
}
} else {
// No validation → allow all main-table fields
return "cursor_select." + field, tableName + "." + field, nil
return "cursor_select." + field, tableName + "." + field, false, nil
}
// Joined column (not supported in resolvespec yet)
// Joined column
if prefix != "" && prefix != tableName {
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
return "", "", true, nil
}
return "", "", fmt.Errorf("invalid column: %s", field)
return "", "", false, fmt.Errorf("invalid column: %s", field)
}
// Helper: rewrite JOIN clause for cursor subquery
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
cursorAlias = "cursor_select_" + alias
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
return joinSQL, cursorAlias
}
// ------------------------------------------------------------------------- //

View File

@@ -20,7 +20,7 @@ func TestGetCursorFilter_Forward(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
@@ -65,7 +65,7 @@ func TestGetCursorFilter_Backward(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title", "created_at", "user_id"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
@@ -96,7 +96,7 @@ func TestGetCursorFilter_NoCursor(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title", "created_at"}
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err == nil {
t.Error("Expected error when no cursor is provided")
}
@@ -116,7 +116,7 @@ func TestGetCursorFilter_NoSort(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title"}
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err == nil {
t.Error("Expected error when no sort columns are defined")
}
@@ -140,7 +140,7 @@ func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "title", "priority", "created_at"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
@@ -170,19 +170,50 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "name", "email"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
// Should handle schema prefix properly
if !strings.Contains(filter, "users") {
t.Errorf("Filter should reference table name users, got: %s", filter)
// Should include full schema-qualified name in FROM clause
if !strings.Contains(filter, "public.users") {
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
}
t.Logf("Generated cursor filter with schema: %s", filter)
}
func TestGetCursorFilter_LateralJoin(t *testing.T) {
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
options := common.RequestOptions{
Sort: []common.SortOption{{Column: "fn.sortorder", Direction: "ASC"}},
CursorForward: "8975",
}
tableName := "core.account"
pkName := "rid_account"
modelColumns := []string{"rid_account", "description", "pastelno"}
expandJoins := map[string]string{"fn": lateralJoin}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, expandJoins)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
t.Logf("Generated lateral cursor filter: %s", filter)
if !strings.Contains(filter, "cursor_select_fn") {
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
}
if !strings.Contains(filter, "sortorder") {
t.Errorf("Filter should reference sortorder column, got: %s", filter)
}
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
}
}
func TestGetActiveCursor(t *testing.T) {
tests := []struct {
name string
@@ -288,18 +319,19 @@ func TestResolveColumn(t *testing.T) {
wantErr: false,
},
{
name: "Joined column (not supported)",
name: "Joined column (isJoin=true, no error)",
field: "name",
prefix: "user",
tableName: "posts",
modelColumns: []string{"id", "title"},
wantErr: true,
wantErr: false,
// cursorCol and targetCol are empty when isJoin=true; handled by caller
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
cursor, target, isJoin, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
if tt.wantErr {
if err == nil {
@@ -312,6 +344,14 @@ func TestResolveColumn(t *testing.T) {
t.Fatalf("Unexpected error: %v", err)
}
// For join columns, cursor/target are empty and isJoin=true
if isJoin {
if cursor != "" || target != "" {
t.Errorf("Expected empty cursor/target for join column, got %q / %q", cursor, target)
}
return
}
if cursor != tt.wantCursor {
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
}
@@ -362,7 +402,7 @@ func TestCursorFilter_SQL_Safety(t *testing.T) {
pkName := "id"
modelColumns := []string{"id", "created_at"}
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}

View File

@@ -44,8 +44,8 @@ func TestBuildFilterCondition(t *testing.T) {
Operator: "in",
Value: []string{"active", "pending"},
},
expectedCondition: "status IN (?)",
expectedArgsCount: 1,
expectedCondition: "status IN (?,?)",
expectedArgsCount: 2,
},
{
name: "LIKE operator",

View File

@@ -138,6 +138,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
validator := common.NewColumnValidator(model)
req.Options = validator.FilterRequestOptions(req.Options)
// Execute BeforeHandle hook - auth check fires here, after model resolution
beforeCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Writer: w,
Request: r,
Operation: req.Operation,
}
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
code := http.StatusUnauthorized
if beforeCtx.AbortCode != 0 {
code = beforeCtx.AbortCode
}
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
return
}
switch req.Operation {
case "read":
h.handleRead(ctx, w, id, req.Options)
@@ -309,8 +329,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Extract model columns for validation
modelColumns := reflection.GetModelColumns(model)
// Get cursor filter SQL
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
// Default sort to primary key when none provided
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// Get cursor filter SQL (expandJoins is empty for resolvespec — no custom SQL join support yet)
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
if err != nil {
logger.Error("Error building cursor filter: %v", err)
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
@@ -1501,22 +1526,22 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
var args []interface{}
switch filter.Operator {
case "eq":
case "eq", "=":
condition = fmt.Sprintf("%s = ?", filter.Column)
args = []interface{}{filter.Value}
case "neq":
case "neq", "!=", "<>":
condition = fmt.Sprintf("%s != ?", filter.Column)
args = []interface{}{filter.Value}
case "gt":
case "gt", ">":
condition = fmt.Sprintf("%s > ?", filter.Column)
args = []interface{}{filter.Value}
case "gte":
case "gte", ">=":
condition = fmt.Sprintf("%s >= ?", filter.Column)
args = []interface{}{filter.Value}
case "lt":
case "lt", "<":
condition = fmt.Sprintf("%s < ?", filter.Column)
args = []interface{}{filter.Value}
case "lte":
case "lte", "<=":
condition = fmt.Sprintf("%s <= ?", filter.Column)
args = []interface{}{filter.Value}
case "like":
@@ -1526,8 +1551,10 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
args = []interface{}{filter.Value}
case "in":
condition = fmt.Sprintf("%s IN (?)", filter.Column)
args = []interface{}{filter.Value}
condition, args = common.BuildInCondition(filter.Column, filter.Value)
if condition == "" {
return "", nil
}
default:
return "", nil
}
@@ -1543,22 +1570,22 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
var args []interface{}
switch filter.Operator {
case "eq":
case "eq", "=":
condition = fmt.Sprintf("%s = ?", filter.Column)
args = []interface{}{filter.Value}
case "neq":
case "neq", "!=", "<>":
condition = fmt.Sprintf("%s != ?", filter.Column)
args = []interface{}{filter.Value}
case "gt":
case "gt", ">":
condition = fmt.Sprintf("%s > ?", filter.Column)
args = []interface{}{filter.Value}
case "gte":
case "gte", ">=":
condition = fmt.Sprintf("%s >= ?", filter.Column)
args = []interface{}{filter.Value}
case "lt":
case "lt", "<":
condition = fmt.Sprintf("%s < ?", filter.Column)
args = []interface{}{filter.Value}
case "lte":
case "lte", "<=":
condition = fmt.Sprintf("%s <= ?", filter.Column)
args = []interface{}{filter.Value}
case "like":
@@ -1568,8 +1595,10 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
args = []interface{}{filter.Value}
case "in":
condition = fmt.Sprintf("%s IN (?)", filter.Column)
args = []interface{}{filter.Value}
condition, args = common.BuildInCondition(filter.Column, filter.Value)
if condition == "" {
return query
}
default:
return query
}

View File

@@ -12,6 +12,10 @@ import (
type HookType string
const (
// BeforeHandle fires after model resolution, before operation dispatch.
// Use this for auth checks that need model rules and user context simultaneously.
BeforeHandle HookType = "before_handle"
// Read operation hooks
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
@@ -43,6 +47,9 @@ type HookContext struct {
Writer common.ResponseWriter
Request common.Request
// Operation being dispatched (e.g. "read", "create", "update", "delete")
Operation string
// Operation-specific fields
ID string
Data interface{} // For create/update operations

View File

@@ -70,17 +70,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
// Create handler functions for this specific entity
postEntityHandler := createMuxHandler(handler, schema, entity, "")
postEntityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
getEntityHandler := createMuxGetHandler(handler, schema, entity, "")
var postEntityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
var postEntityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
var getEntityHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
// Apply authentication middleware if provided
if authMiddleware != nil {
postEntityHandler = authMiddleware(postEntityHandler).(http.HandlerFunc)
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler).(http.HandlerFunc)
getEntityHandler = authMiddleware(getEntityHandler).(http.HandlerFunc)
postEntityHandler = authMiddleware(postEntityHandler)
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler)
getEntityHandler = authMiddleware(getEntityHandler)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
}
@@ -225,7 +225,11 @@ func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware Middlewa
return func(w http.ResponseWriter, req bunrouter.Request) error {
// Create an http.Handler that calls the bunrouter handler
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = handler(w, req)
// Replace the embedded *http.Request with the middleware-enriched one
// so that auth context (user ID, etc.) is visible to the handler.
enrichedReq := req
enrichedReq.Request = r
_ = handler(w, enrichedReq)
})
// Wrap with auth middleware and execute

View File

@@ -2,6 +2,7 @@ package resolvespec
import (
"context"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/common"
"github.com/bitechdev/ResolveSpec/pkg/logger"
@@ -10,6 +11,17 @@ import (
// RegisterSecurityHooks registers all security-related hooks with the handler
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 0: BeforeHandle - enforce auth after model resolution
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
hookCtx.Abort = true
hookCtx.AbortMessage = err.Error()
hookCtx.AbortCode = http.StatusUnauthorized
return err
}
return nil
})
// Hook 1: BeforeRead - Load security rules
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)

View File

@@ -147,6 +147,7 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
```
**Available Hook Types**:
* `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
* `BeforeRead`, `AfterRead`
* `BeforeCreate`, `AfterCreate`
* `BeforeUpdate`, `AfterUpdate`
@@ -157,11 +158,13 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
* `Handler`: Access to handler, database, and registry
* `Schema`, `Entity`, `TableName`: Request info
* `Model`: The registered model type
* `Operation`: Current operation string (`"read"`, `"create"`, `"update"`, `"delete"`)
* `Options`: Parsed request options (filters, sorting, etc.)
* `ID`: Record ID (for single-record operations)
* `Data`: Request data (for create/update)
* `Result`: Operation result (for after hooks)
* `Writer`: Response writer (allows hooks to modify response)
* `Abort`, `AbortMessage`, `AbortCode`: Set in hook to abort with an error response
## Cursor Pagination

View File

@@ -32,6 +32,8 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
modelColumns []string, // optional: for validation
expandJoins map[string]string, // optional: alias → JOIN SQL
) (string, error) {
// Separate schema prefix from bare table name
fullTableName := tableName
if strings.Contains(tableName, ".") {
tableName = strings.SplitN(tableName, ".", 2)[1]
}
@@ -62,7 +64,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
// 4. Process each sort column
// --------------------------------------------------------------------- //
for _, s := range sortItems {
col := strings.TrimSpace(s.Column)
col := strings.Trim(strings.TrimSpace(s.Column), "()")
if col == "" {
continue
}
@@ -91,12 +93,18 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
}
// Handle joins
if isJoin && expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
if isJoin {
if expandJoins != nil {
if joinClause, ok := expandJoins[prefix]; ok {
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
joinSQL = jSQL
cursorCol = cRef + "." + field
targetCol = prefix + "." + field
}
}
if cursorCol == "" {
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
continue
}
}
@@ -127,7 +135,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
WHERE cursor_select.%s = %s
AND (%s)
)`,
tableName,
fullTableName,
joinSQL,
pkName,
cursorID,

View File

@@ -187,9 +187,9 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
t.Fatalf("GetCursorFilter failed: %v", err)
}
// Should handle schema prefix properly
if !strings.Contains(filter, "users") {
t.Errorf("Filter should reference table name users, got: %s", filter)
// Should include full schema-qualified name in FROM clause
if !strings.Contains(filter, "public.users") {
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
}
t.Logf("Generated cursor filter with schema: %s", filter)
@@ -278,6 +278,47 @@ func TestCleanSortField(t *testing.T) {
}
}
func TestGetCursorFilter_LateralJoin(t *testing.T) {
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
opts := &ExtendedRequestOptions{
RequestOptions: common.RequestOptions{
Sort: []common.SortOption{
{Column: "fn.sortorder", Direction: "ASC"},
},
},
}
opts.CursorForward = "8975"
tableName := "core.account"
pkName := "rid_account"
// modelColumns does not contain "sortorder" - it's a lateral join computed column
modelColumns := []string{"rid_account", "description", "pastelno"}
expandJoins := map[string]string{"fn": lateralJoin}
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
if err != nil {
t.Fatalf("GetCursorFilter failed: %v", err)
}
t.Logf("Generated lateral cursor filter: %s", filter)
// Should contain the rewritten lateral join inside the EXISTS subquery
if !strings.Contains(filter, "cursor_select_fn") {
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
}
// Should compare fn.sortorder values
if !strings.Contains(filter, "sortorder") {
t.Errorf("Filter should reference sortorder column, got: %s", filter)
}
// Should NOT contain empty comparison like "< "
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
}
}
func TestBuildPriorityChain(t *testing.T) {
clauses := []string{
"cursor_select.priority > posts.priority",

View File

@@ -133,6 +133,41 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
// Add request-scoped data to context (including options)
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
// Derive operation for auth check
var operation string
switch method {
case "GET":
operation = "read"
case "POST":
operation = "create"
case "PUT", "PATCH":
operation = "update"
case "DELETE":
operation = "delete"
default:
operation = "read"
}
// Execute BeforeHandle hook - auth check fires here, after model resolution
beforeCtx := &HookContext{
Context: ctx,
Handler: h,
Schema: schema,
Entity: entity,
Model: model,
Writer: w,
Request: r,
Operation: operation,
}
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
code := http.StatusUnauthorized
if beforeCtx.AbortCode != 0 {
code = beforeCtx.AbortCode
}
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
return
}
switch method {
case "GET":
if id != "" {
@@ -688,12 +723,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
// Extract model columns for validation using the generic database function
modelColumns := reflection.GetModelColumns(model)
// Build expand joins map (if needed in future)
var expandJoins map[string]string
if len(options.Expand) > 0 {
expandJoins = make(map[string]string)
// TODO: Build actual JOIN SQL for each expand relation
// For now, pass empty map as joins are handled via Preload
// Build expand joins map: custom SQL joins are available in cursor subquery
expandJoins := make(map[string]string)
for _, joinClause := range options.CustomSQLJoin {
alias := extractJoinAlias(joinClause)
if alias != "" {
expandJoins[alias] = joinClause
}
}
// TODO: also add Expand relation JOINs when those are built as SQL rather than Preload
// Default sort to primary key when none provided
if len(options.Sort) == 0 {
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
}
// Get cursor filter SQL
@@ -2111,7 +2153,11 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
// Column is already cast to TEXT if needed
return applyWhere(fmt.Sprintf("%s ILIKE ?", qualifiedColumn), filter.Value)
case "in":
return applyWhere(fmt.Sprintf("%s IN (?)", qualifiedColumn), filter.Value)
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
if cond == "" {
return query
}
return applyWhere(cond, inArgs...)
case "between":
// Handle between operator - exclusive (> val1 AND < val2)
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
@@ -2187,24 +2233,25 @@ func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common
// buildFilterCondition builds a single filter condition and returns the condition string and args
func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) {
switch strings.ToLower(filter.Operator) {
case "eq", "equals":
case "eq", "equals", "=":
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
case "neq", "not_equals", "ne":
case "neq", "not_equals", "ne", "!=", "<>":
return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value}
case "gt", "greater_than":
case "gt", "greater_than", ">":
return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value}
case "gte", "greater_than_equals", "ge":
case "gte", "greater_than_equals", "ge", ">=":
return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value}
case "lt", "less_than":
case "lt", "less_than", "<":
return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value}
case "lte", "less_than_equals", "le":
case "lte", "less_than_equals", "le", "<=":
return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value}
case "like":
return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value}
case "ilike":
return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value}
case "in":
return fmt.Sprintf("%s IN (?)", qualifiedColumn), []interface{}{filter.Value}
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
return cond, inArgs
case "between":
// Handle between operator - exclusive (> val1 AND < val2)
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
@@ -2839,6 +2886,8 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
// Filter base RequestOptions
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
// Restore JoinAliases cleared by FilterRequestOptions — still needed for SanitizeWhereClause
filtered.RequestOptions.JoinAliases = options.JoinAliases
// Filter SearchColumns
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)

View File

@@ -274,9 +274,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
}
}
// Resolve relation names (convert table names to field names) if model is provided
// Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names
if model != nil && !options.XFilesPresent {
// Resolve relation names (convert table names/prefixes to actual model field names) if model is provided.
// This runs for both regular headers and X-Files, because XFile prefixes don't always match model
// field names (e.g., prefix "HUB" vs field "HUB_RID_HUB"). RelatedKey/ForeignKey are used to
// disambiguate when multiple fields point to the same related type.
if model != nil {
h.resolveRelationNamesInOptions(&options, model)
}
@@ -550,10 +552,8 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri
// - "LEFT JOIN departments d ON ..." -> "d"
// - "INNER JOIN users AS u ON ..." -> "u"
// - "JOIN roles r ON ..." -> "r"
// - "INNER JOIN LATERAL (...) fn ON true" -> "fn"
func extractJoinAlias(joinClause string) string {
// Pattern: JOIN table_name [AS] alias ON ...
// We need to extract the alias (word before ON)
upperJoin := strings.ToUpper(joinClause)
// Find the "JOIN" keyword position
@@ -562,7 +562,20 @@ func extractJoinAlias(joinClause string) string {
return ""
}
// Find the "ON" keyword position
// Lateral joins: alias is the word after the closing ) and before ON
if strings.Contains(upperJoin, "LATERAL") {
lastClose := strings.LastIndex(joinClause, ")")
if lastClose != -1 {
words := strings.Fields(joinClause[lastClose+1:])
// words should be like ["fn", "on", "true"] or ["on", "true"]
if len(words) >= 1 && !strings.EqualFold(words[0], "on") {
return words[0]
}
}
return ""
}
// Regular joins: find the "ON" keyword position (first occurrence)
onIdx := strings.Index(upperJoin, " ON ")
if onIdx == -1 {
return ""
@@ -863,8 +876,21 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
// Resolve each part of the path
currentModel := model
for _, part := range parts {
resolvedPart := h.resolveRelationName(currentModel, part)
for partIdx, part := range parts {
isLast := partIdx == len(parts)-1
var resolvedPart string
if isLast {
// For the final part, use join-key-aware resolution to disambiguate when
// multiple fields point to the same type (e.g., HUB_RID_HUB vs HUB_RID_ASSIGNEDTO).
// RelatedKey = parent's local column linking to child; ForeignKey = local column linking to parent.
localKey := preload.RelatedKey
if localKey == "" {
localKey = preload.ForeignKey
}
resolvedPart = h.resolveRelationNameWithJoinKey(currentModel, part, localKey)
} else {
resolvedPart = h.resolveRelationName(currentModel, part)
}
resolvedParts = append(resolvedParts, resolvedPart)
// Try to get the model type for the next level
@@ -980,6 +1006,101 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
return nameOrTable
}
// resolveRelationNameWithJoinKey resolves a relation name like resolveRelationName, but when
// multiple fields point to the same related type, uses localKey to pick the one whose bun join
// tag starts with "join:localKey=". Falls back to resolveRelationName if no key match is found.
func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable string, localKey string) string {
if localKey == "" {
return h.resolveRelationName(model, nameOrTable)
}
modelType := reflect.TypeOf(model)
if modelType == nil {
return nameOrTable
}
if modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}
if modelType == nil || modelType.Kind() != reflect.Struct {
return nameOrTable
}
// If it's already a direct field name, return as-is (no ambiguity).
for i := 0; i < modelType.NumField(); i++ {
if modelType.Field(i).Name == nameOrTable {
return nameOrTable
}
}
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
localKeyLower := strings.ToLower(localKey)
// Find all fields whose related type matches nameOrTable, then pick the one
// whose bun join tag local key matches localKey.
var fallbackField string
for i := 0; i < modelType.NumField(); i++ {
field := modelType.Field(i)
fieldType := field.Type
var targetType reflect.Type
if fieldType.Kind() == reflect.Slice {
targetType = fieldType.Elem()
} else if fieldType.Kind() == reflect.Ptr {
targetType = fieldType.Elem()
}
if targetType != nil && targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
}
if targetType == nil || targetType.Kind() != reflect.Struct {
continue
}
normalizedTypeName := strings.ToLower(targetType.Name())
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
if normalizedTypeName != normalizedInput {
continue
}
// Type name matches; record as fallback.
if fallbackField == "" {
fallbackField = field.Name
}
// Check bun join tag: "join:localKey=foreignKey"
bunTag := field.Tag.Get("bun")
for _, tagPart := range strings.Split(bunTag, ",") {
tagPart = strings.TrimSpace(tagPart)
if !strings.HasPrefix(tagPart, "join:") {
continue
}
joinSpec := strings.TrimPrefix(tagPart, "join:")
// joinSpec can be "col1=col2" or "col1=col2 col3=col4" (multi-col joins)
joinCols := strings.Fields(joinSpec)
if len(joinCols) == 0 {
joinCols = []string{joinSpec}
}
for _, joinCol := range joinCols {
eqIdx := strings.Index(joinCol, "=")
if eqIdx < 0 {
continue
}
joinLocalKey := strings.ToLower(joinCol[:eqIdx])
if joinLocalKey == localKeyLower {
logger.Debug("Resolved '%s' (localKey: %s) -> field '%s'", nameOrTable, localKey, field.Name)
return field.Name
}
}
}
}
if fallbackField != "" {
logger.Debug("No join key match for '%s' (localKey: %s), using first type match: '%s'", nameOrTable, localKey, fallbackField)
return fallbackField
}
return h.resolveRelationName(model, nameOrTable)
}
// addXFilesPreload converts an XFiles relation into a PreloadOption
// and recursively processes its children
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
@@ -1061,15 +1182,42 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
}
}
// Transfer SqlJoins from XFiles to PreloadOption first, so aliases are available for WHERE sanitization
if len(xfile.SqlJoins) > 0 {
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
for _, joinClause := range xfile.SqlJoins {
// Sanitize the join clause
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
if sanitizedJoin == "" {
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
continue
}
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
// Extract join alias for validation
alias := extractJoinAlias(sanitizedJoin)
if alias != "" {
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
}
}
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
}
// Add WHERE clause if SQL conditions specified
// SqlJoins must be processed first so join aliases are known and not incorrectly replaced
whereConditions := make([]string, 0)
if len(xfile.SqlAnd) > 0 {
// Process each SQL condition
// Note: We don't add table prefixes here because they're only needed for JOINs
// The handler will add prefixes later if SqlJoins are present
var sqlAndOpts *common.RequestOptions
if len(preloadOpt.JoinAliases) > 0 {
sqlAndOpts = &common.RequestOptions{JoinAliases: preloadOpt.JoinAliases}
}
for _, sqlCond := range xfile.SqlAnd {
// Sanitize the condition without adding prefixes
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName)
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName, sqlAndOpts)
if sanitizedCond != "" {
whereConditions = append(whereConditions, sanitizedCond)
}
@@ -1114,32 +1262,6 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
}
// Transfer SqlJoins from XFiles to PreloadOption
if len(xfile.SqlJoins) > 0 {
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
for _, joinClause := range xfile.SqlJoins {
// Sanitize the join clause
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
if sanitizedJoin == "" {
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
continue
}
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
// Extract join alias for validation
alias := extractJoinAlias(sanitizedJoin)
if alias != "" {
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
}
}
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
}
// Check if this table has a recursive child - if so, mark THIS preload as recursive
// and store the recursive child's RelatedKey for recursion generation
hasRecursiveChild := false

View File

@@ -142,6 +142,16 @@ func TestExtractJoinAlias(t *testing.T) {
joinClause: "LEFT JOIN departments",
expected: "",
},
{
name: "LATERAL join with alias",
joinClause: "inner join lateral (select sortorder from compute_fn(t.id)) fn on true",
expected: "fn",
},
{
name: "LATERAL join with multiline subquery containing inner ON",
joinClause: "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(t.id) r\ninner join account a on a.id = r.id\n) fn on true",
expected: "fn",
},
}
for _, tt := range tests {

View File

@@ -12,6 +12,10 @@ import (
type HookType string
const (
// BeforeHandle fires after model resolution, before operation dispatch.
// Use this for auth checks that need model rules and user context simultaneously.
BeforeHandle HookType = "before_handle"
// Read operation hooks
BeforeRead HookType = "before_read"
AfterRead HookType = "after_read"
@@ -42,6 +46,9 @@ type HookContext struct {
Model interface{}
Options ExtendedRequestOptions
// Operation being dispatched (e.g. "read", "create", "update", "delete")
Operation string
// Operation-specific fields
ID string
Data interface{} // For create/update operations
@@ -56,6 +63,14 @@ type HookContext struct {
// Response writer - allows hooks to modify response
Writer common.ResponseWriter
// Request - the original HTTP request
Request common.Request
// Allow hooks to abort the operation
Abort bool // If set to true, the operation will be aborted
AbortMessage string // Message to return if aborted
AbortCode int // HTTP status code if aborted
// Tx provides access to the database/transaction for executing additional SQL
// This allows hooks to run custom queries in addition to the main Query chain
Tx common.Database
@@ -110,6 +125,12 @@ func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
return fmt.Errorf("hook execution failed: %w", err)
}
// Check if hook requested abort
if ctx.Abort {
logger.Warn("Hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
// logger.Debug("All hooks for %s executed successfully", hookType)

View File

@@ -125,17 +125,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
metadataPath := buildRoutePath(schema, entity) + "/metadata"
// Create handler functions for this specific entity
entityHandler := createMuxHandler(handler, schema, entity, "")
entityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
metadataHandler := createMuxGetHandler(handler, schema, entity, "")
var entityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
var entityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
var metadataHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"})
// Apply authentication middleware if provided
if authMiddleware != nil {
entityHandler = authMiddleware(entityHandler).(http.HandlerFunc)
entityWithIDHandler = authMiddleware(entityWithIDHandler).(http.HandlerFunc)
metadataHandler = authMiddleware(metadataHandler).(http.HandlerFunc)
entityHandler = authMiddleware(entityHandler)
entityWithIDHandler = authMiddleware(entityWithIDHandler)
metadataHandler = authMiddleware(metadataHandler)
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
}
@@ -289,7 +289,11 @@ func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware Middlewa
return func(w http.ResponseWriter, req bunrouter.Request) error {
// Create an http.Handler that calls the bunrouter handler
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = handler(w, req)
// Replace the embedded *http.Request with the middleware-enriched one
// so that auth context (user ID, etc.) is visible to the handler.
enrichedReq := req
enrichedReq.Request = r
_ = handler(w, enrichedReq)
})
// Wrap with auth middleware and execute

View File

@@ -2,6 +2,7 @@ package restheadspec
import (
"context"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
@@ -9,6 +10,17 @@ import (
// RegisterSecurityHooks registers all security-related hooks with the handler
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 0: BeforeHandle - enforce auth after model resolution
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
hookCtx.Abort = true
hookCtx.AbortMessage = err.Error()
hookCtx.AbortCode = http.StatusUnauthorized
return err
}
return nil
})
// Hook 1: BeforeRead - Load security rules
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)

View File

@@ -258,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
}
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
// Add to blacklist
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{
"token": req.Token,
"user_id": req.UserID,
}).Error
// Invalidate session via stored procedure
return nil
}
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
@@ -405,11 +402,16 @@ assert.Equal(t, "user_id = {UserID}", row.Template)
```
HTTP Request
NewAuthMiddleware → calls provider.Authenticate()
↓ (adds UserContext to context)
NewOptionalAuthMiddleware → calls provider.Authenticate()
↓ (adds UserContext or guest context; never 401)
SetSecurityMiddleware → adds SecurityList to context
Handler.Handle()
Handler.Handle() → resolves model
BeforeHandle Hook → CheckModelAuthAllowed(secCtx, operation)
├─ SecurityDisabled → allow
├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
└─ UserID == 0 → abort 401
BeforeRead Hook → calls provider.GetColumnSecurity() + GetRowSecurity()
@@ -693,15 +695,30 @@ http.Handle("/api/protected", authHandler)
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
http.Handle("/home", optionalHandler)
// Example handler
func myHandler(w http.ResponseWriter, r *http.Request) {
userCtx, _ := security.GetUserContext(r.Context())
if userCtx.UserID == 0 {
// Guest user
} else {
// Authenticated user
}
}
// NewOptionalAuthMiddleware - For spec routes; auth enforcement deferred to BeforeHandle
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
apiRouter.Use(security.SetSecurityMiddleware(securityList))
restheadspec.RegisterSecurityHooks(handler, securityList) // includes BeforeHandle
```
---
## Model-Level Access Control
```go
// Register model with rules (pkg/modelregistry)
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
SecurityDisabled: false, // skip all auth when true
CanPublicRead: true, // unauthenticated reads allowed
CanPublicCreate: false, // requires auth
CanPublicUpdate: false, // requires auth
CanPublicDelete: false, // requires auth
CanUpdate: true, // authenticated can update
CanDelete: false, // authenticated cannot delete (enforced in BeforeDelete)
})
// CheckModelAuthAllowed used automatically in BeforeHandle hook
// No code needed — call RegisterSecurityHooks and it's applied
```
---

View File

@@ -751,14 +751,25 @@ resolvespec.RegisterSecurityHooks(resolveHandler, securityList)
```
HTTP Request
NewAuthMiddleware (security package)
NewOptionalAuthMiddleware (security package) ← recommended for spec routes
├─ Calls provider.Authenticate(request)
Adds UserContext to context
On success: adds authenticated UserContext to context
└─ On failure: adds guest UserContext (UserID=0) to context
SetSecurityMiddleware (security package)
└─ Adds SecurityList to context
Spec Handler (restheadspec/funcspec/resolvespec)
Spec Handler (restheadspec/funcspec/resolvespec/websocketspec/mqttspec)
└─ Resolves schema + entity + model from request
BeforeHandle Hook (registered by spec via RegisterSecurityHooks)
├─ Adapts spec's HookContext → SecurityContext
├─ Calls security.CheckModelAuthAllowed(secCtx, operation)
│ ├─ Loads model rules from context or registry
│ ├─ SecurityDisabled → allow
│ ├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
│ └─ UserID == 0 → 401 unauthorized
└─ On error: aborts with 401
BeforeRead Hook (registered by spec)
├─ Adapts spec's HookContext → SecurityContext
@@ -784,7 +795,8 @@ HTTP Response (secured data)
```
**Key Points:**
- Security package is spec-agnostic and provides core logic
- `NewOptionalAuthMiddleware` never rejects — it sets guest context on auth failure; `BeforeHandle` enforces auth after model resolution
- `BeforeHandle` fires after model resolution, giving access to model rules and user context simultaneously
- Each spec registers its own hooks that adapt to SecurityContext
- Security rules are loaded once and cached for the request
- Row security is applied to the query (database level)
@@ -1002,15 +1014,49 @@ func (p *MyProvider) GetRowSecurity(ctx context.Context, userID int, schema, tab
}
```
## Model-Level Access Control
Use `ModelRules` (from `pkg/modelregistry`) to control per-entity auth behavior:
```go
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
SecurityDisabled: false, // true = skip all auth checks
CanPublicRead: true, // unauthenticated GET allowed
CanPublicCreate: false, // requires auth
CanPublicUpdate: false, // requires auth
CanPublicDelete: false, // requires auth
CanUpdate: true, // authenticated users can update
CanDelete: false, // authenticated users cannot delete
})
```
`CheckModelAuthAllowed(secCtx, operation)` applies these rules in `BeforeHandle`:
1. `SecurityDisabled` → allow all
2. `CanPublicRead/Create/Update/Delete` → allow unauthenticated for that operation
3. Guest (UserID == 0) → return 401
4. Authenticated → allow (operation-specific `CanUpdate`/`CanDelete` checked in `BeforeUpdate`/`BeforeDelete`)
---
## Middleware and Handler API
### NewAuthMiddleware
Standard middleware that authenticates all requests:
Standard middleware that authenticates all requests and returns 401 on failure:
```go
router.Use(security.NewAuthMiddleware(securityList))
```
### NewOptionalAuthMiddleware
Middleware for spec routes — always continues; sets guest context on auth failure:
```go
// Use with RegisterSecurityHooks — auth enforcement is deferred to BeforeHandle
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
apiRouter.Use(security.SetSecurityMiddleware(securityList))
restheadspec.RegisterSecurityHooks(handler, securityList) // registers BeforeHandle
```
Routes can skip authentication using the `SkipAuth` helper:
```go

View File

@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
}
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
// For JWT, logout could involve token blacklisting
// Add token to blacklist table
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
// "token": req.Token,
// "expires_at": time.Now().Add(24 * time.Hour),
// }).Error
return nil
}

View File

@@ -275,6 +275,64 @@ func checkModelDeleteAllowed(secCtx SecurityContext) error {
return nil
}
// CheckModelAuthAllowed checks whether the requested operation is permitted based on
// model rules and the current user's authentication state. It is intended for use in
// a BeforeHandle hook, fired after model resolution.
//
// Logic:
// 1. Load model rules from context (set by NewModelAuthMiddleware) or fall back to registry.
// 2. SecurityDisabled → allow.
// 3. operation == "read" && CanPublicRead → allow.
// 4. operation == "create" && CanPublicCreate → allow.
// 5. operation == "update" && CanPublicUpdate → allow.
// 6. operation == "delete" && CanPublicDelete → allow.
// 7. Guest (UserID == 0) → return "authentication required".
// 8. Authenticated user → allow (operation-specific checks remain in BeforeUpdate/BeforeDelete).
func CheckModelAuthAllowed(secCtx SecurityContext, operation string) error {
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
if !ok {
schema := secCtx.GetSchema()
entity := secCtx.GetEntity()
var err error
if schema != "" {
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
}
if err != nil || schema == "" {
rules, err = modelregistry.GetModelRulesByName(entity)
}
if err != nil {
// Model not registered - fall through to auth check
userID, _ := secCtx.GetUserID()
if userID == 0 {
return fmt.Errorf("authentication required")
}
return nil
}
}
if rules.SecurityDisabled {
return nil
}
if operation == "read" && rules.CanPublicRead {
return nil
}
if operation == "create" && rules.CanPublicCreate {
return nil
}
if operation == "update" && rules.CanPublicUpdate {
return nil
}
if operation == "delete" && rules.CanPublicDelete {
return nil
}
userID, _ := secCtx.GetUserID()
if userID == 0 {
return fmt.Errorf("authentication required")
}
return nil
}
// CheckModelUpdateAllowed is the public wrapper for checkModelUpdateAllowed.
func CheckModelUpdateAllowed(secCtx SecurityContext) error {
return checkModelUpdateAllowed(secCtx)

View File

@@ -139,6 +139,31 @@ func NewOptionalAuthHandler(securityList *SecurityList, next http.Handler) http.
})
}
// NewOptionalAuthMiddleware creates authentication middleware that always continues.
// On auth failure, a guest user context is set instead of returning 401.
// Intended for spec routes where auth enforcement is deferred to a BeforeHandle hook
// after model resolution.
func NewOptionalAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
provider := securityList.Provider()
if provider == nil {
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
return
}
userCtx, err := provider.Authenticate(r)
if err != nil {
guestCtx := createGuestContext(r)
next.ServeHTTP(w, setUserContext(r, guestCtx))
return
}
next.ServeHTTP(w, setUserContext(r, userCtx))
})
}
}
// NewAuthMiddleware creates an authentication middleware with the given security list
// This middleware extracts user authentication from the request and adds it to context
// Routes can skip authentication by setting SkipAuthKey context value (use SkipAuth helper)
@@ -431,6 +456,125 @@ func GetUserMeta(ctx context.Context) (map[string]any, bool) {
return meta, ok
}
// SessionCookieOptions configures the session cookie set by SetSessionCookie.
// All fields are optional; sensible secure defaults are applied when omitted.
type SessionCookieOptions struct {
// Name is the cookie name. Defaults to "session_token".
Name string
// Path is the cookie path. Defaults to "/".
Path string
// Domain restricts the cookie to a specific domain. Empty means current host.
Domain string
// Secure sets the Secure flag. Defaults to true.
// Set to false only in local development over HTTP.
Secure *bool
// SameSite sets the SameSite policy. Defaults to http.SameSiteLaxMode.
SameSite http.SameSite
}
func (o SessionCookieOptions) name() string {
if o.Name != "" {
return o.Name
}
return "session_token"
}
func (o SessionCookieOptions) path() string {
if o.Path != "" {
return o.Path
}
return "/"
}
func (o SessionCookieOptions) secure() bool {
if o.Secure != nil {
return *o.Secure
}
return true
}
func (o SessionCookieOptions) sameSite() http.SameSite {
if o.SameSite != 0 {
return o.SameSite
}
return http.SameSiteLaxMode
}
// SetSessionCookie writes the session_token cookie to the response after a successful login.
// Call this immediately after a successful Authenticator.Login() call.
//
// Example:
//
// resp, err := auth.Login(r.Context(), req)
// if err != nil { ... }
// security.SetSessionCookie(w, resp)
// json.NewEncoder(w).Encode(resp)
func SetSessionCookie(w http.ResponseWriter, loginResp *LoginResponse, opts ...SessionCookieOptions) {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
maxAge := 0
if loginResp.ExpiresIn > 0 {
maxAge = int(loginResp.ExpiresIn)
}
http.SetCookie(w, &http.Cookie{
Name: o.name(),
Value: loginResp.Token,
Path: o.path(),
Domain: o.Domain,
MaxAge: maxAge,
HttpOnly: true,
Secure: o.secure(),
SameSite: o.sameSite(),
})
}
// GetSessionCookie returns the session token value from the request cookie, or empty string if not present.
//
// Example:
//
// token := security.GetSessionCookie(r)
func GetSessionCookie(r *http.Request, opts ...SessionCookieOptions) string {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
cookie, err := r.Cookie(o.name())
if err != nil {
return ""
}
return cookie.Value
}
// ClearSessionCookie expires the session_token cookie, effectively logging the user out on the browser side.
// Call this after a successful Authenticator.Logout() call.
//
// Example:
//
// err := auth.Logout(r.Context(), req)
// if err != nil { ... }
// security.ClearSessionCookie(w)
func ClearSessionCookie(w http.ResponseWriter, opts ...SessionCookieOptions) {
var o SessionCookieOptions
if len(opts) > 0 {
o = opts[0]
}
http.SetCookie(w, &http.Cookie{
Name: o.name(),
Value: "",
Path: o.path(),
Domain: o.Domain,
MaxAge: -1,
HttpOnly: true,
Secure: o.secure(),
SameSite: o.sameSite(),
})
}
// GetModelRulesFromContext extracts ModelRules stored by NewModelAuthMiddleware
func GetModelRulesFromContext(ctx context.Context) (modelregistry.ModelRules, bool) {
rules, ok := ctx.Value(ModelRulesKey).(modelregistry.ModelRules)

View File

@@ -244,10 +244,10 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
var errMsg *string
var userID *int
err = a.db.QueryRowContext(ctx, `
SELECT p_success, p_error, p_user_id
FROM resolvespec_oauth_getorcreateuser($1::jsonb)
`, userJSON).Scan(&success, &errMsg, &userID)
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_user_id
FROM %s($1::jsonb)
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
if err != nil {
return 0, fmt.Errorf("failed to get or create user: %w", err)
@@ -287,10 +287,10 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
var success bool
var errMsg *string
err = a.db.QueryRowContext(ctx, `
SELECT p_success, p_error
FROM resolvespec_oauth_createsession($1::jsonb)
`, sessionJSON).Scan(&success, &errMsg)
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error
FROM %s($1::jsonb)
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
if err != nil {
return fmt.Errorf("failed to create session: %w", err)
@@ -385,10 +385,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var errMsg *string
var sessionData []byte
err = a.db.QueryRowContext(ctx, `
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM resolvespec_oauth_getrefreshtoken($1)
`, refreshToken).Scan(&success, &errMsg, &sessionData)
FROM %s($1)
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
if err != nil {
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
@@ -451,10 +451,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var updateSuccess bool
var updateErrMsg *string
err = a.db.QueryRowContext(ctx, `
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb)
`, updateJSON).Scan(&updateSuccess, &updateErrMsg)
FROM %s($1::jsonb)
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
if err != nil {
return nil, fmt.Errorf("failed to update session: %w", err)
@@ -472,10 +472,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
var userErrMsg *string
var userData []byte
err = a.db.QueryRowContext(ctx, `
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
SELECT p_success, p_error, p_data::text
FROM resolvespec_oauth_getuser($1)
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
FROM %s($1)
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
if err != nil {
return nil, fmt.Errorf("failed to get user data: %w", err)

View File

@@ -11,12 +11,14 @@ import (
)
// DatabasePasskeyProvider implements PasskeyProvider using database storage
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabasePasskeyProvider struct {
db *sql.DB
rpID string // Relying Party ID (domain)
rpName string // Relying Party display name
rpOrigin string // Expected origin for WebAuthn
timeout int64 // Timeout in milliseconds (default: 60000)
sqlNames *SQLNames
}
// DatabasePasskeyProviderOptions configures the passkey provider
@@ -29,6 +31,8 @@ type DatabasePasskeyProviderOptions struct {
RPOrigin string
// Timeout is the timeout for operations in milliseconds (default: 60000)
Timeout int64
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
SQLNames *SQLNames
}
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
@@ -37,12 +41,15 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
opts.Timeout = 60000 // 60 seconds default
}
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabasePasskeyProvider{
db: db,
rpID: opts.RPID,
rpName: opts.RPName,
rpOrigin: opts.RPOrigin,
timeout: opts.Timeout,
sqlNames: sqlNames,
}
}
@@ -132,7 +139,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
var errorMsg sql.NullString
var credentialID sql.NullInt64
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
if err != nil {
return nil, fmt.Errorf("failed to store credential: %w", err)
@@ -173,7 +180,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
var userID sql.NullInt64
var credentialsJSON sql.NullString
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err)
@@ -233,7 +240,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var errorMsg sql.NullString
var credentialJSON sql.NullString
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
if err != nil {
return 0, fmt.Errorf("failed to get credential: %w", err)
@@ -264,7 +271,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
var updateError sql.NullString
var cloneWarning sql.NullBool
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)`
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
if err != nil {
return 0, fmt.Errorf("failed to update counter: %w", err)
@@ -283,7 +290,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
var errorMsg sql.NullString
var credentialsJSON sql.NullString
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
if err != nil {
return nil, fmt.Errorf("failed to get credentials: %w", err)
@@ -362,7 +369,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)`
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("failed to delete credential: %w", err)
@@ -388,7 +395,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)`
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("failed to update credential name: %w", err)

View File

@@ -58,8 +58,7 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
// DatabaseAuthenticator provides session-based authentication with database storage
// All database operations go through stored procedures for security and consistency
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
// resolvespec_session_update, resolvespec_refresh_token
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// See database_schema.sql for procedure definitions
// Also supports multiple OAuth2 providers configured with WithOAuth2()
// Also supports passkey authentication configured with WithPasskey()
@@ -67,6 +66,7 @@ type DatabaseAuthenticator struct {
db *sql.DB
cache *cache.Cache
cacheTTL time.Duration
sqlNames *SQLNames
// OAuth2 providers registry (multiple providers supported)
oauth2Providers map[string]*OAuth2Provider
@@ -85,6 +85,9 @@ type DatabaseAuthenticatorOptions struct {
Cache *cache.Cache
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
PasskeyProvider PasskeyProvider
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
// Partial overrides are supported: only set the fields you want to change.
SQLNames *SQLNames
}
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
@@ -103,10 +106,13 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
cacheInstance = cache.GetDefaultCache()
}
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
return &DatabaseAuthenticator{
db: db,
cache: cacheInstance,
cacheTTL: opts.CacheTTL,
sqlNames: sqlNames,
passkeyProvider: opts.PasskeyProvider,
}
}
@@ -118,12 +124,11 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
return nil, fmt.Errorf("failed to marshal login request: %w", err)
}
// Call resolvespec_login stored procedure
var success bool
var errorMsg sql.NullString
var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil {
return nil, fmt.Errorf("login query failed: %w", err)
@@ -153,12 +158,11 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
return nil, fmt.Errorf("failed to marshal register request: %w", err)
}
// Call resolvespec_register stored procedure
var success bool
var errorMsg sql.NullString
var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil {
return nil, fmt.Errorf("register query failed: %w", err)
@@ -187,12 +191,11 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
return fmt.Errorf("failed to marshal logout request: %w", err)
}
// Call resolvespec_logout stored procedure
var success bool
var errorMsg sql.NullString
var dataJSON sql.NullString
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil {
return fmt.Errorf("logout query failed: %w", err)
@@ -222,9 +225,8 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
if sessionToken == "" {
// Try cookie
cookie, err := r.Cookie("session_token")
if err == nil {
tokens = []string{cookie.Value}
if token := GetSessionCookie(r); token != "" {
tokens = []string{token}
reference = "cookie"
}
} else {
@@ -267,7 +269,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
var errorMsg sql.NullString
var userJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("session query failed: %w", err)
@@ -339,24 +341,22 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
return
}
// Call resolvespec_session_update stored procedure
var success bool
var errorMsg sql.NullString
var updatedUserJSON sql.NullString
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
}
// RefreshToken implements Refreshable interface
func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
// Call api_refresh_token stored procedure
// First, we need to get the current user context for the refresh token
var success bool
var errorMsg sql.NullString
var userJSON sql.NullString
// Get current session to pass to refresh
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("refresh token query failed: %w", err)
@@ -369,12 +369,11 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
return nil, fmt.Errorf("invalid refresh token")
}
// Call resolvespec_refresh_token to generate new token
var newSuccess bool
var newErrorMsg sql.NullString
var newUserJSON sql.NullString
refreshQuery := `SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token($1, $2::jsonb)`
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
if err != nil {
return nil, fmt.Errorf("refresh token generation failed: %w", err)
@@ -402,27 +401,28 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
// JWTAuthenticator provides JWT token-based authentication
// All database operations go through stored procedures
// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported
type JWTAuthenticator struct {
secretKey []byte
db *sql.DB
sqlNames *SQLNames
}
func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator {
func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTAuthenticator {
return &JWTAuthenticator{
secretKey: []byte(secretKey),
db: db,
sqlNames: resolveSQLNames(names...),
}
}
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
// Call resolvespec_jwt_login stored procedure
var success bool
var errorMsg sql.NullString
var userJSON []byte
query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
if err != nil {
return nil, fmt.Errorf("login query failed: %w", err)
@@ -472,11 +472,10 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginR
}
func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
// Call resolvespec_jwt_logout stored procedure
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)`
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("logout query failed: %w", err)
@@ -512,24 +511,24 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
// DatabaseColumnSecurityProvider loads column security from database
// All database operations go through stored procedures
// Requires stored procedure: resolvespec_column_security
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseColumnSecurityProvider struct {
db *sql.DB
db *sql.DB
sqlNames *SQLNames
}
func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider {
return &DatabaseColumnSecurityProvider{db: db}
func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
}
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
var rules []ColumnSecurity
// Call resolvespec_column_security stored procedure
var success bool
var errorMsg sql.NullString
var rulesJSON []byte
query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
if err != nil {
return nil, fmt.Errorf("failed to load column security: %w", err)
@@ -577,21 +576,21 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
// DatabaseRowSecurityProvider loads row security from database
// All database operations go through stored procedures
// Requires stored procedure: resolvespec_row_security
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
type DatabaseRowSecurityProvider struct {
db *sql.DB
db *sql.DB
sqlNames *SQLNames
}
func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider {
return &DatabaseRowSecurityProvider{db: db}
func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
}
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
var template string
var hasBlock bool
// Call resolvespec_row_security stored procedure
query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)`
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
if err != nil {
@@ -759,56 +758,47 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
return nil, fmt.Errorf("passkey authentication failed: %w", err)
}
// Get user data from database
var username, email, roles string
var userLevel int
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
if err != nil {
return nil, fmt.Errorf("failed to get user data: %w", err)
// Build request JSON for passkey login stored procedure
reqData := map[string]any{
"user_id": userID,
}
// Generate session token
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
expiresAt := time.Now().Add(24 * time.Hour)
// Extract IP and user agent from claims
ipAddress := ""
userAgent := ""
if req.Claims != nil {
if ip, ok := req.Claims["ip_address"].(string); ok {
ipAddress = ip
reqData["ip_address"] = ip
}
if ua, ok := req.Claims["user_agent"].(string); ok {
userAgent = ua
reqData["user_agent"] = ua
}
}
// Create session
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
VALUES ($1, $2, $3, $4, $5, now())`
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
reqJSON, err := json.Marshal(reqData)
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
return nil, fmt.Errorf("failed to marshal passkey login request: %w", err)
}
// Update last login
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1`
_, _ = a.db.ExecContext(ctx, updateQuery, userID)
var success bool
var errorMsg sql.NullString
var dataJSON sql.NullString
// Return login response
return &LoginResponse{
Token: sessionToken,
User: &UserContext{
UserID: userID,
UserName: username,
Email: email,
UserLevel: userLevel,
SessionID: sessionToken,
Roles: parseRoles(roles),
},
ExpiresIn: int64(24 * time.Hour.Seconds()),
}, nil
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
if err != nil {
return nil, fmt.Errorf("passkey login query failed: %w", err)
}
if !success {
if errorMsg.Valid {
return nil, fmt.Errorf("%s", errorMsg.String)
}
return nil, fmt.Errorf("passkey login failed")
}
var response LoginResponse
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
return nil, fmt.Errorf("failed to parse passkey login response: %w", err)
}
return &response, nil
}
// GetPasskeyCredentials returns all passkey credentials for a user

222
pkg/security/sql_names.go Normal file
View File

@@ -0,0 +1,222 @@
package security
import (
"fmt"
"reflect"
"regexp"
)
var validSQLIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
// SQLNames defines all configurable SQL stored procedure and table names
// used by the security package. Override individual fields to remap
// to custom database objects. Use DefaultSQLNames() for baseline defaults,
// and MergeSQLNames() to apply partial overrides.
type SQLNames struct {
// Auth procedures (DatabaseAuthenticator)
Login string // default: "resolvespec_login"
Register string // default: "resolvespec_register"
Logout string // default: "resolvespec_logout"
Session string // default: "resolvespec_session"
SessionUpdate string // default: "resolvespec_session_update"
RefreshToken string // default: "resolvespec_refresh_token"
// JWT procedures (JWTAuthenticator)
JWTLogin string // default: "resolvespec_jwt_login"
JWTLogout string // default: "resolvespec_jwt_logout"
// Security policy procedures
ColumnSecurity string // default: "resolvespec_column_security"
RowSecurity string // default: "resolvespec_row_security"
// TOTP procedures (DatabaseTwoFactorProvider)
TOTPEnable string // default: "resolvespec_totp_enable"
TOTPDisable string // default: "resolvespec_totp_disable"
TOTPGetStatus string // default: "resolvespec_totp_get_status"
TOTPGetSecret string // default: "resolvespec_totp_get_secret"
TOTPRegenerateBackup string // default: "resolvespec_totp_regenerate_backup_codes"
TOTPValidateBackupCode string // default: "resolvespec_totp_validate_backup_code"
// Passkey procedures (DatabasePasskeyProvider)
PasskeyStoreCredential string // default: "resolvespec_passkey_store_credential"
PasskeyGetCredsByUsername string // default: "resolvespec_passkey_get_credentials_by_username"
PasskeyGetCredential string // default: "resolvespec_passkey_get_credential"
PasskeyUpdateCounter string // default: "resolvespec_passkey_update_counter"
PasskeyGetUserCredentials string // default: "resolvespec_passkey_get_user_credentials"
PasskeyDeleteCredential string // default: "resolvespec_passkey_delete_credential"
PasskeyUpdateName string // default: "resolvespec_passkey_update_name"
PasskeyLogin string // default: "resolvespec_passkey_login"
// OAuth2 procedures (DatabaseAuthenticator OAuth2 methods)
OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser"
OAuthCreateSession string // default: "resolvespec_oauth_createsession"
OAuthGetRefreshToken string // default: "resolvespec_oauth_getrefreshtoken"
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
OAuthGetUser string // default: "resolvespec_oauth_getuser"
}
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
func DefaultSQLNames() *SQLNames {
return &SQLNames{
Login: "resolvespec_login",
Register: "resolvespec_register",
Logout: "resolvespec_logout",
Session: "resolvespec_session",
SessionUpdate: "resolvespec_session_update",
RefreshToken: "resolvespec_refresh_token",
JWTLogin: "resolvespec_jwt_login",
JWTLogout: "resolvespec_jwt_logout",
ColumnSecurity: "resolvespec_column_security",
RowSecurity: "resolvespec_row_security",
TOTPEnable: "resolvespec_totp_enable",
TOTPDisable: "resolvespec_totp_disable",
TOTPGetStatus: "resolvespec_totp_get_status",
TOTPGetSecret: "resolvespec_totp_get_secret",
TOTPRegenerateBackup: "resolvespec_totp_regenerate_backup_codes",
TOTPValidateBackupCode: "resolvespec_totp_validate_backup_code",
PasskeyStoreCredential: "resolvespec_passkey_store_credential",
PasskeyGetCredsByUsername: "resolvespec_passkey_get_credentials_by_username",
PasskeyGetCredential: "resolvespec_passkey_get_credential",
PasskeyUpdateCounter: "resolvespec_passkey_update_counter",
PasskeyGetUserCredentials: "resolvespec_passkey_get_user_credentials",
PasskeyDeleteCredential: "resolvespec_passkey_delete_credential",
PasskeyUpdateName: "resolvespec_passkey_update_name",
PasskeyLogin: "resolvespec_passkey_login",
OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser",
OAuthCreateSession: "resolvespec_oauth_createsession",
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
OAuthGetUser: "resolvespec_oauth_getuser",
}
}
// MergeSQLNames returns a copy of base with any non-empty fields from override applied.
// If override is nil, a copy of base is returned.
func MergeSQLNames(base, override *SQLNames) *SQLNames {
if override == nil {
copied := *base
return &copied
}
merged := *base
if override.Login != "" {
merged.Login = override.Login
}
if override.Register != "" {
merged.Register = override.Register
}
if override.Logout != "" {
merged.Logout = override.Logout
}
if override.Session != "" {
merged.Session = override.Session
}
if override.SessionUpdate != "" {
merged.SessionUpdate = override.SessionUpdate
}
if override.RefreshToken != "" {
merged.RefreshToken = override.RefreshToken
}
if override.JWTLogin != "" {
merged.JWTLogin = override.JWTLogin
}
if override.JWTLogout != "" {
merged.JWTLogout = override.JWTLogout
}
if override.ColumnSecurity != "" {
merged.ColumnSecurity = override.ColumnSecurity
}
if override.RowSecurity != "" {
merged.RowSecurity = override.RowSecurity
}
if override.TOTPEnable != "" {
merged.TOTPEnable = override.TOTPEnable
}
if override.TOTPDisable != "" {
merged.TOTPDisable = override.TOTPDisable
}
if override.TOTPGetStatus != "" {
merged.TOTPGetStatus = override.TOTPGetStatus
}
if override.TOTPGetSecret != "" {
merged.TOTPGetSecret = override.TOTPGetSecret
}
if override.TOTPRegenerateBackup != "" {
merged.TOTPRegenerateBackup = override.TOTPRegenerateBackup
}
if override.TOTPValidateBackupCode != "" {
merged.TOTPValidateBackupCode = override.TOTPValidateBackupCode
}
if override.PasskeyStoreCredential != "" {
merged.PasskeyStoreCredential = override.PasskeyStoreCredential
}
if override.PasskeyGetCredsByUsername != "" {
merged.PasskeyGetCredsByUsername = override.PasskeyGetCredsByUsername
}
if override.PasskeyGetCredential != "" {
merged.PasskeyGetCredential = override.PasskeyGetCredential
}
if override.PasskeyUpdateCounter != "" {
merged.PasskeyUpdateCounter = override.PasskeyUpdateCounter
}
if override.PasskeyGetUserCredentials != "" {
merged.PasskeyGetUserCredentials = override.PasskeyGetUserCredentials
}
if override.PasskeyDeleteCredential != "" {
merged.PasskeyDeleteCredential = override.PasskeyDeleteCredential
}
if override.PasskeyUpdateName != "" {
merged.PasskeyUpdateName = override.PasskeyUpdateName
}
if override.PasskeyLogin != "" {
merged.PasskeyLogin = override.PasskeyLogin
}
if override.OAuthGetOrCreateUser != "" {
merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser
}
if override.OAuthCreateSession != "" {
merged.OAuthCreateSession = override.OAuthCreateSession
}
if override.OAuthGetRefreshToken != "" {
merged.OAuthGetRefreshToken = override.OAuthGetRefreshToken
}
if override.OAuthUpdateRefreshToken != "" {
merged.OAuthUpdateRefreshToken = override.OAuthUpdateRefreshToken
}
if override.OAuthGetUser != "" {
merged.OAuthGetUser = override.OAuthGetUser
}
return &merged
}
// ValidateSQLNames checks that all non-empty fields in names are valid SQL identifiers.
// Returns an error if any field contains invalid characters.
func ValidateSQLNames(names *SQLNames) error {
v := reflect.ValueOf(names).Elem()
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Kind() != reflect.String {
continue
}
val := field.String()
if val != "" && !validSQLIdentifier.MatchString(val) {
return fmt.Errorf("SQLNames.%s contains invalid characters: %q", typ.Field(i).Name, val)
}
}
return nil
}
// resolveSQLNames merges an optional override with defaults.
// Used by constructors that accept variadic *SQLNames.
func resolveSQLNames(override ...*SQLNames) *SQLNames {
if len(override) > 0 && override[0] != nil {
return MergeSQLNames(DefaultSQLNames(), override[0])
}
return DefaultSQLNames()
}

View File

@@ -0,0 +1,145 @@
package security
import (
"reflect"
"testing"
)
func TestDefaultSQLNames_AllFieldsNonEmpty(t *testing.T) {
names := DefaultSQLNames()
v := reflect.ValueOf(names).Elem()
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.Kind() != reflect.String {
continue
}
if field.String() == "" {
t.Errorf("DefaultSQLNames().%s is empty", typ.Field(i).Name)
}
}
}
func TestMergeSQLNames_PartialOverride(t *testing.T) {
base := DefaultSQLNames()
override := &SQLNames{
Login: "custom_login",
TOTPEnable: "custom_totp_enable",
PasskeyLogin: "custom_passkey_login",
}
merged := MergeSQLNames(base, override)
if merged.Login != "custom_login" {
t.Errorf("MergeSQLNames().Login = %q, want %q", merged.Login, "custom_login")
}
if merged.TOTPEnable != "custom_totp_enable" {
t.Errorf("MergeSQLNames().TOTPEnable = %q, want %q", merged.TOTPEnable, "custom_totp_enable")
}
if merged.PasskeyLogin != "custom_passkey_login" {
t.Errorf("MergeSQLNames().PasskeyLogin = %q, want %q", merged.PasskeyLogin, "custom_passkey_login")
}
// Non-overridden fields should retain defaults
if merged.Logout != "resolvespec_logout" {
t.Errorf("MergeSQLNames().Logout = %q, want %q", merged.Logout, "resolvespec_logout")
}
if merged.Session != "resolvespec_session" {
t.Errorf("MergeSQLNames().Session = %q, want %q", merged.Session, "resolvespec_session")
}
}
func TestMergeSQLNames_NilOverride(t *testing.T) {
base := DefaultSQLNames()
merged := MergeSQLNames(base, nil)
// Should be a copy, not the same pointer
if merged == base {
t.Error("MergeSQLNames with nil override should return a copy, not the same pointer")
}
// All values should match
v1 := reflect.ValueOf(base).Elem()
v2 := reflect.ValueOf(merged).Elem()
typ := v1.Type()
for i := 0; i < v1.NumField(); i++ {
f1 := v1.Field(i)
f2 := v2.Field(i)
if f1.Kind() != reflect.String {
continue
}
if f1.String() != f2.String() {
t.Errorf("MergeSQLNames(base, nil).%s = %q, want %q", typ.Field(i).Name, f2.String(), f1.String())
}
}
}
func TestMergeSQLNames_DoesNotMutateBase(t *testing.T) {
base := DefaultSQLNames()
originalLogin := base.Login
override := &SQLNames{Login: "custom_login"}
_ = MergeSQLNames(base, override)
if base.Login != originalLogin {
t.Errorf("MergeSQLNames mutated base: Login = %q, want %q", base.Login, originalLogin)
}
}
func TestMergeSQLNames_AllFieldsMerged(t *testing.T) {
base := DefaultSQLNames()
override := &SQLNames{}
v := reflect.ValueOf(override).Elem()
for i := 0; i < v.NumField(); i++ {
if v.Field(i).Kind() == reflect.String {
v.Field(i).SetString("custom_sentinel")
}
}
merged := MergeSQLNames(base, override)
mv := reflect.ValueOf(merged).Elem()
typ := mv.Type()
for i := 0; i < mv.NumField(); i++ {
if mv.Field(i).Kind() != reflect.String {
continue
}
if mv.Field(i).String() != "custom_sentinel" {
t.Errorf("MergeSQLNames did not merge field %s", typ.Field(i).Name)
}
}
}
func TestValidateSQLNames_Valid(t *testing.T) {
names := DefaultSQLNames()
if err := ValidateSQLNames(names); err != nil {
t.Errorf("ValidateSQLNames(defaults) error = %v", err)
}
}
func TestValidateSQLNames_Invalid(t *testing.T) {
names := DefaultSQLNames()
names.Login = "resolvespec_login; DROP TABLE users; --"
err := ValidateSQLNames(names)
if err == nil {
t.Error("ValidateSQLNames should reject names with invalid characters")
}
}
func TestResolveSQLNames_NoOverride(t *testing.T) {
names := resolveSQLNames()
if names.Login != "resolvespec_login" {
t.Errorf("resolveSQLNames().Login = %q, want default", names.Login)
}
}
func TestResolveSQLNames_WithOverride(t *testing.T) {
names := resolveSQLNames(&SQLNames{Login: "custom_login"})
if names.Login != "custom_login" {
t.Errorf("resolveSQLNames().Login = %q, want %q", names.Login, "custom_login")
}
if names.Logout != "resolvespec_logout" {
t.Errorf("resolveSQLNames().Logout = %q, want default", names.Logout)
}
}

View File

@@ -9,23 +9,23 @@ import (
)
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable,
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
// See totp_database_schema.sql for procedure definitions
type DatabaseTwoFactorProvider struct {
db *sql.DB
totpGen *TOTPGenerator
db *sql.DB
totpGen *TOTPGenerator
sqlNames *SQLNames
}
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider {
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig, names ...*SQLNames) *DatabaseTwoFactorProvider {
if config == nil {
config = DefaultTwoFactorConfig()
}
return &DatabaseTwoFactorProvider{
db: db,
totpGen: NewTOTPGenerator(config),
db: db,
totpGen: NewTOTPGenerator(config),
sqlNames: resolveSQLNames(names...),
}
}
@@ -76,7 +76,7 @@ func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupC
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)`
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3::jsonb)`, p.sqlNames.TOTPEnable)
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("enable 2FA query failed: %w", err)
@@ -97,7 +97,7 @@ func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)`
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1)`, p.sqlNames.TOTPDisable)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
if err != nil {
return fmt.Errorf("disable 2FA query failed: %w", err)
@@ -119,7 +119,7 @@ func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
var errorMsg sql.NullString
var enabled bool
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_enabled FROM %s($1)`, p.sqlNames.TOTPGetStatus)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
if err != nil {
return false, fmt.Errorf("get 2FA status query failed: %w", err)
@@ -141,7 +141,7 @@ func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
var errorMsg sql.NullString
var secret sql.NullString
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_secret FROM %s($1)`, p.sqlNames.TOTPGetSecret)
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
if err != nil {
return "", fmt.Errorf("get 2FA secret query failed: %w", err)
@@ -185,7 +185,7 @@ func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) (
var success bool
var errorMsg sql.NullString
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)`
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2::jsonb)`, p.sqlNames.TOTPRegenerateBackup)
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
if err != nil {
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
@@ -212,7 +212,7 @@ func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string)
var errorMsg sql.NullString
var valid bool
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)`
query := fmt.Sprintf(`SELECT p_success, p_error, p_valid FROM %s($1, $2)`, p.sqlNames.TOTPValidateBackupCode)
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
if err != nil {
return false, fmt.Errorf("validate backup code query failed: %w", err)

View File

@@ -98,6 +98,7 @@ func (p *EmbedFSProvider) Open(name string) (fs.File, error) {
// Apply prefix stripping by prepending the prefix to the requested path
actualPath := name
alternatePath := ""
if p.stripPrefix != "" {
// Clean the paths to handle leading/trailing slashes
prefix := strings.Trim(p.stripPrefix, "/")
@@ -105,12 +106,25 @@ func (p *EmbedFSProvider) Open(name string) (fs.File, error) {
if prefix != "" {
actualPath = path.Join(prefix, cleanName)
alternatePath = cleanName
} else {
actualPath = cleanName
}
}
// First try the actual path with prefix
if file, err := p.fs.Open(actualPath); err == nil {
return file, nil
}
return p.fs.Open(actualPath)
// If alternate path is different, try it as well
if alternatePath != "" && alternatePath != actualPath {
if file, err := p.fs.Open(alternatePath); err == nil {
return file, nil
}
}
// If both attempts fail, return the error from the first attempt
return nil, fmt.Errorf("file not found: %s", name)
}
// Close releases any resources held by the provider.

View File

@@ -53,6 +53,7 @@ func (p *LocalFSProvider) Open(name string) (fs.File, error) {
// Apply prefix stripping by prepending the prefix to the requested path
actualPath := name
alternatePath := ""
if p.stripPrefix != "" {
// Clean the paths to handle leading/trailing slashes
prefix := strings.Trim(p.stripPrefix, "/")
@@ -60,12 +61,26 @@ func (p *LocalFSProvider) Open(name string) (fs.File, error) {
if prefix != "" {
actualPath = path.Join(prefix, cleanName)
alternatePath = cleanName
} else {
actualPath = cleanName
}
}
return p.fs.Open(actualPath)
// First try the actual path with prefix
if file, err := p.fs.Open(actualPath); err == nil {
return file, nil
}
// If alternate path is different, try it as well
if alternatePath != "" && alternatePath != actualPath {
if file, err := p.fs.Open(alternatePath); err == nil {
return file, nil
}
}
// If both attempts fail, return the error from the first attempt
return nil, fmt.Errorf("file not found: %s", name)
}
// Close releases any resources held by the provider.

View File

@@ -56,6 +56,7 @@ func (p *ZipFSProvider) Open(name string) (fs.File, error) {
// Apply prefix stripping by prepending the prefix to the requested path
actualPath := name
alternatePath := ""
if p.stripPrefix != "" {
// Clean the paths to handle leading/trailing slashes
prefix := strings.Trim(p.stripPrefix, "/")
@@ -63,12 +64,26 @@ func (p *ZipFSProvider) Open(name string) (fs.File, error) {
if prefix != "" {
actualPath = path.Join(prefix, cleanName)
alternatePath = cleanName
} else {
actualPath = cleanName
}
}
return p.zipFS.Open(actualPath)
// First try the actual path with prefix
if file, err := p.zipFS.Open(actualPath); err == nil {
return file, nil
}
// If alternate path is different, try it as well
if alternatePath != "" && alternatePath != actualPath {
if file, err := p.zipFS.Open(alternatePath); err == nil {
return file, nil
}
}
// If both attempts fail, return the error from the first attempt
return nil, fmt.Errorf("file not found: %s", name)
}
// Close releases resources held by the zip reader.

View File

@@ -330,6 +330,7 @@ Hooks allow you to intercept and modify operations at various points in the life
### Available Hook Types
- **BeforeHandle** — fires after model resolution, before operation dispatch (auth checks)
- **BeforeRead** / **AfterRead**
- **BeforeCreate** / **AfterCreate**
- **BeforeUpdate** / **AfterUpdate**
@@ -337,6 +338,8 @@ Hooks allow you to intercept and modify operations at various points in the life
- **BeforeSubscribe** / **AfterSubscribe**
- **BeforeConnect** / **AfterConnect**
`HookContext` includes `Operation string` (`"read"`, `"create"`, `"update"`, `"delete"`) and `Abort bool`, `AbortMessage string`, `AbortCode int` for abort signaling.
### Hook Example
```go
@@ -599,7 +602,19 @@ asyncio.run(main())
## Authentication
Implement authentication using hooks:
Use `RegisterSecurityHooks` for integrated auth with model-rule support:
```go
import "github.com/bitechdev/ResolveSpec/pkg/security"
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
securityList := security.NewSecurityList(provider)
websocketspec.RegisterSecurityHooks(handler, securityList)
// Registers BeforeHandle (model auth), BeforeRead (load rules),
// AfterRead (column security + audit), BeforeUpdate, BeforeDelete
```
Or implement custom authentication using hooks directly:
```go
handler := websocketspec.NewHandlerWithGORM(db)

View File

@@ -177,6 +177,16 @@ func (h *Handler) handleRequest(conn *Connection, msg *Message) {
Metadata: make(map[string]interface{}),
}
// Execute BeforeHandle hook - auth check fires here, after model resolution
hookCtx.Operation = string(msg.Operation)
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
if hookCtx.Abort {
errResp := NewErrorResponse(msg.ID, "unauthorized", hookCtx.AbortMessage)
_ = conn.SendJSON(errResp)
}
return
}
// Route to operation handler
switch msg.Operation {
case OperationRead:
@@ -618,7 +628,10 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
if hookCtx.Options != nil {
for _, filter := range hookCtx.Options.Filters {
countQuery = countQuery.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
cond, args := h.buildFilterCondition(filter)
if cond != "" {
countQuery = countQuery.Where(cond, args...)
}
}
}
count, _ := countQuery.Count(hookCtx.Context)
@@ -790,14 +803,12 @@ func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.Fi
// buildFilterCondition builds a filter condition and returns it with args
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
var condition string
var args []interface{}
if strings.EqualFold(filter.Operator, "in") {
cond, args := common.BuildInCondition(filter.Column, filter.Value)
return cond, args
}
operatorSQL := h.getOperatorSQL(filter.Operator)
condition = fmt.Sprintf("%s %s ?", filter.Column, operatorSQL)
args = []interface{}{filter.Value}
return condition, args
return fmt.Sprintf("%s %s ?", filter.Column, operatorSQL), []interface{}{filter.Value}
}
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists

View File

@@ -2,6 +2,7 @@ package websocketspec
import (
"context"
"fmt"
"github.com/bitechdev/ResolveSpec/pkg/common"
)
@@ -10,6 +11,10 @@ import (
type HookType string
const (
// BeforeHandle fires after model resolution, before operation dispatch.
// Use this for auth checks that need model rules and user context simultaneously.
BeforeHandle HookType = "before_handle"
// BeforeRead is called before a read operation
BeforeRead HookType = "before_read"
// AfterRead is called after a read operation
@@ -83,6 +88,9 @@ type HookContext struct {
// Options contains the parsed request options
Options *common.RequestOptions
// Operation being dispatched (e.g. "read", "create", "update", "delete")
Operation string
// ID is the record ID for single-record operations
ID string
@@ -98,6 +106,11 @@ type HookContext struct {
// Error is any error that occurred (for after hooks)
Error error
// Allow hooks to abort the operation
Abort bool // If set to true, the operation will be aborted
AbortMessage string // Message to return if aborted
AbortCode int // HTTP status code if aborted
// Metadata is additional context data
Metadata map[string]interface{}
}
@@ -171,6 +184,11 @@ func (hr *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
if err := hook(ctx); err != nil {
return err
}
// Check if hook requested abort
if ctx.Abort {
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
}
}
return nil

View File

@@ -2,6 +2,7 @@ package websocketspec
import (
"context"
"net/http"
"github.com/bitechdev/ResolveSpec/pkg/logger"
"github.com/bitechdev/ResolveSpec/pkg/security"
@@ -9,6 +10,17 @@ import (
// RegisterSecurityHooks registers all security-related hooks with the handler
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
// Hook 0: BeforeHandle - enforce auth after model resolution
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
hookCtx.Abort = true
hookCtx.AbortMessage = err.Error()
hookCtx.AbortCode = http.StatusUnauthorized
return err
}
return nil
})
// Hook 1: BeforeRead - Load security rules
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
secCtx := newSecurityContext(hookCtx)