mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-09 09:26:24 +00:00
Compare commits
37 Commits
v1.0.52
...
feature-au
| Author | SHA1 | Date | |
|---|---|---|---|
| 6502b55797 | |||
| aa095d6bfd | |||
| ea5bb38ee4 | |||
| c2e2c9b873 | |||
| 4adf94fe37 | |||
|
|
405a04a192 | ||
|
|
c1b16d363a | ||
|
|
568df8c6d6 | ||
|
|
aa362c77da | ||
|
|
1641eaf278 | ||
|
|
200a03c225 | ||
|
|
7ef9cf39d3 | ||
|
|
7f6410f665 | ||
|
|
835bbb0727 | ||
|
|
047a1cc187 | ||
|
|
7a498edab7 | ||
|
|
f10bb0827e | ||
|
|
22a4ab345a | ||
|
|
e289c2ed8f | ||
|
|
0d50bcfee6 | ||
| 4df626ea71 | |||
|
|
7dd630dec2 | ||
|
|
613bf22cbd | ||
| d1ae4fe64e | |||
| 254102bfac | |||
| 6c27419dbc | |||
| 377336caf4 | |||
| 79720d5421 | |||
| e7ab0a20d6 | |||
| e4087104a9 | |||
|
|
17e580a9d3 | ||
|
|
337a007d57 | ||
|
|
e923b0a2a3 | ||
| ea4a4371ba | |||
| b3694e50fe | |||
| b76dae5991 | |||
| dc85008d7f |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -26,4 +26,6 @@ go.work.sum
|
||||
bin/
|
||||
test.db
|
||||
/testserver
|
||||
tests/data/
|
||||
tests/data/
|
||||
node_modules/
|
||||
resolvespec-js/dist/
|
||||
|
||||
27
LICENSE
27
LICENSE
@@ -1,3 +1,18 @@
|
||||
Project Notice
|
||||
|
||||
This project was independently developed.
|
||||
|
||||
The contents of this repository were prepared and published outside any time
|
||||
allocated to Bitech Systems CC and do not contain, incorporate, disclose,
|
||||
or rely upon any proprietary or confidential information, trade secrets,
|
||||
protected designs, or other intellectual property of Bitech Systems CC.
|
||||
|
||||
No portion of this repository reproduces any Bitech Systems CC-specific
|
||||
implementation, design asset, confidential workflow, or non-public technical material.
|
||||
|
||||
This notice is provided for clarification only and does not modify the terms of
|
||||
the Apache License, Version 2.0.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
@@ -32,15 +47,15 @@ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
||||
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
|
||||
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
||||
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
||||
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
||||
|
||||
@@ -56,7 +71,7 @@ END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
|
||||
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
|
||||
|
||||
Copyright 2025 wdevs
|
||||
|
||||
|
||||
82
README.md
82
README.md
@@ -9,6 +9,7 @@ ResolveSpec is a flexible and powerful REST API specification and implementation
|
||||
3. **FuncSpec** - Header-based API to map and call API's to sql functions
|
||||
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
|
||||
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
|
||||
6. **ResolveMCP** - Model Context Protocol (MCP) server that exposes models as AI-consumable tools and resources over HTTP/SSE
|
||||
|
||||
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
||||
|
||||
@@ -21,6 +22,7 @@ All share the same core architecture and provide dynamic data querying, relation
|
||||
* [Quick Start](#quick-start)
|
||||
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
|
||||
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
|
||||
* [ResolveMCP (MCP Server)](#resolvemcp---mcp-server)
|
||||
* [Architecture](#architecture)
|
||||
* [API Structure](#api-structure)
|
||||
* [RestHeadSpec Overview](#restheadspec-header-based-api)
|
||||
@@ -50,6 +52,15 @@ All share the same core architecture and provide dynamic data querying, relation
|
||||
* **🆕 Backward Compatible**: Existing code works without changes
|
||||
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
|
||||
|
||||
### ResolveMCP (v3.2+)
|
||||
|
||||
* **🆕 MCP Server**: Expose any registered database model as Model Context Protocol tools and resources
|
||||
* **🆕 AI-Ready Descriptions**: Tool descriptions include the full column schema, primary key, nullable flags, and relations — giving AI models everything they need to query correctly without guessing
|
||||
* **🆕 Four Tools Per Model**: `read_`, `create_`, `update_`, `delete_` tools auto-registered per model
|
||||
* **🆕 Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
|
||||
* **🆕 HTTP/SSE Transport**: Standards-compliant SSE transport for use with Claude Desktop, Cursor, and any MCP-compatible client
|
||||
* **🆕 Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth and side-effects
|
||||
|
||||
### RestHeadSpec (v2.1+)
|
||||
|
||||
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
|
||||
@@ -190,6 +201,40 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||
|
||||
### ResolveMCP (MCP Server)
|
||||
|
||||
ResolveMCP exposes registered models as Model Context Protocol tools so AI models (Claude, Cursor, etc.) can query and mutate your database directly:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
|
||||
|
||||
// Create handler
|
||||
handler := resolvemcp.NewHandlerWithGORM(db)
|
||||
|
||||
// Register models — must be done BEFORE Build()
|
||||
handler.RegisterModel("public", "users", &User{})
|
||||
handler.RegisterModel("public", "posts", &Post{})
|
||||
|
||||
// Finalize: registers MCP tools and resources
|
||||
handler.Build()
|
||||
|
||||
// Mount SSE transport on your existing router
|
||||
router := mux.NewRouter()
|
||||
resolvemcp.SetupMuxRoutes(router, handler, "http://localhost:8080")
|
||||
|
||||
// MCP clients connect to:
|
||||
// SSE stream: GET http://localhost:8080/mcp/sse
|
||||
// Messages: POST http://localhost:8080/mcp/message
|
||||
//
|
||||
// Auto-registered tools per model:
|
||||
// read_public_users — filter, sort, paginate, preload
|
||||
// create_public_users — insert a new record
|
||||
// update_public_users — update a record by ID
|
||||
// delete_public_users — delete a record by ID
|
||||
```
|
||||
|
||||
For complete documentation, see [pkg/resolvemcp/README.md](pkg/resolvemcp/README.md) (if present) or the package source.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Two Complementary APIs
|
||||
@@ -344,6 +389,19 @@ Alternative REST API where query options are passed via HTTP headers.
|
||||
|
||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||
|
||||
#### ResolveMCP - MCP Server
|
||||
|
||||
Expose any registered model as Model Context Protocol tools and resources consumable by AI models over HTTP/SSE.
|
||||
|
||||
**Key Features**:
|
||||
- Four tools per model: `read_`, `create_`, `update_`, `delete_`
|
||||
- Rich AI-readable descriptions: column names, types, primary key, nullable flags, and preloadable relations
|
||||
- Full query support: filters, sort, limit/offset, cursor pagination, column selection, preloads
|
||||
- HTTP/SSE transport compatible with Claude Desktop, Cursor, and any MCP client
|
||||
- Same Before/After lifecycle hooks as ResolveSpec
|
||||
|
||||
For complete documentation, see [pkg/resolvemcp/](pkg/resolvemcp/).
|
||||
|
||||
#### FuncSpec - Function-Based SQL API
|
||||
|
||||
Execute SQL functions and queries through a simple HTTP API with header-based parameters.
|
||||
@@ -357,6 +415,17 @@ Execute SQL functions and queries through a simple HTTP API with header-based pa
|
||||
|
||||
For complete documentation, see [pkg/funcspec/](pkg/funcspec/).
|
||||
|
||||
#### ResolveSpec JS - TypeScript Client Library
|
||||
|
||||
TypeScript/JavaScript client library supporting all three REST and WebSocket protocols.
|
||||
|
||||
**Clients**:
|
||||
- Body-based REST client (`read`, `create`, `update`, `deleteEntity`)
|
||||
- Header-based REST client (`HeaderSpecClient`)
|
||||
- WebSocket client (`WebSocketClient`) with CRUD, subscriptions, heartbeat, reconnect
|
||||
|
||||
For complete documentation, see [resolvespec-js/README.md](resolvespec-js/README.md).
|
||||
|
||||
### Real-Time Communication
|
||||
|
||||
#### WebSocketSpec - WebSocket API
|
||||
@@ -518,7 +587,18 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
|
||||
## What's New
|
||||
|
||||
### v3.1 (Latest - February 2026)
|
||||
### v3.2 (Latest - March 2026)
|
||||
|
||||
**ResolveMCP - Model Context Protocol Server (🆕)**:
|
||||
|
||||
* **MCP Tools**: Four tools auto-registered per model (`read_`, `create_`, `update_`, `delete_`) over HTTP/SSE transport
|
||||
* **AI-Ready Descriptions**: Full column schema, primary key, nullable flags, and relation names surfaced in tool descriptions so AI models can query without guessing
|
||||
* **Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
|
||||
* **HTTP/SSE Transport**: Standards-compliant transport compatible with Claude Desktop, Cursor, and any MCP 2024-11-05 client
|
||||
* **Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth, auditing, and side-effects
|
||||
* **MCP Resources**: Each model also exposed as a named resource for direct data access by AI clients
|
||||
|
||||
### v3.1 (February 2026)
|
||||
|
||||
**SQLite Schema Translation (🆕)**:
|
||||
|
||||
|
||||
5
go.mod
5
go.mod
@@ -15,6 +15,7 @@ require (
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/klauspost/compress v1.18.2
|
||||
github.com/mark3labs/mcp-go v0.46.0
|
||||
github.com/mattn/go-sqlite3 v1.14.33
|
||||
github.com/microsoft/go-mssqldb v1.9.5
|
||||
github.com/mochi-mqtt/server/v2 v2.7.9
|
||||
@@ -40,6 +41,7 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
go.uber.org/zap v1.27.1
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/oauth2 v0.34.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
@@ -78,6 +80,7 @@ require (
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/snappy v1.0.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
@@ -131,6 +134,7 @@ require (
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
github.com/xdg-go/scram v1.2.0 // indirect
|
||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
@@ -143,7 +147,6 @@ require (
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
|
||||
golang.org/x/mod v0.31.0 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/oauth2 v0.34.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@@ -120,6 +120,8 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -173,6 +175,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
|
||||
@@ -326,6 +330,8 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
|
||||
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
|
||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
|
||||
@@ -394,12 +394,12 @@ func (p *PgSQLSelectQuery) buildSQL() string {
|
||||
|
||||
// LIMIT clause
|
||||
if p.limit > 0 {
|
||||
sb.WriteString(fmt.Sprintf(" LIMIT %d", p.limit))
|
||||
fmt.Fprintf(&sb, " LIMIT %d", p.limit)
|
||||
}
|
||||
|
||||
// OFFSET clause
|
||||
if p.offset > 0 {
|
||||
sb.WriteString(fmt.Sprintf(" OFFSET %d", p.offset))
|
||||
fmt.Fprintf(&sb, " OFFSET %d", p.offset)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
|
||||
@@ -2,6 +2,7 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
@@ -167,16 +168,17 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
}
|
||||
|
||||
// Build a set of allowed table prefixes (main table + preloaded relations)
|
||||
// Keys are stored lowercase for case-insensitive matching
|
||||
allowedPrefixes := make(map[string]bool)
|
||||
if tableName != "" {
|
||||
allowedPrefixes[tableName] = true
|
||||
allowedPrefixes[strings.ToLower(tableName)] = true
|
||||
}
|
||||
|
||||
// Add preload relation names as allowed prefixes
|
||||
if len(options) > 0 && options[0] != nil {
|
||||
for pi := range options[0].Preload {
|
||||
if options[0].Preload[pi].Relation != "" {
|
||||
allowedPrefixes[options[0].Preload[pi].Relation] = true
|
||||
allowedPrefixes[strings.ToLower(options[0].Preload[pi].Relation)] = true
|
||||
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
||||
}
|
||||
}
|
||||
@@ -184,7 +186,7 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
// Add join aliases as allowed prefixes
|
||||
for _, alias := range options[0].JoinAliases {
|
||||
if alias != "" {
|
||||
allowedPrefixes[alias] = true
|
||||
allowedPrefixes[strings.ToLower(alias)] = true
|
||||
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
|
||||
}
|
||||
}
|
||||
@@ -216,8 +218,8 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
currentPrefix, columnName := extractTableAndColumn(condToCheck)
|
||||
|
||||
if currentPrefix != "" && columnName != "" {
|
||||
// Check if the prefix is allowed (main table or preload relation)
|
||||
if !allowedPrefixes[currentPrefix] {
|
||||
// Check if the prefix is allowed (main table or preload relation) - case-insensitive
|
||||
if !allowedPrefixes[strings.ToLower(currentPrefix)] {
|
||||
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
|
||||
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||
// Replace the incorrect prefix with the correct main table name
|
||||
@@ -925,3 +927,36 @@ func extractLeftSideOfComparison(cond string) string {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// FilterValueToSlice converts a filter value to []interface{} for use with IN operators.
|
||||
// JSON-decoded arrays arrive as []interface{}, but typed slices (e.g. []string) also work.
|
||||
// Returns a single-element slice if the value is not a slice type.
|
||||
func FilterValueToSlice(v interface{}) []interface{} {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Slice {
|
||||
result := make([]interface{}, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
result[i] = rv.Index(i).Interface()
|
||||
}
|
||||
return result
|
||||
}
|
||||
return []interface{}{v}
|
||||
}
|
||||
|
||||
// BuildInCondition builds a parameterized IN condition from a filter value.
|
||||
// Returns the condition string (e.g. "col IN (?,?)") and the individual values as args.
|
||||
// Returns ("", nil) if the value is empty or not a slice.
|
||||
func BuildInCondition(column string, v interface{}) (query string, args []interface{}) {
|
||||
values := FilterValueToSlice(v)
|
||||
if len(values) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
placeholders := make([]string, len(values))
|
||||
for i := range values {
|
||||
placeholders[i] = "?"
|
||||
}
|
||||
return fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")), values
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ type Connection interface {
|
||||
Bun() (*bun.DB, error)
|
||||
GORM() (*gorm.DB, error)
|
||||
Native() (*sql.DB, error)
|
||||
DB() (*sql.DB, error)
|
||||
|
||||
// Common Database interface (for SQL databases)
|
||||
Database() (common.Database, error)
|
||||
@@ -224,6 +225,11 @@ func (c *sqlConnection) Native() (*sql.DB, error) {
|
||||
return c.nativeDB, nil
|
||||
}
|
||||
|
||||
// DB returns the underlying *sql.DB connection
|
||||
func (c *sqlConnection) DB() (*sql.DB, error) {
|
||||
return c.Native()
|
||||
}
|
||||
|
||||
// Bun returns a Bun ORM instance wrapping the native connection
|
||||
func (c *sqlConnection) Bun() (*bun.DB, error) {
|
||||
if c == nil {
|
||||
@@ -645,6 +651,11 @@ func (c *mongoConnection) Native() (*sql.DB, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// DB returns an error for MongoDB connections
|
||||
func (c *mongoConnection) DB() (*sql.DB, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// Database returns an error for MongoDB connections
|
||||
func (c *mongoConnection) Database() (common.Database, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
|
||||
@@ -3,6 +3,7 @@ package providers_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
|
||||
@@ -29,14 +30,14 @@ func ExamplePostgresListener_basic() {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||
log.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
// Get listener
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||
log.Fatalf("Failed to get listener: %v", err)
|
||||
}
|
||||
|
||||
// Subscribe to a channel with a handler
|
||||
@@ -44,13 +45,13 @@ func ExamplePostgresListener_basic() {
|
||||
fmt.Printf("Received notification on %s: %s\n", channel, payload)
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to listen: %v", err))
|
||||
log.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
|
||||
// Send a notification
|
||||
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to notify: %v", err))
|
||||
log.Fatalf("Failed to notify: %v", err)
|
||||
}
|
||||
|
||||
// Wait for notification to be processed
|
||||
@@ -58,7 +59,7 @@ func ExamplePostgresListener_basic() {
|
||||
|
||||
// Unsubscribe from the channel
|
||||
if err := listener.Unlisten("user_events"); err != nil {
|
||||
panic(fmt.Sprintf("Failed to unlisten: %v", err))
|
||||
log.Fatalf("Failed to unlisten: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,13 +81,13 @@ func ExamplePostgresListener_multipleChannels() {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||
log.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||
log.Fatalf("Failed to get listener: %v", err)
|
||||
}
|
||||
|
||||
// Listen to multiple channels
|
||||
@@ -97,7 +98,7 @@ func ExamplePostgresListener_multipleChannels() {
|
||||
fmt.Printf("[%s] %s\n", ch, payload)
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to listen on %s: %v", channel, err))
|
||||
log.Fatalf("Failed to listen on %s: %v", channel, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,14 +141,14 @@ func ExamplePostgresListener_withDBManager() {
|
||||
|
||||
provider := providers.NewPostgresProvider()
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
// Get listener
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Subscribe to application events
|
||||
@@ -186,13 +187,13 @@ func ExamplePostgresListener_errorHandling() {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||
log.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||
log.Fatalf("Failed to get listener: %v", err)
|
||||
}
|
||||
|
||||
// The listener automatically reconnects if the connection is lost
|
||||
|
||||
@@ -2,14 +2,38 @@ package funcspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks registers security hooks for funcspec handlers
|
||||
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
|
||||
// We provide audit logging for data access tracking
|
||||
// We provide auth enforcement and audit logging for data access tracking
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeQueryList - Auth check before list query execution
|
||||
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
||||
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = "authentication required"
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 0: BeforeQuery - Auth check before single query execution
|
||||
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
|
||||
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = "authentication required"
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeQueryList - Audit logging before query list execution
|
||||
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
||||
secCtx := newFuncSpecSecurityContext(hookCtx)
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
|
||||
// ModelRules defines the permissions and security settings for a model
|
||||
type ModelRules struct {
|
||||
CanPublicRead bool // Whether the model can be read (GET operations)
|
||||
CanPublicUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||
CanPublicCreate bool // Whether the model can be created (POST operations)
|
||||
CanPublicDelete bool // Whether the model can be deleted (DELETE operations)
|
||||
CanRead bool // Whether the model can be read (GET operations)
|
||||
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||
CanCreate bool // Whether the model can be created (POST operations)
|
||||
@@ -22,6 +26,10 @@ func DefaultModelRules() ModelRules {
|
||||
CanUpdate: true,
|
||||
CanCreate: true,
|
||||
CanDelete: true,
|
||||
CanPublicRead: false,
|
||||
CanPublicUpdate: false,
|
||||
CanPublicCreate: false,
|
||||
CanPublicDelete: false,
|
||||
SecurityDisabled: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ MQTTSpec is an MQTT-based database query framework that enables real-time databa
|
||||
- **Full CRUD Operations**: Create, Read, Update, Delete with hooks
|
||||
- **Real-time Subscriptions**: Subscribe to entity changes with filtering
|
||||
- **Database Agnostic**: GORM and Bun ORM support
|
||||
- **Lifecycle Hooks**: 12 hooks for authentication, authorization, validation, and auditing
|
||||
- **Lifecycle Hooks**: 13 hooks for authentication, authorization, validation, and auditing
|
||||
- **Multi-tenancy Support**: Built-in tenant isolation via hooks
|
||||
- **Thread-safe**: Proper concurrency handling throughout
|
||||
|
||||
@@ -326,10 +326,11 @@ When any client creates/updates/deletes a user matching the subscription filters
|
||||
|
||||
## Lifecycle Hooks
|
||||
|
||||
MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
|
||||
MQTTSpec provides 13 lifecycle hooks for implementing cross-cutting concerns:
|
||||
|
||||
### Hook Types
|
||||
|
||||
- `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
|
||||
- `BeforeConnect` / `AfterConnect` - Connection lifecycle
|
||||
- `BeforeDisconnect` / `AfterDisconnect` - Disconnection lifecycle
|
||||
- `BeforeRead` / `AfterRead` - Read operations
|
||||
@@ -339,6 +340,20 @@ MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
|
||||
- `BeforeSubscribe` / `AfterSubscribe` - Subscription creation
|
||||
- `BeforeUnsubscribe` / `AfterUnsubscribe` - Subscription removal
|
||||
|
||||
### Security Hooks (Recommended)
|
||||
|
||||
Use `RegisterSecurityHooks` for integrated auth with model-rule support:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList := security.NewSecurityList(provider)
|
||||
mqttspec.RegisterSecurityHooks(handler, securityList)
|
||||
// Registers BeforeHandle (model auth), BeforeRead (load rules),
|
||||
// AfterRead (column security + audit), BeforeUpdate, BeforeDelete
|
||||
```
|
||||
|
||||
### Authentication Example (JWT)
|
||||
|
||||
```go
|
||||
@@ -657,7 +672,7 @@ handler, err := mqttspec.NewHandlerWithGORM(db,
|
||||
| **Network Efficiency** | Better for unreliable networks | Better for low-latency |
|
||||
| **Best For** | IoT, mobile apps, distributed systems | Web applications, real-time dashboards |
|
||||
| **Message Protocol** | Same JSON structure | Same JSON structure |
|
||||
| **Hooks** | Same 12 hooks | Same 12 hooks |
|
||||
| **Hooks** | Same 13 hooks | Same 13 hooks |
|
||||
| **CRUD Operations** | Identical | Identical |
|
||||
| **Subscriptions** | Identical (via MQTT topics) | Identical (via app-level) |
|
||||
|
||||
|
||||
@@ -284,6 +284,15 @@ func (h *Handler) handleRequest(client *Client, msg *Message) {
|
||||
},
|
||||
}
|
||||
|
||||
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||
hookCtx.Operation = string(msg.Operation)
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
if hookCtx.Abort {
|
||||
h.sendError(client.ID, msg.ID, "unauthorized", hookCtx.AbortMessage)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Route to operation handler
|
||||
switch msg.Operation {
|
||||
case OperationRead:
|
||||
|
||||
@@ -20,8 +20,11 @@ type (
|
||||
HookRegistry = websocketspec.HookRegistry
|
||||
)
|
||||
|
||||
// Hook type constants - all 12 lifecycle hooks
|
||||
// Hook type constants - all lifecycle hooks
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch
|
||||
BeforeHandle = websocketspec.BeforeHandle
|
||||
|
||||
// CRUD operation hooks
|
||||
BeforeRead = websocketspec.BeforeRead
|
||||
AfterRead = websocketspec.AfterRead
|
||||
|
||||
108
pkg/mqttspec/security_hooks.go
Normal file
108
pkg/mqttspec/security_hooks.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the MQTT handler
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.LoadSecurityRules(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 2: AfterRead - Apply column-level security (masking)
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.ApplyColumnSecurity(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 3 (Optional): Audit logging
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Hook 4: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelUpdateAllowed(secCtx)
|
||||
})
|
||||
|
||||
// Hook 5: BeforeDelete - enforce CanDelete rule from context/registry
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelDeleteAllowed(secCtx)
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for mqttspec handler")
|
||||
}
|
||||
|
||||
// securityContext adapts mqttspec.HookContext to security.SecurityContext interface
|
||||
type securityContext struct {
|
||||
ctx *HookContext
|
||||
}
|
||||
|
||||
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||
return &securityContext{ctx: ctx}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetContext() context.Context {
|
||||
return s.ctx.Context
|
||||
}
|
||||
|
||||
func (s *securityContext) GetUserID() (int, bool) {
|
||||
return security.GetUserID(s.ctx.Context)
|
||||
}
|
||||
|
||||
func (s *securityContext) GetSchema() string {
|
||||
return s.ctx.Schema
|
||||
}
|
||||
|
||||
func (s *securityContext) GetEntity() string {
|
||||
return s.ctx.Entity
|
||||
}
|
||||
|
||||
func (s *securityContext) GetModel() interface{} {
|
||||
return s.ctx.Model
|
||||
}
|
||||
|
||||
// GetQuery retrieves a stored query from hook metadata
|
||||
func (s *securityContext) GetQuery() interface{} {
|
||||
if s.ctx.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
return s.ctx.Metadata["query"]
|
||||
}
|
||||
|
||||
// SetQuery stores the query in hook metadata
|
||||
func (s *securityContext) SetQuery(query interface{}) {
|
||||
if s.ctx.Metadata == nil {
|
||||
s.ctx.Metadata = make(map[string]interface{})
|
||||
}
|
||||
s.ctx.Metadata["query"] = query
|
||||
}
|
||||
|
||||
func (s *securityContext) GetResult() interface{} {
|
||||
return s.ctx.Result
|
||||
}
|
||||
|
||||
func (s *securityContext) SetResult(result interface{}) {
|
||||
s.ctx.Result = result
|
||||
}
|
||||
670
pkg/resolvemcp/README.md
Normal file
670
pkg/resolvemcp/README.md
Normal file
@@ -0,0 +1,670 @@
|
||||
# resolvemcp
|
||||
|
||||
Package `resolvemcp` exposes registered database models as **Model Context Protocol (MCP) tools and resources** over HTTP/SSE transport. It mirrors the `resolvespec` package patterns — same model registration API, same filter/sort/pagination/preload options, same lifecycle hook system.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// 1. Create a handler
|
||||
handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{
|
||||
BaseURL: "http://localhost:8080",
|
||||
})
|
||||
|
||||
// 2. Register models
|
||||
handler.RegisterModel("public", "users", &User{})
|
||||
handler.RegisterModel("public", "orders", &Order{})
|
||||
|
||||
// 3. Mount routes
|
||||
r := mux.NewRouter()
|
||||
resolvemcp.SetupMuxRoutes(r, handler)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Config
|
||||
|
||||
```go
|
||||
type Config struct {
|
||||
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||
// Sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||
// If empty, it is detected from each incoming request using the Host header and
|
||||
// TLS state (X-Forwarded-Proto is honoured for reverse-proxy deployments).
|
||||
BaseURL string
|
||||
|
||||
// BasePath is the URL path prefix where MCP endpoints are mounted (e.g. "/mcp").
|
||||
// Required.
|
||||
BasePath string
|
||||
}
|
||||
```
|
||||
|
||||
## Handler Creation
|
||||
|
||||
| Function | Description |
|
||||
|---|---|
|
||||
| `NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler` | Backed by GORM |
|
||||
| `NewHandlerWithBun(db *bun.DB, cfg Config) *Handler` | Backed by Bun |
|
||||
| `NewHandlerWithDB(db common.Database, cfg Config) *Handler` | Backed by any `common.Database` |
|
||||
| `NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler` | Full control over registry |
|
||||
|
||||
---
|
||||
|
||||
## Registering Models
|
||||
|
||||
```go
|
||||
handler.RegisterModel(schema, entity string, model interface{}) error
|
||||
```
|
||||
|
||||
- `schema` — database schema name (e.g. `"public"`), or empty string for no schema prefix.
|
||||
- `entity` — table/entity name (e.g. `"users"`).
|
||||
- `model` — a pointer to a struct (e.g. `&User{}`).
|
||||
|
||||
Each call immediately creates four MCP **tools** and one MCP **resource** for the model.
|
||||
|
||||
---
|
||||
|
||||
## HTTP Transports
|
||||
|
||||
`Config.BasePath` is required and used for all route registration.
|
||||
`Config.BaseURL` is optional — when empty it is detected from each request.
|
||||
|
||||
Two transports are supported: **SSE** (legacy, two-endpoint) and **Streamable HTTP** (recommended, single-endpoint).
|
||||
|
||||
---
|
||||
|
||||
### SSE Transport
|
||||
|
||||
Two endpoints: `GET {BasePath}/sse` (subscribe) + `POST {BasePath}/message` (send).
|
||||
|
||||
#### Gorilla Mux
|
||||
|
||||
```go
|
||||
resolvemcp.SetupMuxRoutes(r, handler)
|
||||
```
|
||||
|
||||
| Route | Method | Description |
|
||||
|---|---|---|
|
||||
| `{BasePath}/sse` | GET | SSE connection — clients subscribe here |
|
||||
| `{BasePath}/message` | POST | JSON-RPC — clients send requests here |
|
||||
|
||||
#### bunrouter
|
||||
|
||||
```go
|
||||
resolvemcp.SetupBunRouterRoutes(router, handler)
|
||||
```
|
||||
|
||||
#### Gin / net/http / Echo
|
||||
|
||||
```go
|
||||
sse := handler.SSEServer()
|
||||
|
||||
engine.Any("/mcp/*path", gin.WrapH(sse)) // Gin
|
||||
http.Handle("/mcp/", sse) // net/http
|
||||
e.Any("/mcp/*", echo.WrapHandler(sse)) // Echo
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Streamable HTTP Transport
|
||||
|
||||
Single endpoint at `{BasePath}`. Handles POST (client→server) and GET (server→client streaming). Preferred for new integrations.
|
||||
|
||||
#### Gorilla Mux
|
||||
|
||||
```go
|
||||
resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler)
|
||||
```
|
||||
|
||||
Mounts the handler at `{BasePath}` (all methods).
|
||||
|
||||
#### bunrouter
|
||||
|
||||
```go
|
||||
resolvemcp.SetupBunRouterStreamableHTTPRoutes(router, handler)
|
||||
```
|
||||
|
||||
Registers GET, POST, DELETE on `{BasePath}`.
|
||||
|
||||
#### Gin / net/http / Echo
|
||||
|
||||
```go
|
||||
h := handler.StreamableHTTPServer()
|
||||
// or: h := resolvemcp.NewStreamableHTTPHandler(handler)
|
||||
|
||||
engine.Any("/mcp", gin.WrapH(h)) // Gin
|
||||
http.Handle("/mcp", h) // net/http
|
||||
e.Any("/mcp", echo.WrapHandler(h)) // Echo
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## OAuth2 Authentication
|
||||
|
||||
`resolvemcp` ships a full **MCP-standard OAuth2 authorization server** (`pkg/security.OAuthServer`) that MCP clients (Claude Desktop, Cursor, etc.) can discover and use automatically.
|
||||
|
||||
It can operate as:
|
||||
- **Its own identity provider** — shows a login form, validates via `DatabaseAuthenticator.Login()`
|
||||
- **An OAuth2 federation layer** — delegates to external providers (Google, GitHub, Microsoft, etc.)
|
||||
- **Both simultaneously**
|
||||
|
||||
### Standard endpoints served
|
||||
|
||||
| Path | Spec | Purpose |
|
||||
|---|---|---|
|
||||
| `GET /.well-known/oauth-authorization-server` | RFC 8414 | MCP client auto-discovery |
|
||||
| `POST /oauth/register` | RFC 7591 | Dynamic client registration |
|
||||
| `GET /oauth/authorize` | OAuth 2.1 + PKCE | Start login (form or provider redirect) |
|
||||
| `POST /oauth/authorize` | — | Login form submission |
|
||||
| `POST /oauth/token` | OAuth 2.1 | Auth code → Bearer token exchange |
|
||||
| `POST /oauth/token` (refresh) | OAuth 2.1 | Refresh token rotation |
|
||||
| `GET /oauth/provider/callback` | Internal | External provider redirect target |
|
||||
|
||||
MCP clients send `Authorization: Bearer <token>` on all subsequent requests.
|
||||
|
||||
---
|
||||
|
||||
### Mode 1 — Direct login (server as identity provider)
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
db, _ := sql.Open("postgres", dsn)
|
||||
auth := security.NewDatabaseAuthenticator(db)
|
||||
|
||||
handler := resolvemcp.NewHandlerWithGORM(gormDB, resolvemcp.Config{
|
||||
BaseURL: "https://api.example.com",
|
||||
BasePath: "/mcp",
|
||||
})
|
||||
|
||||
// Enable the OAuth2 server — auth enables the login form
|
||||
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
}, auth)
|
||||
|
||||
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList, _ := security.NewSecurityList(provider)
|
||||
security.RegisterSecurityHooks(handler, securityList)
|
||||
|
||||
http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
|
||||
```
|
||||
|
||||
MCP client flow:
|
||||
1. Discovers server at `/.well-known/oauth-authorization-server`
|
||||
2. Registers itself at `/oauth/register`
|
||||
3. Redirects user to `/oauth/authorize` → login form appears
|
||||
4. On submit, exchanges code at `/oauth/token` → receives `Authorization: Bearer` token
|
||||
5. Uses token on all MCP tool calls
|
||||
|
||||
---
|
||||
|
||||
### Mode 2 — External provider (Google, GitHub, etc.)
|
||||
|
||||
The `RedirectURL` in the provider config must point to `/oauth/provider/callback` on this server.
|
||||
|
||||
```go
|
||||
auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
|
||||
ClientID: os.Getenv("GOOGLE_CLIENT_ID"),
|
||||
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
|
||||
RedirectURL: "https://api.example.com/oauth/provider/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
})
|
||||
|
||||
// nil = no password login; Google handles auth
|
||||
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
}, nil)
|
||||
handler.RegisterOAuth2Provider(auth, "google")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Mode 3 — Both (login form + external providers)
|
||||
|
||||
```go
|
||||
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
LoginTitle: "My App Login",
|
||||
}, auth) // auth enables the username/password form
|
||||
|
||||
handler.RegisterOAuth2Provider(googleAuth, "google")
|
||||
handler.RegisterOAuth2Provider(githubAuth, "github")
|
||||
```
|
||||
|
||||
When external providers are registered they take priority; the login form is used as fallback when no providers are configured.
|
||||
|
||||
---
|
||||
|
||||
### Using `security.OAuthServer` standalone
|
||||
|
||||
The authorization server lives in `pkg/security` and can be used with any HTTP framework independently of `resolvemcp`:
|
||||
|
||||
```go
|
||||
oauthSrv := security.NewOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
}, auth)
|
||||
oauthSrv.RegisterExternalProvider(googleAuth, "google")
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/", oauthSrv.HTTPHandler()) // mounts all OAuth2 routes
|
||||
mux.Handle("/mcp/", myMCPHandler)
|
||||
http.ListenAndServe(":8080", mux)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Cookie-based flow (legacy)
|
||||
|
||||
For simple setups without full MCP OAuth2 compliance, use the legacy helpers that set a session cookie after external provider login:
|
||||
|
||||
```go
|
||||
resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
|
||||
ProviderName: "google",
|
||||
LoginPath: "/auth/google/login",
|
||||
CallbackPath: "/auth/google/callback",
|
||||
AfterLoginRedirect: "/",
|
||||
})
|
||||
resolvemcp.SetupMuxRoutesWithAuth(r, handler, securityList)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Security
|
||||
|
||||
`resolvemcp` integrates with the `security` package to provide per-entity access control, row-level security, and column-level security — the same system used by `resolvespec` and `restheadspec`.
|
||||
|
||||
### Wiring security hooks
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
securityList := security.NewSecurityList(mySecurityProvider)
|
||||
resolvemcp.RegisterSecurityHooks(handler, securityList)
|
||||
```
|
||||
|
||||
Call `RegisterSecurityHooks` **once**, after creating the handler and before registering models. It installs these controls automatically:
|
||||
|
||||
| Hook | Effect |
|
||||
|---|---|
|
||||
| `BeforeHandle` | Enforces per-entity operation rules (see below) |
|
||||
| `BeforeRead` | Loads RLS/CLS rules, then injects a user-scoped WHERE clause |
|
||||
| `AfterRead` | Masks/hides columns per column-security rules; writes audit log |
|
||||
| `BeforeUpdate` | Blocks update if `CanUpdate` is false |
|
||||
| `BeforeDelete` | Blocks delete if `CanDelete` is false |
|
||||
|
||||
### Per-entity operation rules
|
||||
|
||||
Use `RegisterModelWithRules` instead of `RegisterModel` to set access rules at registration time:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
|
||||
// Read-only entity
|
||||
handler.RegisterModelWithRules("public", "audit_logs", &AuditLog{}, modelregistry.ModelRules{
|
||||
CanRead: true,
|
||||
CanCreate: false,
|
||||
CanUpdate: false,
|
||||
CanDelete: false,
|
||||
})
|
||||
|
||||
// Public read, authenticated write
|
||||
handler.RegisterModelWithRules("public", "products", &Product{}, modelregistry.ModelRules{
|
||||
CanPublicRead: true,
|
||||
CanRead: true,
|
||||
CanCreate: true,
|
||||
CanUpdate: true,
|
||||
CanDelete: false,
|
||||
})
|
||||
```
|
||||
|
||||
To update rules for an already-registered model:
|
||||
|
||||
```go
|
||||
handler.SetModelRules("public", "users", modelregistry.ModelRules{
|
||||
CanRead: true,
|
||||
CanCreate: true,
|
||||
CanUpdate: true,
|
||||
CanDelete: false,
|
||||
})
|
||||
```
|
||||
|
||||
`RegisterModel` (no rules) registers with all-allowed defaults (`CanRead/Create/Update/Delete = true`).
|
||||
|
||||
### ModelRules fields
|
||||
|
||||
| Field | Default | Description |
|
||||
|---|---|---|
|
||||
| `CanPublicRead` | `false` | Allow unauthenticated reads |
|
||||
| `CanPublicCreate` | `false` | Allow unauthenticated creates |
|
||||
| `CanPublicUpdate` | `false` | Allow unauthenticated updates |
|
||||
| `CanPublicDelete` | `false` | Allow unauthenticated deletes |
|
||||
| `CanRead` | `true` | Allow authenticated reads |
|
||||
| `CanCreate` | `true` | Allow authenticated creates |
|
||||
| `CanUpdate` | `true` | Allow authenticated updates |
|
||||
| `CanDelete` | `true` | Allow authenticated deletes |
|
||||
| `SecurityDisabled` | `false` | Skip all security checks for this model |
|
||||
|
||||
---
|
||||
|
||||
## MCP Tools
|
||||
|
||||
### Tool Naming
|
||||
|
||||
```
|
||||
{operation}_{schema}_{entity} // e.g. read_public_users
|
||||
{operation}_{entity} // e.g. read_users (when schema is empty)
|
||||
```
|
||||
|
||||
Operations: `read`, `create`, `update`, `delete`.
|
||||
|
||||
### Read Tool — `read_{schema}_{entity}`
|
||||
|
||||
Fetch one or many records.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `id` | string | Primary key value. Omit to return multiple records. |
|
||||
| `limit` | number | Max records per page (recommended: 10–100). |
|
||||
| `offset` | number | Records to skip (offset-based pagination). |
|
||||
| `cursor_forward` | string | PK of the **last** record on the current page (next-page cursor). |
|
||||
| `cursor_backward` | string | PK of the **first** record on the current page (prev-page cursor). |
|
||||
| `columns` | array | Column names to include. Omit for all columns. |
|
||||
| `omit_columns` | array | Column names to exclude. |
|
||||
| `filters` | array | Filter objects (see [Filtering](#filtering)). |
|
||||
| `sort` | array | Sort objects (see [Sorting](#sorting)). |
|
||||
| `preloads` | array | Relation preload objects (see [Preloading](#preloading)). |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [...],
|
||||
"metadata": {
|
||||
"total": 100,
|
||||
"filtered": 100,
|
||||
"count": 10,
|
||||
"limit": 10,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Create Tool — `create_{schema}_{entity}`
|
||||
|
||||
Insert one or more records.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `data` | object \| array | Single object or array of objects to insert. |
|
||||
|
||||
Array input runs inside a single transaction — all succeed or all fail.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "data": { ... } }
|
||||
```
|
||||
|
||||
### Update Tool — `update_{schema}_{entity}`
|
||||
|
||||
Partially update an existing record. Only non-null, non-empty fields in `data` are applied; existing values are preserved for omitted fields.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `id` | string | Primary key of the record. Can also be included inside `data`. |
|
||||
| `data` | object (required) | Fields to update. |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "data": { ...merged record... } }
|
||||
```
|
||||
|
||||
### Delete Tool — `delete_{schema}_{entity}`
|
||||
|
||||
Delete a record by primary key. **Irreversible.**
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `id` | string (required) | Primary key of the record to delete. |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "data": { ...deleted record... } }
|
||||
```
|
||||
|
||||
### Annotation Tool — `resolvespec_annotate`
|
||||
|
||||
Store or retrieve freeform annotation records for any tool, model, or entity. Registered automatically on every handler.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `tool_name` | string (required) | Key to annotate — an MCP tool name (e.g. `read_public_users`), a model name (e.g. `public.users`), or any other identifier. |
|
||||
| `annotations` | object | Annotation data to persist. Omit to retrieve existing annotations instead. |
|
||||
|
||||
**Set annotations** (calls `resolvespec_set_annotation(tool_name, annotations)`):
|
||||
```json
|
||||
{ "tool_name": "read_public_users", "annotations": { "description": "Returns active users", "owner": "platform-team" } }
|
||||
```
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "tool_name": "read_public_users", "action": "set" }
|
||||
```
|
||||
|
||||
**Get annotations** (calls `resolvespec_get_annotation(tool_name)`):
|
||||
```json
|
||||
{ "tool_name": "read_public_users" }
|
||||
```
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "tool_name": "read_public_users", "action": "get", "annotations": { ... } }
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Resource — `{schema}.{entity}`
|
||||
|
||||
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.
|
||||
|
||||
---
|
||||
|
||||
## Filtering
|
||||
|
||||
Pass an array of filter objects to the `filters` argument:
|
||||
|
||||
```json
|
||||
[
|
||||
{ "column": "status", "operator": "=", "value": "active" },
|
||||
{ "column": "age", "operator": ">", "value": 18, "logic_operator": "AND" },
|
||||
{ "column": "role", "operator": "in", "value": ["admin", "editor"], "logic_operator": "OR" }
|
||||
]
|
||||
```
|
||||
|
||||
### Supported Operators
|
||||
|
||||
| Operator | Aliases | Description |
|
||||
|---|---|---|
|
||||
| `=` | `eq` | Equal |
|
||||
| `!=` | `neq`, `<>` | Not equal |
|
||||
| `>` | `gt` | Greater than |
|
||||
| `>=` | `gte` | Greater than or equal |
|
||||
| `<` | `lt` | Less than |
|
||||
| `<=` | `lte` | Less than or equal |
|
||||
| `like` | | SQL LIKE (case-sensitive) |
|
||||
| `ilike` | | SQL ILIKE (case-insensitive) |
|
||||
| `in` | | Value in list |
|
||||
| `is_null` | | Column IS NULL |
|
||||
| `is_not_null` | | Column IS NOT NULL |
|
||||
|
||||
### Logic Operators
|
||||
|
||||
- `"logic_operator": "AND"` (default) — filter is AND-chained with the previous condition.
|
||||
- `"logic_operator": "OR"` — filter is OR-grouped with the previous condition.
|
||||
|
||||
Consecutive OR filters are grouped into a single `(cond1 OR cond2 OR ...)` clause.
|
||||
|
||||
---
|
||||
|
||||
## Sorting
|
||||
|
||||
```json
|
||||
[
|
||||
{ "column": "created_at", "direction": "desc" },
|
||||
{ "column": "name", "direction": "asc" }
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pagination
|
||||
|
||||
### Offset-Based
|
||||
|
||||
```json
|
||||
{ "limit": 20, "offset": 40 }
|
||||
```
|
||||
|
||||
### Cursor-Based
|
||||
|
||||
Cursor pagination uses a SQL `EXISTS` subquery for stable, efficient paging. Always pair with a `sort` argument.
|
||||
|
||||
```json
|
||||
// Next page: pass the PK of the last record on the current page
|
||||
{ "cursor_forward": "42", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
|
||||
|
||||
// Previous page: pass the PK of the first record on the current page
|
||||
{ "cursor_backward": "23", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Preloading Relations
|
||||
|
||||
```json
|
||||
[
|
||||
{ "relation": "Profile" },
|
||||
{ "relation": "Orders" }
|
||||
]
|
||||
```
|
||||
|
||||
Available relations are listed in each tool's description. Only relations defined on the model struct are valid.
|
||||
|
||||
---
|
||||
|
||||
## Hook System
|
||||
|
||||
Hooks let you intercept and modify CRUD operations at well-defined lifecycle points.
|
||||
|
||||
### Hook Types
|
||||
|
||||
| Constant | Fires |
|
||||
|---|---|
|
||||
| `BeforeHandle` | After model resolution, before operation dispatch (all CRUD) |
|
||||
| `BeforeRead` / `AfterRead` | Around read queries |
|
||||
| `BeforeCreate` / `AfterCreate` | Around insert |
|
||||
| `BeforeUpdate` / `AfterUpdate` | Around update |
|
||||
| `BeforeDelete` / `AfterDelete` | Around delete |
|
||||
|
||||
### Registering Hooks
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(resolvemcp.BeforeCreate, func(ctx *resolvemcp.HookContext) error {
|
||||
// Inject a timestamp before insert
|
||||
if data, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
data["created_at"] = time.Now()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Register the same hook for multiple events
|
||||
handler.Hooks().RegisterMultiple(
|
||||
[]resolvemcp.HookType{resolvemcp.BeforeCreate, resolvemcp.BeforeUpdate},
|
||||
auditHook,
|
||||
)
|
||||
```
|
||||
|
||||
### HookContext Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `Context` | `context.Context` | Request context |
|
||||
| `Handler` | `*Handler` | The resolvemcp handler |
|
||||
| `Schema` | `string` | Database schema name |
|
||||
| `Entity` | `string` | Entity/table name |
|
||||
| `Model` | `interface{}` | Registered model instance |
|
||||
| `Options` | `common.RequestOptions` | Parsed request options (read operations) |
|
||||
| `Operation` | `string` | `"read"`, `"create"`, `"update"`, or `"delete"` |
|
||||
| `ID` | `string` | Primary key from request (read/update/delete) |
|
||||
| `Data` | `interface{}` | Input data (create/update — modifiable) |
|
||||
| `Result` | `interface{}` | Output data (set by After hooks) |
|
||||
| `Error` | `error` | Operation error, if any |
|
||||
| `Query` | `common.SelectQuery` | Live query object (available in `BeforeRead`) |
|
||||
| `Tx` | `common.Database` | Database/transaction handle |
|
||||
| `Abort` | `bool` | Set to `true` to abort the operation |
|
||||
| `AbortMessage` | `string` | Error message returned when aborting |
|
||||
| `AbortCode` | `int` | Optional status code for the abort |
|
||||
|
||||
### Aborting an Operation
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(resolvemcp.BeforeDelete, func(ctx *resolvemcp.HookContext) error {
|
||||
ctx.Abort = true
|
||||
ctx.AbortMessage = "deletion is disabled"
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Managing Hooks
|
||||
|
||||
```go
|
||||
registry := handler.Hooks()
|
||||
registry.HasHooks(resolvemcp.BeforeCreate) // bool
|
||||
registry.Clear(resolvemcp.BeforeCreate) // remove hooks for one type
|
||||
registry.ClearAll() // remove all hooks
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Context Helpers
|
||||
|
||||
Request metadata is threaded through `context.Context` during handler execution. Hooks and custom tools can read it:
|
||||
|
||||
```go
|
||||
schema := resolvemcp.GetSchema(ctx)
|
||||
entity := resolvemcp.GetEntity(ctx)
|
||||
tableName := resolvemcp.GetTableName(ctx)
|
||||
model := resolvemcp.GetModel(ctx)
|
||||
modelPtr := resolvemcp.GetModelPtr(ctx)
|
||||
```
|
||||
|
||||
You can also set values manually (e.g. in middleware):
|
||||
|
||||
```go
|
||||
ctx = resolvemcp.WithSchema(ctx, "tenant_a")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Adding Custom MCP Tools
|
||||
|
||||
Access the underlying `*server.MCPServer` to register additional tools:
|
||||
|
||||
```go
|
||||
mcpServer := handler.MCPServer()
|
||||
mcpServer.AddTool(myTool, myHandler)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Table Name Resolution
|
||||
|
||||
The handler resolves table names in priority order:
|
||||
|
||||
1. `TableNameProvider` interface — `TableName() string` (can return `"schema.table"`)
|
||||
2. `SchemaProvider` interface — `SchemaName() string` (combined with entity name)
|
||||
3. Fallback: `schema.entity` (or `schema_entity` for SQLite)
|
||||
107
pkg/resolvemcp/annotation.go
Normal file
107
pkg/resolvemcp/annotation.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const annotationToolName = "resolvespec_annotate"
|
||||
|
||||
// registerAnnotationTool adds the resolvespec_annotate tool to the MCP server.
|
||||
// The tool lets models/entities store and retrieve freeform annotation records
|
||||
// using the resolvespec_set_annotation / resolvespec_get_annotation database procedures.
|
||||
func registerAnnotationTool(h *Handler) {
|
||||
tool := mcp.NewTool(annotationToolName,
|
||||
mcp.WithDescription(
|
||||
"Store or retrieve annotations for any MCP tool, model, or entity.\n\n"+
|
||||
"To set annotations: provide both 'tool_name' and 'annotations'. "+
|
||||
"Calls resolvespec_set_annotation(tool_name, annotations) to persist the data.\n\n"+
|
||||
"To get annotations: provide only 'tool_name'. "+
|
||||
"Calls resolvespec_get_annotation(tool_name) and returns the stored annotations.\n\n"+
|
||||
"'tool_name' may be any identifier: an MCP tool name (e.g. 'read_public_users'), "+
|
||||
"a model/entity name (e.g. 'public.users'), or any other key.",
|
||||
),
|
||||
mcp.WithString("tool_name",
|
||||
mcp.Description("Name of the tool, model, or entity to annotate (e.g. 'read_public_users', 'public.users')."),
|
||||
mcp.Required(),
|
||||
),
|
||||
mcp.WithObject("annotations",
|
||||
mcp.Description("Annotation data to store. Omit to retrieve existing annotations instead of setting them."),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
|
||||
toolName, ok := args["tool_name"].(string)
|
||||
if !ok || toolName == "" {
|
||||
return mcp.NewToolResultError("missing required argument: tool_name"), nil
|
||||
}
|
||||
|
||||
annotations, hasAnnotations := args["annotations"]
|
||||
|
||||
if hasAnnotations && annotations != nil {
|
||||
return executeSetAnnotation(ctx, h, toolName, annotations)
|
||||
}
|
||||
return executeGetAnnotation(ctx, h, toolName)
|
||||
})
|
||||
}
|
||||
|
||||
func executeSetAnnotation(ctx context.Context, h *Handler, toolName string, annotations interface{}) (*mcp.CallToolResult, error) {
|
||||
jsonBytes, err := json.Marshal(annotations)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal annotations: %v", err)), nil
|
||||
}
|
||||
|
||||
_, err = h.db.Exec(ctx, "SELECT resolvespec_set_annotation($1, $2)", toolName, string(jsonBytes))
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to set annotation: %v", err)), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"tool_name": toolName,
|
||||
"action": "set",
|
||||
})
|
||||
}
|
||||
|
||||
func executeGetAnnotation(ctx context.Context, h *Handler, toolName string) (*mcp.CallToolResult, error) {
|
||||
var rows []map[string]interface{}
|
||||
err := h.db.Query(ctx, &rows, "SELECT resolvespec_get_annotation($1)", toolName)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to get annotation: %v", err)), nil
|
||||
}
|
||||
|
||||
var annotations interface{}
|
||||
if len(rows) > 0 {
|
||||
// The procedure returns a single value; extract the first column of the first row.
|
||||
for _, v := range rows[0] {
|
||||
annotations = v
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If the value is a []byte or string containing JSON, decode it so it round-trips cleanly.
|
||||
switch v := annotations.(type) {
|
||||
case []byte:
|
||||
var decoded interface{}
|
||||
if json.Unmarshal(v, &decoded) == nil {
|
||||
annotations = decoded
|
||||
}
|
||||
case string:
|
||||
var decoded interface{}
|
||||
if json.Unmarshal([]byte(v), &decoded) == nil {
|
||||
annotations = decoded
|
||||
}
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"tool_name": toolName,
|
||||
"action": "get",
|
||||
"annotations": annotations,
|
||||
})
|
||||
}
|
||||
71
pkg/resolvemcp/context.go
Normal file
71
pkg/resolvemcp/context.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package resolvemcp
|
||||
|
||||
import "context"
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
contextKeySchema contextKey = "schema"
|
||||
contextKeyEntity contextKey = "entity"
|
||||
contextKeyTableName contextKey = "tableName"
|
||||
contextKeyModel contextKey = "model"
|
||||
contextKeyModelPtr contextKey = "modelPtr"
|
||||
)
|
||||
|
||||
func WithSchema(ctx context.Context, schema string) context.Context {
|
||||
return context.WithValue(ctx, contextKeySchema, schema)
|
||||
}
|
||||
|
||||
func GetSchema(ctx context.Context) string {
|
||||
if v := ctx.Value(contextKeySchema); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithEntity(ctx context.Context, entity string) context.Context {
|
||||
return context.WithValue(ctx, contextKeyEntity, entity)
|
||||
}
|
||||
|
||||
func GetEntity(ctx context.Context) string {
|
||||
if v := ctx.Value(contextKeyEntity); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithTableName(ctx context.Context, tableName string) context.Context {
|
||||
return context.WithValue(ctx, contextKeyTableName, tableName)
|
||||
}
|
||||
|
||||
func GetTableName(ctx context.Context) string {
|
||||
if v := ctx.Value(contextKeyTableName); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithModel(ctx context.Context, model interface{}) context.Context {
|
||||
return context.WithValue(ctx, contextKeyModel, model)
|
||||
}
|
||||
|
||||
func GetModel(ctx context.Context) interface{} {
|
||||
return ctx.Value(contextKeyModel)
|
||||
}
|
||||
|
||||
func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context {
|
||||
return context.WithValue(ctx, contextKeyModelPtr, modelPtr)
|
||||
}
|
||||
|
||||
func GetModelPtr(ctx context.Context) interface{} {
|
||||
return ctx.Value(contextKeyModelPtr)
|
||||
}
|
||||
|
||||
func withRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
|
||||
ctx = WithSchema(ctx, schema)
|
||||
ctx = WithEntity(ctx, entity)
|
||||
ctx = WithTableName(ctx, tableName)
|
||||
ctx = WithModel(ctx, model)
|
||||
ctx = WithModelPtr(ctx, modelPtr)
|
||||
return ctx
|
||||
}
|
||||
161
pkg/resolvemcp/cursor.go
Normal file
161
pkg/resolvemcp/cursor.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package resolvemcp
|
||||
|
||||
// Cursor-based pagination adapted from pkg/resolvespec/cursor.go.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
type cursorDirection int
|
||||
|
||||
const (
|
||||
cursorForward cursorDirection = 1
|
||||
cursorBackward cursorDirection = -1
|
||||
)
|
||||
|
||||
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
|
||||
// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support.
|
||||
func getCursorFilter(
|
||||
tableName string,
|
||||
pkName string,
|
||||
modelColumns []string,
|
||||
options common.RequestOptions,
|
||||
expandJoins map[string]string,
|
||||
) (string, error) {
|
||||
fullTableName := tableName
|
||||
if strings.Contains(tableName, ".") {
|
||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||
}
|
||||
|
||||
cursorID, direction := getActiveCursor(options)
|
||||
if cursorID == "" {
|
||||
return "", fmt.Errorf("no cursor provided for table %s", tableName)
|
||||
}
|
||||
|
||||
sortItems := options.Sort
|
||||
if len(sortItems) == 0 {
|
||||
return "", fmt.Errorf("no sort columns defined")
|
||||
}
|
||||
|
||||
var whereClauses []string
|
||||
joinSQL := ""
|
||||
reverse := direction < 0
|
||||
|
||||
for _, s := range sortItems {
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(col, ".")
|
||||
field := strings.TrimSpace(parts[len(parts)-1])
|
||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||
|
||||
desc := strings.EqualFold(s.Direction, "desc")
|
||||
if reverse {
|
||||
desc = !desc
|
||||
}
|
||||
|
||||
cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
|
||||
if err != nil {
|
||||
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if isJoin {
|
||||
if expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
if cursorCol == "" {
|
||||
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
op := "<"
|
||||
if desc {
|
||||
op = ">"
|
||||
}
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
|
||||
}
|
||||
|
||||
if len(whereClauses) == 0 {
|
||||
return "", fmt.Errorf("no valid sort columns after filtering")
|
||||
}
|
||||
|
||||
orSQL := buildCursorPriorityChain(whereClauses)
|
||||
|
||||
query := fmt.Sprintf(`EXISTS (
|
||||
SELECT 1
|
||||
FROM %s cursor_select
|
||||
%s
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
fullTableName,
|
||||
joinSQL,
|
||||
pkName,
|
||||
cursorID,
|
||||
orSQL,
|
||||
)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func getActiveCursor(options common.RequestOptions) (id string, direction cursorDirection) {
|
||||
if options.CursorForward != "" {
|
||||
return options.CursorForward, cursorForward
|
||||
}
|
||||
if options.CursorBackward != "" {
|
||||
return options.CursorBackward, cursorBackward
|
||||
}
|
||||
return "", 0
|
||||
}
|
||||
|
||||
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||
if strings.Contains(field, "->") {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
if modelColumns != nil {
|
||||
for _, col := range modelColumns {
|
||||
if strings.EqualFold(col, field) {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
if prefix != "" && prefix != tableName {
|
||||
return "", "", true, nil
|
||||
}
|
||||
|
||||
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||
}
|
||||
|
||||
func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||
cursorAlias = "cursor_select_" + alias
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||
return joinSQL, cursorAlias
|
||||
}
|
||||
|
||||
func buildCursorPriorityChain(clauses []string) string {
|
||||
var or []string
|
||||
for i := 0; i < len(clauses); i++ {
|
||||
and := strings.Join(clauses[:i+1], "\n AND ")
|
||||
or = append(or, "("+and+")")
|
||||
}
|
||||
return strings.Join(or, "\n OR ")
|
||||
}
|
||||
760
pkg/resolvemcp/handler.go
Normal file
760
pkg/resolvemcp/handler.go
Normal file
@@ -0,0 +1,760 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// Handler exposes registered database models as MCP tools and resources.
|
||||
type Handler struct {
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
hooks *HookRegistry
|
||||
mcpServer *server.MCPServer
|
||||
config Config
|
||||
name string
|
||||
version string
|
||||
oauth2Regs []oauth2Registration
|
||||
oauthSrv *security.OAuthServer
|
||||
}
|
||||
|
||||
// NewHandler creates a Handler with the given database, model registry, and config.
|
||||
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
|
||||
h := &Handler{
|
||||
db: db,
|
||||
registry: registry,
|
||||
hooks: NewHookRegistry(),
|
||||
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
|
||||
config: cfg,
|
||||
name: "resolvemcp",
|
||||
version: "1.0.0",
|
||||
}
|
||||
registerAnnotationTool(h)
|
||||
return h
|
||||
}
|
||||
|
||||
// Hooks returns the hook registry.
|
||||
func (h *Handler) Hooks() *HookRegistry {
|
||||
return h.hooks
|
||||
}
|
||||
|
||||
// GetDatabase returns the underlying database.
|
||||
func (h *Handler) GetDatabase() common.Database {
|
||||
return h.db
|
||||
}
|
||||
|
||||
// MCPServer returns the underlying MCP server, e.g. to add custom tools.
|
||||
func (h *Handler) MCPServer() *server.MCPServer {
|
||||
return h.mcpServer
|
||||
}
|
||||
|
||||
// SSEServer returns an http.Handler that serves MCP over SSE.
|
||||
// Config.BasePath must be set. Config.BaseURL is used when set; if empty it is
|
||||
// detected automatically from each incoming request.
|
||||
func (h *Handler) SSEServer() http.Handler {
|
||||
if h.config.BaseURL != "" {
|
||||
return h.newSSEServer(h.config.BaseURL, h.config.BasePath)
|
||||
}
|
||||
return &dynamicSSEHandler{h: h}
|
||||
}
|
||||
|
||||
// StreamableHTTPServer returns an http.Handler that serves MCP over the streamable HTTP transport.
|
||||
// Unlike SSE (which requires two endpoints), streamable HTTP uses a single endpoint for all
|
||||
// client-server communication (POST for requests, GET for server-initiated messages).
|
||||
// Mount the returned handler at the desired path; the path itself becomes the MCP endpoint.
|
||||
func (h *Handler) StreamableHTTPServer() http.Handler {
|
||||
return server.NewStreamableHTTPServer(h.mcpServer)
|
||||
}
|
||||
|
||||
// newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values.
|
||||
func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
|
||||
return server.NewSSEServer(
|
||||
h.mcpServer,
|
||||
server.WithBaseURL(baseURL),
|
||||
server.WithStaticBasePath(basePath),
|
||||
)
|
||||
}
|
||||
|
||||
// dynamicSSEHandler detects BaseURL from each request and delegates to a cached
|
||||
// *server.SSEServer per detected baseURL. Used when Config.BaseURL is empty.
|
||||
type dynamicSSEHandler struct {
|
||||
h *Handler
|
||||
mu sync.Mutex
|
||||
pool map[string]*server.SSEServer
|
||||
}
|
||||
|
||||
func (d *dynamicSSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
baseURL := requestBaseURL(r)
|
||||
|
||||
d.mu.Lock()
|
||||
if d.pool == nil {
|
||||
d.pool = make(map[string]*server.SSEServer)
|
||||
}
|
||||
s, ok := d.pool[baseURL]
|
||||
if !ok {
|
||||
s = d.h.newSSEServer(baseURL, d.h.config.BasePath)
|
||||
d.pool[baseURL] = s
|
||||
}
|
||||
d.mu.Unlock()
|
||||
|
||||
s.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// requestBaseURL builds the base URL from an incoming request.
|
||||
// It honours the X-Forwarded-Proto header for deployments behind a proxy.
|
||||
func requestBaseURL(r *http.Request) string {
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
}
|
||||
return scheme + "://" + r.Host
|
||||
}
|
||||
|
||||
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
|
||||
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
|
||||
fullName := buildModelName(schema, entity)
|
||||
if err := h.registry.RegisterModel(fullName, model); err != nil {
|
||||
return err
|
||||
}
|
||||
registerModelTools(h, schema, entity, model)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterModelWithRules registers a model and sets per-entity operation rules
|
||||
// (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*, SecurityDisabled).
|
||||
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
|
||||
func (h *Handler) RegisterModelWithRules(schema, entity string, model interface{}, rules modelregistry.ModelRules) error {
|
||||
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
|
||||
if !ok {
|
||||
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
|
||||
}
|
||||
fullName := buildModelName(schema, entity)
|
||||
if err := reg.RegisterModelWithRules(fullName, model, rules); err != nil {
|
||||
return err
|
||||
}
|
||||
registerModelTools(h, schema, entity, model)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetModelRules updates the operation rules for an already-registered model.
|
||||
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
|
||||
func (h *Handler) SetModelRules(schema, entity string, rules modelregistry.ModelRules) error {
|
||||
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
|
||||
if !ok {
|
||||
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
|
||||
}
|
||||
return reg.SetModelRules(buildModelName(schema, entity), rules)
|
||||
}
|
||||
|
||||
// buildModelName builds the registry key for a model (same format as resolvespec).
|
||||
func buildModelName(schema, entity string) string {
|
||||
if schema == "" {
|
||||
return entity
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schema, entity)
|
||||
}
|
||||
|
||||
// getTableName returns the fully qualified table name for a model.
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||
if schemaName != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
|
||||
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
|
||||
if tableProvider, ok := model.(common.TableNameProvider); ok {
|
||||
tableName := tableProvider.TableName()
|
||||
if idx := strings.LastIndex(tableName, "."); idx != -1 {
|
||||
return tableName[:idx], tableName[idx+1:]
|
||||
}
|
||||
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||
return schemaProvider.SchemaName(), tableName
|
||||
}
|
||||
return defaultSchema, tableName
|
||||
}
|
||||
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||
return schemaProvider.SchemaName(), entity
|
||||
}
|
||||
return defaultSchema, entity
|
||||
}
|
||||
|
||||
// recoverPanic catches a panic from the current goroutine and returns it as an error.
|
||||
// Usage: defer recoverPanic(&returnedErr)
|
||||
func recoverPanic(err *error) {
|
||||
if r := recover(); r != nil {
|
||||
msg := fmt.Sprintf("%v", r)
|
||||
logger.Error("[resolvemcp] panic recovered: %s", msg)
|
||||
*err = fmt.Errorf("internal error: %s", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// executeRead reads records from the database and returns raw data + metadata.
|
||||
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (_ interface{}, _ *common.Metadata, retErr error) {
|
||||
defer recoverPanic(&retErr)
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
unwrapped, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = unwrapped.Model
|
||||
modelType := unwrapped.ModelType
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, unwrapped.ModelPtr)
|
||||
|
||||
validator := common.NewColumnValidator(model)
|
||||
options = validator.FilterRequestOptions(options)
|
||||
|
||||
// BeforeHandle hook
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "read",
|
||||
Options: options,
|
||||
ID: id,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
|
||||
modelPtr := reflect.New(sliceType).Interface()
|
||||
|
||||
query := h.db.NewSelect().Model(modelPtr)
|
||||
|
||||
tempInstance := reflect.New(modelType).Interface()
|
||||
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||
query = query.Table(tableName)
|
||||
}
|
||||
|
||||
// Column selection
|
||||
if len(options.Columns) == 0 && len(options.ComputedColumns) > 0 {
|
||||
options.Columns = reflection.GetSQLModelColumns(model)
|
||||
}
|
||||
for _, col := range options.Columns {
|
||||
query = query.Column(reflection.ExtractSourceColumn(col))
|
||||
}
|
||||
for _, cu := range options.ComputedColumns {
|
||||
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
||||
}
|
||||
|
||||
// Filters
|
||||
query = h.applyFilters(query, options.Filters)
|
||||
|
||||
// Custom operators
|
||||
for _, customOp := range options.CustomOperators {
|
||||
query = query.Where(customOp.SQL)
|
||||
}
|
||||
|
||||
// Sorting
|
||||
for _, sort := range options.Sort {
|
||||
direction := "ASC"
|
||||
if strings.EqualFold(sort.Direction, "desc") {
|
||||
direction = "DESC"
|
||||
}
|
||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
}
|
||||
|
||||
// Cursor pagination
|
||||
if options.CursorForward != "" || options.CursorBackward != "" {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
modelColumns := reflection.GetModelColumns(model)
|
||||
|
||||
if len(options.Sort) == 0 {
|
||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||
}
|
||||
|
||||
// expandJoins is empty for resolvemcp — no custom SQL join support yet
|
||||
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("cursor error: %w", err)
|
||||
}
|
||||
|
||||
if cursorFilter != "" {
|
||||
sanitized := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
||||
sanitized = common.EnsureOuterParentheses(sanitized)
|
||||
if sanitized != "" {
|
||||
query = query.Where(sanitized)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count — must happen before preloads are applied; Bun panics when counting with relations.
|
||||
total, err := query.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error counting records: %w", err)
|
||||
}
|
||||
|
||||
// Pagination
|
||||
if options.Limit != nil && *options.Limit > 0 {
|
||||
query = query.Limit(*options.Limit)
|
||||
}
|
||||
if options.Offset != nil && *options.Offset > 0 {
|
||||
query = query.Offset(*options.Offset)
|
||||
}
|
||||
|
||||
// Preloads — applied after count to avoid Bun panic when counting with relations.
|
||||
if len(options.Preload) > 0 {
|
||||
var preloadErr error
|
||||
query, preloadErr = h.applyPreloads(model, query, options.Preload)
|
||||
if preloadErr != nil {
|
||||
return nil, nil, fmt.Errorf("failed to apply preloads: %w", preloadErr)
|
||||
}
|
||||
}
|
||||
|
||||
// BeforeRead hook
|
||||
hookCtx.Query = query
|
||||
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var data interface{}
|
||||
if id != "" {
|
||||
singleResult := reflect.New(modelType).Interface()
|
||||
pkName := reflection.GetPrimaryKeyName(singleResult)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
if err := query.Scan(ctx, singleResult); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, fmt.Errorf("record not found")
|
||||
}
|
||||
return nil, nil, fmt.Errorf("query error: %w", err)
|
||||
}
|
||||
data = singleResult
|
||||
} else {
|
||||
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||
return nil, nil, fmt.Errorf("query error: %w", err)
|
||||
}
|
||||
data = reflect.ValueOf(modelPtr).Elem().Interface()
|
||||
}
|
||||
|
||||
limit := 0
|
||||
offset := 0
|
||||
if options.Limit != nil {
|
||||
limit = *options.Limit
|
||||
}
|
||||
if options.Offset != nil {
|
||||
offset = *options.Offset
|
||||
}
|
||||
|
||||
// Count is the number of records in this page, not the total.
|
||||
var pageCount int64
|
||||
if id != "" {
|
||||
pageCount = 1
|
||||
} else {
|
||||
pageCount = int64(reflect.ValueOf(data).Len())
|
||||
}
|
||||
|
||||
metadata := &common.Metadata{
|
||||
Total: int64(total),
|
||||
Filtered: int64(total),
|
||||
Count: pageCount,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
// AfterRead hook
|
||||
hookCtx.Result = data
|
||||
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return data, metadata, nil
|
||||
}
|
||||
|
||||
// executeCreate inserts one or more records.
|
||||
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (_ interface{}, retErr error) {
|
||||
defer recoverPanic(&retErr)
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "create",
|
||||
Data: data,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use potentially modified data
|
||||
data = hookCtx.Data
|
||||
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
query := h.db.NewInsert().Table(tableName)
|
||||
for key, value := range v {
|
||||
query = query.Value(key, value)
|
||||
}
|
||||
if _, err := query.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("create error: %w", err)
|
||||
}
|
||||
hookCtx.Result = v
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case []interface{}:
|
||||
results := make([]interface{}, 0, len(v))
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
itemMap, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("each item must be an object")
|
||||
}
|
||||
q := tx.NewInsert().Table(tableName)
|
||||
for key, value := range itemMap {
|
||||
q = q.Value(key, value)
|
||||
}
|
||||
if _, err := q.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, item)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch create error: %w", err)
|
||||
}
|
||||
hookCtx.Result = results
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("data must be an object or array of objects")
|
||||
}
|
||||
}
|
||||
|
||||
// executeUpdate updates a record by ID.
|
||||
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (_ interface{}, retErr error) {
|
||||
defer recoverPanic(&retErr)
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||
|
||||
updates, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("data must be an object")
|
||||
}
|
||||
|
||||
if id == "" {
|
||||
if idVal, exists := updates["id"]; exists {
|
||||
id = fmt.Sprintf("%v", idVal)
|
||||
}
|
||||
}
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("update requires an ID")
|
||||
}
|
||||
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
var updateResult interface{}
|
||||
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Read existing record
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
existingRecord := reflect.New(modelType).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("no records found to update")
|
||||
}
|
||||
return fmt.Errorf("error fetching existing record: %w", err)
|
||||
}
|
||||
|
||||
// Convert to map
|
||||
existingMap := make(map[string]interface{})
|
||||
jsonData, err := json.Marshal(existingRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling existing record: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||
return fmt.Errorf("error unmarshaling existing record: %w", err)
|
||||
}
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "update",
|
||||
ID: id,
|
||||
Data: updates,
|
||||
Tx: tx,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||
updates = modifiedData
|
||||
}
|
||||
|
||||
// Merge non-nil, non-empty values
|
||||
for key, newValue := range updates {
|
||||
if newValue == nil {
|
||||
continue
|
||||
}
|
||||
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||
continue
|
||||
}
|
||||
existingMap[key] = newValue
|
||||
}
|
||||
|
||||
q := tx.NewUpdate().Table(tableName).SetMap(existingMap).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
res, err := q.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating record: %w", err)
|
||||
}
|
||||
if res.RowsAffected() == 0 {
|
||||
return fmt.Errorf("no records found to update")
|
||||
}
|
||||
|
||||
updateResult = existingMap
|
||||
hookCtx.Result = updateResult
|
||||
return h.hooks.Execute(AfterUpdate, hookCtx)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return updateResult, nil
|
||||
}
|
||||
|
||||
// executeDelete deletes a record by ID.
|
||||
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (_ interface{}, retErr error) {
|
||||
defer recoverPanic(&retErr)
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("delete requires an ID")
|
||||
}
|
||||
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "delete",
|
||||
ID: id,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
var recordToDelete interface{}
|
||||
|
||||
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
record := reflect.New(modelType).Interface()
|
||||
selectQuery := tx.NewSelect().Model(record).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("record not found")
|
||||
}
|
||||
return fmt.Errorf("error fetching record: %w", err)
|
||||
}
|
||||
|
||||
res, err := tx.NewDelete().Table(tableName).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete error: %w", err)
|
||||
}
|
||||
if res.RowsAffected() == 0 {
|
||||
return fmt.Errorf("record not found or already deleted")
|
||||
}
|
||||
|
||||
recordToDelete = record
|
||||
hookCtx.Tx = tx
|
||||
hookCtx.Result = record
|
||||
return h.hooks.Execute(AfterDelete, hookCtx)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("[resolvemcp] Deleted record %s from %s.%s", id, schema, entity)
|
||||
return recordToDelete, nil
|
||||
}
|
||||
|
||||
// applyFilters applies all filters with OR grouping logic.
|
||||
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(filters) {
|
||||
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||
|
||||
if startORGroup {
|
||||
orGroup := []common.FilterOption{filters[i]}
|
||||
j := i + 1
|
||||
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||
orGroup = append(orGroup, filters[j])
|
||||
j++
|
||||
}
|
||||
query = h.applyFilterGroup(query, orGroup)
|
||||
i = j
|
||||
} else {
|
||||
condition, args := h.buildFilterCondition(filters[i])
|
||||
if condition != "" {
|
||||
query = query.Where(condition, args...)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for _, filter := range filters {
|
||||
condition, filterArgs := h.buildFilterCondition(filter)
|
||||
if condition != "" {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, filterArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return query
|
||||
}
|
||||
if len(conditions) == 1 {
|
||||
return query.Where(conditions[0], args...)
|
||||
}
|
||||
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
|
||||
}
|
||||
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) {
|
||||
switch filter.Operator {
|
||||
case "eq", "=":
|
||||
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
|
||||
case "neq", "!=", "<>":
|
||||
return fmt.Sprintf("%s != ?", filter.Column), []interface{}{filter.Value}
|
||||
case "gt", ">":
|
||||
return fmt.Sprintf("%s > ?", filter.Column), []interface{}{filter.Value}
|
||||
case "gte", ">=":
|
||||
return fmt.Sprintf("%s >= ?", filter.Column), []interface{}{filter.Value}
|
||||
case "lt", "<":
|
||||
return fmt.Sprintf("%s < ?", filter.Column), []interface{}{filter.Value}
|
||||
case "lte", "<=":
|
||||
return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value}
|
||||
case "like":
|
||||
return fmt.Sprintf("%s LIKE ?", filter.Column), []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
return fmt.Sprintf("%s ILIKE ?", filter.Column), []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition, args := common.BuildInCondition(filter.Column, filter.Value)
|
||||
return condition, args
|
||||
case "is_null":
|
||||
return fmt.Sprintf("%s IS NULL", filter.Column), nil
|
||||
case "is_not_null":
|
||||
return fmt.Sprintf("%s IS NOT NULL", filter.Column), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||
for _, preload := range preloads {
|
||||
if preload.Relation == "" {
|
||||
continue
|
||||
}
|
||||
query = query.PreloadRelation(preload.Relation)
|
||||
}
|
||||
return query, nil
|
||||
}
|
||||
113
pkg/resolvemcp/hooks.go
Normal file
113
pkg/resolvemcp/hooks.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// HookType defines the type of hook to execute
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||
BeforeHandle HookType = "before_handle"
|
||||
|
||||
BeforeRead HookType = "before_read"
|
||||
AfterRead HookType = "after_read"
|
||||
|
||||
BeforeCreate HookType = "before_create"
|
||||
AfterCreate HookType = "after_create"
|
||||
|
||||
BeforeUpdate HookType = "before_update"
|
||||
AfterUpdate HookType = "after_update"
|
||||
|
||||
BeforeDelete HookType = "before_delete"
|
||||
AfterDelete HookType = "after_delete"
|
||||
)
|
||||
|
||||
// HookContext contains all the data available to a hook
|
||||
type HookContext struct {
|
||||
Context context.Context
|
||||
Handler *Handler
|
||||
Schema string
|
||||
Entity string
|
||||
Model interface{}
|
||||
Options common.RequestOptions
|
||||
Operation string
|
||||
ID string
|
||||
Data interface{}
|
||||
Result interface{}
|
||||
Error error
|
||||
Query common.SelectQuery
|
||||
Abort bool
|
||||
AbortMessage string
|
||||
AbortCode int
|
||||
Tx common.Database
|
||||
}
|
||||
|
||||
// HookFunc is the signature for hook functions
|
||||
type HookFunc func(*HookContext) error
|
||||
|
||||
// HookRegistry manages all registered hooks
|
||||
type HookRegistry struct {
|
||||
hooks map[HookType][]HookFunc
|
||||
}
|
||||
|
||||
func NewHookRegistry() *HookRegistry {
|
||||
return &HookRegistry{
|
||||
hooks: make(map[HookType][]HookFunc),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
|
||||
if r.hooks == nil {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
}
|
||||
r.hooks[hookType] = append(r.hooks[hookType], hook)
|
||||
logger.Info("Registered resolvemcp hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
|
||||
}
|
||||
|
||||
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
|
||||
for _, hookType := range hookTypes {
|
||||
r.Register(hookType, hook)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||
hooks, exists := r.hooks[hookType]
|
||||
if !exists || len(hooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Executing %d resolvemcp hook(s) for %s", len(hooks), hookType)
|
||||
|
||||
for i, hook := range hooks {
|
||||
if err := hook(ctx); err != nil {
|
||||
logger.Error("resolvemcp hook %d for %s failed: %v", i+1, hookType, err)
|
||||
return fmt.Errorf("hook execution failed: %w", err)
|
||||
}
|
||||
|
||||
if ctx.Abort {
|
||||
logger.Warn("resolvemcp hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *HookRegistry) Clear(hookType HookType) {
|
||||
delete(r.hooks, hookType)
|
||||
}
|
||||
|
||||
func (r *HookRegistry) ClearAll() {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
}
|
||||
|
||||
func (r *HookRegistry) HasHooks(hookType HookType) bool {
|
||||
hooks, exists := r.hooks[hookType]
|
||||
return exists && len(hooks) > 0
|
||||
}
|
||||
264
pkg/resolvemcp/oauth2.go
Normal file
264
pkg/resolvemcp/oauth2.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// OAuth2 registration on the Handler
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// oauth2Registration stores a configured auth provider and its route config.
|
||||
type oauth2Registration struct {
|
||||
auth *security.DatabaseAuthenticator
|
||||
cfg OAuth2RouteConfig
|
||||
}
|
||||
|
||||
// RegisterOAuth2 attaches an OAuth2 provider to the Handler.
|
||||
// The login and callback HTTP routes are served by HTTPHandler / StreamableHTTPMux.
|
||||
// Call this once per provider before serving requests.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
|
||||
// handler.RegisterOAuth2(auth, resolvemcp.OAuth2RouteConfig{
|
||||
// ProviderName: "google",
|
||||
// LoginPath: "/auth/google/login",
|
||||
// CallbackPath: "/auth/google/callback",
|
||||
// AfterLoginRedirect: "/",
|
||||
// })
|
||||
func (h *Handler) RegisterOAuth2(auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
|
||||
h.oauth2Regs = append(h.oauth2Regs, oauth2Registration{auth: auth, cfg: cfg})
|
||||
}
|
||||
|
||||
// HTTPHandler returns a single http.Handler that serves:
|
||||
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
|
||||
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
|
||||
// - The MCP SSE transport wrapped with required authentication middleware
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// auth := security.NewGoogleAuthenticator(...)
|
||||
// handler.RegisterOAuth2(auth, cfg)
|
||||
// handler.EnableOAuthServer(resolvemcp.OAuthServerConfig{Issuer: "https://api.example.com"})
|
||||
// security.RegisterSecurityHooks(handler, securityList)
|
||||
// http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
|
||||
func (h *Handler) HTTPHandler(securityList *security.SecurityList) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
if h.oauthSrv != nil {
|
||||
h.mountOAuthServerRoutes(mux)
|
||||
}
|
||||
h.mountOAuth2Routes(mux)
|
||||
|
||||
mcpHandler := h.AuthedSSEServer(securityList)
|
||||
basePath := h.config.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/mcp"
|
||||
}
|
||||
mux.Handle(basePath+"/sse", mcpHandler)
|
||||
mux.Handle(basePath+"/message", mcpHandler)
|
||||
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// StreamableHTTPMux returns a single http.Handler that serves:
|
||||
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
|
||||
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
|
||||
// - The MCP streamable HTTP transport wrapped with required authentication middleware
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// http.ListenAndServe(":8080", handler.StreamableHTTPMux(securityList))
|
||||
func (h *Handler) StreamableHTTPMux(securityList *security.SecurityList) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
if h.oauthSrv != nil {
|
||||
h.mountOAuthServerRoutes(mux)
|
||||
}
|
||||
h.mountOAuth2Routes(mux)
|
||||
|
||||
mcpHandler := h.AuthedStreamableHTTPServer(securityList)
|
||||
basePath := h.config.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/mcp"
|
||||
}
|
||||
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
|
||||
mux.Handle(basePath, mcpHandler)
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// mountOAuth2Routes registers all stored OAuth2 login+callback routes onto mux.
|
||||
func (h *Handler) mountOAuth2Routes(mux *http.ServeMux) {
|
||||
for _, reg := range h.oauth2Regs {
|
||||
var cookieOpts []security.SessionCookieOptions
|
||||
if reg.cfg.CookieOptions != nil {
|
||||
cookieOpts = append(cookieOpts, *reg.cfg.CookieOptions)
|
||||
}
|
||||
mux.Handle(reg.cfg.LoginPath, OAuth2LoginHandler(reg.auth, reg.cfg.ProviderName))
|
||||
mux.Handle(reg.cfg.CallbackPath, OAuth2CallbackHandler(reg.auth, reg.cfg.ProviderName, reg.cfg.AfterLoginRedirect, cookieOpts...))
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Auth-wrapped transports
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// AuthedSSEServer wraps SSEServer with required authentication middleware from pkg/security.
|
||||
// The middleware reads the session cookie / Authorization header and populates the user
|
||||
// context into the request context, making it available to BeforeHandle security hooks.
|
||||
// Unauthenticated requests receive 401 before reaching any MCP tool.
|
||||
func (h *Handler) AuthedSSEServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewAuthMiddleware(securityList)(h.SSEServer())
|
||||
}
|
||||
|
||||
// OptionalAuthSSEServer wraps SSEServer with optional authentication middleware.
|
||||
// Unauthenticated requests continue as guest rather than returning 401.
|
||||
// Use together with RegisterSecurityHooks and per-model CanPublicRead/Write rules
|
||||
// to allow mixed public/private access.
|
||||
func (h *Handler) OptionalAuthSSEServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewOptionalAuthMiddleware(securityList)(h.SSEServer())
|
||||
}
|
||||
|
||||
// AuthedStreamableHTTPServer wraps StreamableHTTPServer with required authentication middleware.
|
||||
func (h *Handler) AuthedStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewAuthMiddleware(securityList)(h.StreamableHTTPServer())
|
||||
}
|
||||
|
||||
// OptionalAuthStreamableHTTPServer wraps StreamableHTTPServer with optional authentication middleware.
|
||||
func (h *Handler) OptionalAuthStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewOptionalAuthMiddleware(securityList)(h.StreamableHTTPServer())
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// OAuth2 route config and standalone handlers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// OAuth2RouteConfig configures the OAuth2 HTTP endpoints for a single provider.
|
||||
type OAuth2RouteConfig struct {
|
||||
// ProviderName is the OAuth2 provider name as registered with WithOAuth2()
|
||||
// (e.g. "google", "github", "microsoft").
|
||||
ProviderName string
|
||||
|
||||
// LoginPath is the HTTP path that redirects the browser to the OAuth2 provider
|
||||
// (e.g. "/auth/google/login").
|
||||
LoginPath string
|
||||
|
||||
// CallbackPath is the HTTP path that the OAuth2 provider redirects back to
|
||||
// (e.g. "/auth/google/callback"). Must match the RedirectURL in OAuth2Config.
|
||||
CallbackPath string
|
||||
|
||||
// AfterLoginRedirect is the URL to redirect the browser to after a successful
|
||||
// login. When empty the LoginResponse JSON is written directly to the response.
|
||||
AfterLoginRedirect string
|
||||
|
||||
// CookieOptions customises the session cookie written on successful login.
|
||||
// Defaults to HttpOnly, Secure, SameSite=Lax when nil.
|
||||
CookieOptions *security.SessionCookieOptions
|
||||
}
|
||||
|
||||
// OAuth2LoginHandler returns an http.HandlerFunc that redirects the browser to
|
||||
// the OAuth2 provider's authorization URL.
|
||||
//
|
||||
// Register it on any router:
|
||||
//
|
||||
// mux.Handle("/auth/google/login", resolvemcp.OAuth2LoginHandler(auth, "google"))
|
||||
func OAuth2LoginHandler(auth *security.DatabaseAuthenticator, providerName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := auth.OAuth2GenerateState()
|
||||
if err != nil {
|
||||
http.Error(w, "failed to generate state", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
authURL, err := auth.OAuth2GetAuthURL(providerName, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth2CallbackHandler returns an http.HandlerFunc that handles the OAuth2 provider
|
||||
// callback: exchanges the authorization code for a session token, writes the session
|
||||
// cookie, then either redirects to afterLoginRedirect or writes the LoginResponse as JSON.
|
||||
//
|
||||
// Register it on any router:
|
||||
//
|
||||
// mux.Handle("/auth/google/callback", resolvemcp.OAuth2CallbackHandler(auth, "google", "/dashboard"))
|
||||
func OAuth2CallbackHandler(auth *security.DatabaseAuthenticator, providerName, afterLoginRedirect string, cookieOpts ...security.SessionCookieOptions) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
if code == "" {
|
||||
http.Error(w, "missing code parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), providerName, code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
security.SetSessionCookie(w, loginResp, cookieOpts...)
|
||||
|
||||
if afterLoginRedirect != "" {
|
||||
http.Redirect(w, r, afterLoginRedirect, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(loginResp) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Gorilla Mux convenience helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// SetupMuxOAuth2Routes registers the OAuth2 login and callback routes on a Gorilla Mux router.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
|
||||
// ProviderName: "google", LoginPath: "/auth/google/login",
|
||||
// CallbackPath: "/auth/google/callback", AfterLoginRedirect: "/",
|
||||
// })
|
||||
func SetupMuxOAuth2Routes(muxRouter *mux.Router, auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
|
||||
var cookieOpts []security.SessionCookieOptions
|
||||
if cfg.CookieOptions != nil {
|
||||
cookieOpts = append(cookieOpts, *cfg.CookieOptions)
|
||||
}
|
||||
|
||||
muxRouter.Handle(cfg.LoginPath,
|
||||
OAuth2LoginHandler(auth, cfg.ProviderName),
|
||||
).Methods(http.MethodGet)
|
||||
|
||||
muxRouter.Handle(cfg.CallbackPath,
|
||||
OAuth2CallbackHandler(auth, cfg.ProviderName, cfg.AfterLoginRedirect, cookieOpts...),
|
||||
).Methods(http.MethodGet)
|
||||
}
|
||||
|
||||
// SetupMuxRoutesWithAuth mounts the MCP SSE endpoints on a Gorilla Mux router
|
||||
// with required authentication middleware applied.
|
||||
func SetupMuxRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.AuthedSSEServer(securityList)
|
||||
|
||||
muxRouter.Handle(basePath+"/sse", h).Methods(http.MethodGet, http.MethodOptions)
|
||||
muxRouter.Handle(basePath+"/message", h).Methods(http.MethodPost, http.MethodOptions)
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
|
||||
// SetupMuxStreamableHTTPRoutesWithAuth mounts the MCP streamable HTTP endpoint on a
|
||||
// Gorilla Mux router with required authentication middleware applied.
|
||||
func SetupMuxStreamableHTTPRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.AuthedStreamableHTTPServer(securityList)
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
51
pkg/resolvemcp/oauth2_server.go
Normal file
51
pkg/resolvemcp/oauth2_server.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// EnableOAuthServer activates the MCP-standard OAuth2 authorization server on this Handler.
|
||||
//
|
||||
// Pass a DatabaseAuthenticator to enable direct username/password login — the server acts as
|
||||
// its own identity provider and renders a login form at /oauth/authorize. Pass nil to use
|
||||
// only external providers registered via RegisterOAuth2Provider.
|
||||
//
|
||||
// After calling this, HTTPHandler and StreamableHTTPMux serve the full set of RFC-compliant
|
||||
// endpoints required by MCP clients alongside the MCP transport:
|
||||
//
|
||||
// GET /.well-known/oauth-authorization-server RFC 8414 — auto-discovery
|
||||
// POST /oauth/register RFC 7591 — dynamic client registration
|
||||
// GET /oauth/authorize OAuth 2.1 + PKCE — start login
|
||||
// POST /oauth/authorize Login form submission (password flow)
|
||||
// POST /oauth/token Bearer token exchange + refresh
|
||||
// GET /oauth/provider/callback External provider redirect target
|
||||
func (h *Handler) EnableOAuthServer(cfg security.OAuthServerConfig, auth *security.DatabaseAuthenticator) {
|
||||
h.oauthSrv = security.NewOAuthServer(cfg, auth)
|
||||
// Wire any external providers already registered via RegisterOAuth2
|
||||
for _, reg := range h.oauth2Regs {
|
||||
h.oauthSrv.RegisterExternalProvider(reg.auth, reg.cfg.ProviderName)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterOAuth2Provider adds an external OAuth2 provider to the MCP OAuth2 authorization server.
|
||||
// EnableOAuthServer must be called before this. The auth must have been configured with
|
||||
// WithOAuth2(providerName, ...) for the given provider name.
|
||||
func (h *Handler) RegisterOAuth2Provider(auth *security.DatabaseAuthenticator, providerName string) {
|
||||
if h.oauthSrv != nil {
|
||||
h.oauthSrv.RegisterExternalProvider(auth, providerName)
|
||||
}
|
||||
}
|
||||
|
||||
// mountOAuthServerRoutes mounts the security.OAuthServer's HTTP handler onto mux.
|
||||
func (h *Handler) mountOAuthServerRoutes(mux *http.ServeMux) {
|
||||
oauthHandler := h.oauthSrv.HTTPHandler()
|
||||
// Delegate all /oauth/ and /.well-known/ paths to the OAuth server
|
||||
mux.Handle("/.well-known/", oauthHandler)
|
||||
mux.Handle("/oauth/", oauthHandler)
|
||||
if h.oauthSrv != nil {
|
||||
// Also mount the external provider callback path if it differs from /oauth/
|
||||
mux.Handle(h.oauthSrv.ProviderCallbackPath(), oauthHandler)
|
||||
}
|
||||
}
|
||||
133
pkg/resolvemcp/resolvemcp.go
Normal file
133
pkg/resolvemcp/resolvemcp.go
Normal file
@@ -0,0 +1,133 @@
|
||||
// Package resolvemcp exposes registered database models as Model Context Protocol (MCP) tools
|
||||
// and resources over HTTP/SSE transport.
|
||||
//
|
||||
// It mirrors the resolvespec package patterns:
|
||||
// - Same model registration API
|
||||
// - Same filter, sort, cursor pagination, preload options
|
||||
// - Same lifecycle hook system
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{BaseURL: "http://localhost:8080"})
|
||||
// handler.RegisterModel("public", "users", &User{})
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// resolvemcp.SetupMuxRoutes(r, handler)
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/uptrace/bun"
|
||||
bunrouter "github.com/uptrace/bunrouter"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// Config holds configuration for the resolvemcp handler.
|
||||
type Config struct {
|
||||
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||
BaseURL string
|
||||
|
||||
// BasePath is the URL path prefix where the MCP endpoints are mounted (e.g. "/mcp").
|
||||
// If empty, the path is detected from each incoming request automatically.
|
||||
BasePath string
|
||||
}
|
||||
|
||||
// NewHandlerWithGORM creates a Handler backed by a GORM database connection.
|
||||
func NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler {
|
||||
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry(), cfg)
|
||||
}
|
||||
|
||||
// NewHandlerWithBun creates a Handler backed by a Bun database connection.
|
||||
func NewHandlerWithBun(db *bun.DB, cfg Config) *Handler {
|
||||
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry(), cfg)
|
||||
}
|
||||
|
||||
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
|
||||
func NewHandlerWithDB(db common.Database, cfg Config) *Handler {
|
||||
return NewHandler(db, modelregistry.NewModelRegistry(), cfg)
|
||||
}
|
||||
|
||||
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router
|
||||
// using the base path from Config.BasePath (falls back to "/mcp" if empty).
|
||||
//
|
||||
// Two routes are registered:
|
||||
// - GET {basePath}/sse — SSE connection endpoint (client subscribes here)
|
||||
// - POST {basePath}/message — JSON-RPC message endpoint (client sends requests here)
|
||||
//
|
||||
// To protect these routes with authentication, wrap the mux router or apply middleware
|
||||
// before calling SetupMuxRoutes.
|
||||
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.SSEServer()
|
||||
|
||||
muxRouter.Handle(basePath+"/sse", h).Methods("GET", "OPTIONS")
|
||||
muxRouter.Handle(basePath+"/message", h).Methods("POST", "OPTIONS")
|
||||
|
||||
// Convenience: also expose the full SSE server at basePath for clients that
|
||||
// use ServeHTTP directly (e.g. net/http default mux).
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
|
||||
// SetupBunRouterRoutes mounts the MCP HTTP/SSE endpoints on a bunrouter router
|
||||
// using the base path from Config.BasePath.
|
||||
//
|
||||
// Two routes are registered:
|
||||
// - GET {basePath}/sse — SSE connection endpoint
|
||||
// - POST {basePath}/message — JSON-RPC message endpoint
|
||||
func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.SSEServer()
|
||||
|
||||
router.GET(basePath+"/sse", bunrouter.HTTPHandler(h))
|
||||
router.POST(basePath+"/message", bunrouter.HTTPHandler(h))
|
||||
}
|
||||
|
||||
// NewSSEServer returns an http.Handler that serves MCP over SSE.
|
||||
// If Config.BasePath is set it is used directly; otherwise the base path is
|
||||
// detected from each incoming request (by stripping the "/sse" or "/message" suffix).
|
||||
//
|
||||
// h := resolvemcp.NewSSEServer(handler)
|
||||
// http.Handle("/api/mcp/", h)
|
||||
func NewSSEServer(handler *Handler) http.Handler {
|
||||
return handler.SSEServer()
|
||||
}
|
||||
|
||||
// SetupMuxStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on the given Gorilla Mux router.
|
||||
// The streamable HTTP transport uses a single endpoint (Config.BasePath) for all communication:
|
||||
// POST for client→server messages, GET for server→client streaming.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler) // mounts at Config.BasePath
|
||||
func SetupMuxStreamableHTTPRoutes(muxRouter *mux.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.StreamableHTTPServer()
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
|
||||
// SetupBunRouterStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on a bunrouter router.
|
||||
// The streamable HTTP transport uses a single endpoint (Config.BasePath).
|
||||
func SetupBunRouterStreamableHTTPRoutes(router *bunrouter.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.StreamableHTTPServer()
|
||||
router.GET(basePath, bunrouter.HTTPHandler(h))
|
||||
router.POST(basePath, bunrouter.HTTPHandler(h))
|
||||
router.DELETE(basePath, bunrouter.HTTPHandler(h))
|
||||
}
|
||||
|
||||
// NewStreamableHTTPHandler returns an http.Handler that serves MCP over the streamable HTTP transport.
|
||||
// Mount it at the desired path; that path becomes the MCP endpoint.
|
||||
//
|
||||
// h := resolvemcp.NewStreamableHTTPHandler(handler)
|
||||
// http.Handle("/mcp", h)
|
||||
// engine.Any("/mcp", gin.WrapH(h))
|
||||
func NewStreamableHTTPHandler(handler *Handler) http.Handler {
|
||||
return handler.StreamableHTTPServer()
|
||||
}
|
||||
115
pkg/resolvemcp/security_hooks.go
Normal file
115
pkg/resolvemcp/security_hooks.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks wires the security package's access-control layer into the
|
||||
// resolvemcp handler. Call it once after creating the handler, before registering models.
|
||||
//
|
||||
// The following controls are applied:
|
||||
// - Per-entity operation rules (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*)
|
||||
// stored via RegisterModelWithRules / SetModelRules.
|
||||
// - Row-level security: WHERE clause injected per user from the SecurityList provider.
|
||||
// - Column-level security: sensitive columns masked/hidden in read results.
|
||||
// - Audit logging after each read.
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// BeforeHandle: enforce model-level operation rules (auth check).
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// BeforeRead (1st): load RLS + CLS rules from the provider into SecurityList.
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
return security.LoadSecurityRules(newSecurityContext(hookCtx), securityList)
|
||||
})
|
||||
|
||||
// BeforeRead (2nd): apply row-level security — injects a WHERE clause into the query.
|
||||
// resolvemcp has no separate BeforeScan hook; the query is available in BeforeRead.
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
return security.ApplyRowSecurity(newSecurityContext(hookCtx), securityList)
|
||||
})
|
||||
|
||||
// AfterRead (1st): apply column-level security — mask/hide columns in the result.
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
return security.ApplyColumnSecurity(newSecurityContext(hookCtx), securityList)
|
||||
})
|
||||
|
||||
// AfterRead (2nd): audit log.
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
return security.LogDataAccess(newSecurityContext(hookCtx))
|
||||
})
|
||||
|
||||
// BeforeUpdate: enforce CanUpdate rule.
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
return security.CheckModelUpdateAllowed(newSecurityContext(hookCtx))
|
||||
})
|
||||
|
||||
// BeforeDelete: enforce CanDelete rule.
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
return security.CheckModelDeleteAllowed(newSecurityContext(hookCtx))
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for resolvemcp handler")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// securityContext — adapts resolvemcp.HookContext to security.SecurityContext
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
type securityContext struct {
|
||||
ctx *HookContext
|
||||
}
|
||||
|
||||
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||
return &securityContext{ctx: ctx}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetContext() context.Context {
|
||||
return s.ctx.Context
|
||||
}
|
||||
|
||||
func (s *securityContext) GetUserID() (int, bool) {
|
||||
return security.GetUserID(s.ctx.Context)
|
||||
}
|
||||
|
||||
func (s *securityContext) GetSchema() string {
|
||||
return s.ctx.Schema
|
||||
}
|
||||
|
||||
func (s *securityContext) GetEntity() string {
|
||||
return s.ctx.Entity
|
||||
}
|
||||
|
||||
func (s *securityContext) GetModel() interface{} {
|
||||
return s.ctx.Model
|
||||
}
|
||||
|
||||
func (s *securityContext) GetQuery() interface{} {
|
||||
return s.ctx.Query
|
||||
}
|
||||
|
||||
func (s *securityContext) SetQuery(query interface{}) {
|
||||
if q, ok := query.(common.SelectQuery); ok {
|
||||
s.ctx.Query = q
|
||||
}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetResult() interface{} {
|
||||
return s.ctx.Result
|
||||
}
|
||||
|
||||
func (s *securityContext) SetResult(result interface{}) {
|
||||
s.ctx.Result = result
|
||||
}
|
||||
692
pkg/resolvemcp/tools.go
Normal file
692
pkg/resolvemcp/tools.go
Normal file
@@ -0,0 +1,692 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// toolName builds the MCP tool name for a given operation and model.
|
||||
func toolName(operation, schema, entity string) string {
|
||||
if schema == "" {
|
||||
return fmt.Sprintf("%s_%s", operation, entity)
|
||||
}
|
||||
return fmt.Sprintf("%s_%s_%s", operation, schema, entity)
|
||||
}
|
||||
|
||||
// registerModelTools registers the four CRUD tools and resource for a model.
|
||||
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
|
||||
info := buildModelInfo(schema, entity, model)
|
||||
registerReadTool(h, schema, entity, info)
|
||||
registerCreateTool(h, schema, entity, info)
|
||||
registerUpdateTool(h, schema, entity, info)
|
||||
registerDeleteTool(h, schema, entity, info)
|
||||
registerModelResource(h, schema, entity, info)
|
||||
|
||||
logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Model introspection
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// modelInfo holds pre-computed metadata for a model used in tool descriptions.
|
||||
type modelInfo struct {
|
||||
fullName string // e.g. "public.users"
|
||||
pkName string // e.g. "id"
|
||||
columns []columnInfo
|
||||
relationNames []string
|
||||
schemaDoc string // formatted multi-line schema listing
|
||||
}
|
||||
|
||||
type columnInfo struct {
|
||||
jsonName string
|
||||
sqlName string
|
||||
goType string
|
||||
sqlType string
|
||||
isPrimary bool
|
||||
isUnique bool
|
||||
isFK bool
|
||||
nullable bool
|
||||
}
|
||||
|
||||
// buildModelInfo extracts column metadata and pre-builds the schema documentation string.
|
||||
func buildModelInfo(schema, entity string, model interface{}) modelInfo {
|
||||
info := modelInfo{
|
||||
fullName: buildModelName(schema, entity),
|
||||
pkName: reflection.GetPrimaryKeyName(model),
|
||||
}
|
||||
|
||||
// Unwrap to base struct type
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return info
|
||||
}
|
||||
|
||||
details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem())
|
||||
|
||||
for _, d := range details {
|
||||
// Derive the JSON name from the struct field
|
||||
jsonName := fieldJSONName(modelType, d.Name)
|
||||
if jsonName == "" || jsonName == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip relation fields (slice or user-defined struct that isn't time.Time).
|
||||
fieldType, found := modelType.FieldByName(d.Name)
|
||||
if found {
|
||||
ft := fieldType.Type
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
|
||||
if ft.Kind() == reflect.Slice || isUserStruct {
|
||||
info.relationNames = append(info.relationNames, jsonName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
sqlName := d.SQLName
|
||||
if sqlName == "" {
|
||||
sqlName = jsonName
|
||||
}
|
||||
|
||||
// Derive Go type name, unwrapping pointer if needed.
|
||||
goType := d.DataType
|
||||
if goType == "" && found {
|
||||
ft := fieldType.Type
|
||||
for ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
goType = ft.Name()
|
||||
}
|
||||
|
||||
// isPrimary: use both the GORM-tag detection and a name comparison against
|
||||
// the known primary key (handles camelCase "primaryKey" tags correctly).
|
||||
isPrimary := d.SQLKey == "primary_key" ||
|
||||
(info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName))
|
||||
|
||||
ci := columnInfo{
|
||||
jsonName: jsonName,
|
||||
sqlName: sqlName,
|
||||
goType: goType,
|
||||
sqlType: d.SQLDataType,
|
||||
isPrimary: isPrimary,
|
||||
isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex",
|
||||
isFK: d.SQLKey == "foreign_key",
|
||||
nullable: d.Nullable,
|
||||
}
|
||||
info.columns = append(info.columns, ci)
|
||||
}
|
||||
|
||||
info.schemaDoc = buildSchemaDoc(info)
|
||||
return info
|
||||
}
|
||||
|
||||
// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name.
|
||||
func fieldJSONName(modelType reflect.Type, fieldName string) string {
|
||||
field, ok := modelType.FieldByName(fieldName)
|
||||
if !ok {
|
||||
return fieldName
|
||||
}
|
||||
tag := field.Tag.Get("json")
|
||||
if tag == "" {
|
||||
return fieldName
|
||||
}
|
||||
parts := strings.SplitN(tag, ",", 2)
|
||||
if parts[0] == "" {
|
||||
return fieldName
|
||||
}
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions.
|
||||
func buildSchemaDoc(info modelInfo) string {
|
||||
if len(info.columns) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Columns:\n")
|
||||
for _, c := range info.columns {
|
||||
line := fmt.Sprintf(" • %s", c.jsonName)
|
||||
|
||||
typeDesc := c.goType
|
||||
if c.sqlType != "" {
|
||||
typeDesc = c.sqlType
|
||||
}
|
||||
if typeDesc != "" {
|
||||
line += fmt.Sprintf(" (%s)", typeDesc)
|
||||
}
|
||||
|
||||
var flags []string
|
||||
if c.isPrimary {
|
||||
flags = append(flags, "primary key")
|
||||
}
|
||||
if c.isUnique {
|
||||
flags = append(flags, "unique")
|
||||
}
|
||||
if c.isFK {
|
||||
flags = append(flags, "foreign key")
|
||||
}
|
||||
if !c.nullable && !c.isPrimary {
|
||||
flags = append(flags, "not null")
|
||||
} else if c.nullable {
|
||||
flags = append(flags, "nullable")
|
||||
}
|
||||
if len(flags) > 0 {
|
||||
line += " — " + strings.Join(flags, ", ")
|
||||
}
|
||||
|
||||
sb.WriteString(line + "\n")
|
||||
}
|
||||
|
||||
if len(info.relationNames) > 0 {
|
||||
sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// columnNameList returns a comma-separated list of JSON column names (for descriptions).
|
||||
func columnNameList(cols []columnInfo) string {
|
||||
names := make([]string, len(cols))
|
||||
for i, c := range cols {
|
||||
names[i] = c.jsonName
|
||||
}
|
||||
return strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
// writableColumnNames returns JSON names for all non-primary-key columns.
|
||||
func writableColumnNames(cols []columnInfo) []string {
|
||||
var names []string
|
||||
for _, c := range cols {
|
||||
if !c.isPrimary {
|
||||
names = append(names, c.jsonName)
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Read tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerReadTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("read", schema, entity)
|
||||
|
||||
var descParts []string
|
||||
descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName))
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName))
|
||||
}
|
||||
if info.schemaDoc != "" {
|
||||
descParts = append(descParts, info.schemaDoc)
|
||||
}
|
||||
descParts = append(descParts,
|
||||
"Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.",
|
||||
"Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.",
|
||||
"Sorting: each sort object requires 'column' and 'direction' (asc or desc).",
|
||||
)
|
||||
if len(info.relationNames) > 0 {
|
||||
descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", ")))
|
||||
}
|
||||
|
||||
description := strings.Join(descParts, "\n\n")
|
||||
|
||||
filterDesc := `Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`
|
||||
if len(info.columns) > 0 {
|
||||
filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
|
||||
}
|
||||
|
||||
sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]`
|
||||
if len(info.columns) > 0 {
|
||||
sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
|
||||
}
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithString("id",
|
||||
mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)),
|
||||
),
|
||||
mcp.WithNumber("limit",
|
||||
mcp.Description("Maximum number of records to return per page. Recommended: 10–100."),
|
||||
),
|
||||
mcp.WithNumber("offset",
|
||||
mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."),
|
||||
),
|
||||
mcp.WithString("cursor_forward",
|
||||
mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)),
|
||||
),
|
||||
mcp.WithString("cursor_backward",
|
||||
mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)),
|
||||
),
|
||||
mcp.WithArray("columns",
|
||||
mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))),
|
||||
),
|
||||
mcp.WithArray("omit_columns",
|
||||
mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))),
|
||||
),
|
||||
mcp.WithArray("filters",
|
||||
mcp.Description(filterDesc),
|
||||
),
|
||||
mcp.WithArray("sort",
|
||||
mcp.Description(sortDesc),
|
||||
),
|
||||
mcp.WithArray("preloads",
|
||||
mcp.Description(buildPreloadDesc(info)),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
id, _ := args["id"].(string)
|
||||
options := parseRequestOptions(args)
|
||||
|
||||
data, metadata, err := h.executeRead(ctx, schema, entity, id, options)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": data,
|
||||
"metadata": metadata,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func buildPreloadDesc(info modelInfo) string {
|
||||
if len(info.relationNames) == 0 {
|
||||
return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.`
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`,
|
||||
strings.Join(info.relationNames, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Create tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerCreateTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("create", schema, entity)
|
||||
|
||||
writable := writableColumnNames(info.columns)
|
||||
|
||||
var descParts []string
|
||||
descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName))
|
||||
if len(writable) > 0 {
|
||||
descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", ")))
|
||||
}
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName))
|
||||
}
|
||||
descParts = append(descParts,
|
||||
"Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).",
|
||||
)
|
||||
if info.schemaDoc != "" {
|
||||
descParts = append(descParts, info.schemaDoc)
|
||||
}
|
||||
|
||||
description := strings.Join(descParts, "\n\n")
|
||||
|
||||
dataDesc := "Record fields to create."
|
||||
if len(writable) > 0 {
|
||||
dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", "))
|
||||
}
|
||||
dataDesc += " Pass a single object or an array of objects."
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithObject("data",
|
||||
mcp.Description(dataDesc),
|
||||
mcp.Required(),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
data, ok := args["data"]
|
||||
if !ok {
|
||||
return mcp.NewToolResultError("missing required argument: data"), nil
|
||||
}
|
||||
|
||||
result, err := h.executeCreate(ctx, schema, entity, data)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Update tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("update", schema, entity)
|
||||
|
||||
writable := writableColumnNames(info.columns)
|
||||
|
||||
var descParts []string
|
||||
descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName))
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName))
|
||||
}
|
||||
if len(writable) > 0 {
|
||||
descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", ")))
|
||||
}
|
||||
descParts = append(descParts,
|
||||
"Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.",
|
||||
)
|
||||
if info.schemaDoc != "" {
|
||||
descParts = append(descParts, info.schemaDoc)
|
||||
}
|
||||
|
||||
description := strings.Join(descParts, "\n\n")
|
||||
|
||||
idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName)
|
||||
|
||||
dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)."
|
||||
if len(writable) > 0 {
|
||||
dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", "))
|
||||
}
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithString("id",
|
||||
mcp.Description(idDesc),
|
||||
),
|
||||
mcp.WithObject("data",
|
||||
mcp.Description(dataDesc),
|
||||
mcp.Required(),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
id, _ := args["id"].(string)
|
||||
|
||||
data, ok := args["data"]
|
||||
if !ok {
|
||||
return mcp.NewToolResultError("missing required argument: data"), nil
|
||||
}
|
||||
dataMap, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return mcp.NewToolResultError("data must be an object"), nil
|
||||
}
|
||||
|
||||
result, err := h.executeUpdate(ctx, schema, entity, id, dataMap)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Delete tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("delete", schema, entity)
|
||||
|
||||
descParts := []string{
|
||||
fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName),
|
||||
}
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName))
|
||||
}
|
||||
descParts = append(descParts, "Returns the deleted record. This operation is irreversible.")
|
||||
|
||||
description := strings.Join(descParts, " ")
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithString("id",
|
||||
mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)),
|
||||
mcp.Required(),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
id, _ := args["id"].(string)
|
||||
|
||||
result, err := h.executeDelete(ctx, schema, entity, id)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Resource registration
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerModelResource(h *Handler, schema, entity string, info modelInfo) {
|
||||
resourceURI := info.fullName
|
||||
|
||||
var resourceDesc strings.Builder
|
||||
fmt.Fprintf(&resourceDesc, "Database table: %s", info.fullName)
|
||||
if info.pkName != "" {
|
||||
fmt.Fprintf(&resourceDesc, " (primary key: %s)", info.pkName)
|
||||
}
|
||||
if info.schemaDoc != "" {
|
||||
resourceDesc.WriteString("\n\n")
|
||||
resourceDesc.WriteString(info.schemaDoc)
|
||||
}
|
||||
|
||||
resource := mcp.NewResource(
|
||||
resourceURI,
|
||||
entity,
|
||||
mcp.WithResourceDescription(resourceDesc.String()),
|
||||
mcp.WithMIMEType("application/json"),
|
||||
)
|
||||
|
||||
h.mcpServer.AddResource(resource, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
|
||||
limit := 100
|
||||
options := common.RequestOptions{Limit: &limit}
|
||||
|
||||
data, metadata, err := h.executeRead(ctx, schema, entity, "", options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"data": data,
|
||||
"metadata": metadata,
|
||||
}
|
||||
jsonBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling resource: %w", err)
|
||||
}
|
||||
|
||||
return []mcp.ResourceContents{
|
||||
mcp.TextResourceContents{
|
||||
URI: req.Params.URI,
|
||||
MIMEType: "application/json",
|
||||
Text: string(jsonBytes),
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Argument parsing helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// parseRequestOptions converts raw MCP tool arguments into common.RequestOptions.
|
||||
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
|
||||
options := common.RequestOptions{}
|
||||
|
||||
if v, ok := args["limit"]; ok {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
limit := int(n)
|
||||
options.Limit = &limit
|
||||
case int:
|
||||
options.Limit = &n
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := args["offset"]; ok {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
offset := int(n)
|
||||
options.Offset = &offset
|
||||
case int:
|
||||
options.Offset = &n
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := args["cursor_forward"].(string); ok {
|
||||
options.CursorForward = v
|
||||
}
|
||||
if v, ok := args["cursor_backward"].(string); ok {
|
||||
options.CursorBackward = v
|
||||
}
|
||||
|
||||
options.Columns = parseStringArray(args["columns"])
|
||||
options.OmitColumns = parseStringArray(args["omit_columns"])
|
||||
options.Filters = parseFilters(args["filters"])
|
||||
options.Sort = parseSortOptions(args["sort"])
|
||||
options.Preload = parsePreloadOptions(args["preloads"])
|
||||
|
||||
return options
|
||||
}
|
||||
|
||||
func parseStringArray(raw interface{}) []string {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
if s, ok := item.(string); ok {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseFilters(raw interface{}) []common.FilterOption {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]common.FilterOption, 0, len(items))
|
||||
for _, item := range items {
|
||||
b, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var f common.FilterOption
|
||||
if err := json.Unmarshal(b, &f); err != nil {
|
||||
continue
|
||||
}
|
||||
if f.Column == "" || f.Operator == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(f.LogicOperator, "or") {
|
||||
f.LogicOperator = "OR"
|
||||
} else {
|
||||
f.LogicOperator = "AND"
|
||||
}
|
||||
result = append(result, f)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseSortOptions(raw interface{}) []common.SortOption {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]common.SortOption, 0, len(items))
|
||||
for _, item := range items {
|
||||
b, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var s common.SortOption
|
||||
if err := json.Unmarshal(b, &s); err != nil {
|
||||
continue
|
||||
}
|
||||
if s.Column == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, s)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parsePreloadOptions(raw interface{}) []common.PreloadOption {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]common.PreloadOption, 0, len(items))
|
||||
for _, item := range items {
|
||||
b, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var p common.PreloadOption
|
||||
if err := json.Unmarshal(b, &p); err != nil {
|
||||
continue
|
||||
}
|
||||
if p.Relation == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, p)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// marshalResult marshals a value to JSON and returns it as an MCP text result.
|
||||
func marshalResult(v interface{}) (*mcp.CallToolResult, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("error marshaling result: %v", err)), nil
|
||||
}
|
||||
return mcp.NewToolResultText(string(b)), nil
|
||||
}
|
||||
@@ -644,6 +644,7 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Tags []Tag `json:"tags,omitempty" gorm:"many2many:post_tags"`
|
||||
}
|
||||
|
||||
// Schema.Table format
|
||||
handler.registry.RegisterModel("core.users", &User{})
|
||||
handler.registry.RegisterModel("core.posts", &Post{})
|
||||
@@ -654,11 +655,13 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||
"github.com/gorilla/mux"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ const (
|
||||
// - pkName: primary key column (e.g. "id")
|
||||
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
||||
// - options: the request options containing sort and cursor information
|
||||
// - expandJoins: optional map[alias]string of JOIN clauses for join-column sort support
|
||||
//
|
||||
// Returns SQL snippet to embed in WHERE clause.
|
||||
func GetCursorFilter(
|
||||
@@ -31,8 +32,10 @@ func GetCursorFilter(
|
||||
pkName string,
|
||||
modelColumns []string,
|
||||
options common.RequestOptions,
|
||||
expandJoins map[string]string,
|
||||
) (string, error) {
|
||||
// Remove schema prefix if present
|
||||
// Separate schema prefix from bare table name
|
||||
fullTableName := tableName
|
||||
if strings.Contains(tableName, ".") {
|
||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||
}
|
||||
@@ -57,18 +60,19 @@ func GetCursorFilter(
|
||||
// 3. Prepare
|
||||
// --------------------------------------------------------------------- //
|
||||
var whereClauses []string
|
||||
joinSQL := ""
|
||||
reverse := direction < 0
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse: "created_at", "user.name", etc.
|
||||
// Parse: "created_at", "user.name", "fn.sortorder", etc.
|
||||
parts := strings.Split(col, ".")
|
||||
field := strings.TrimSpace(parts[len(parts)-1])
|
||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||
@@ -81,7 +85,7 @@ func GetCursorFilter(
|
||||
}
|
||||
|
||||
// Resolve column
|
||||
cursorCol, targetCol, err := resolveColumn(
|
||||
cursorCol, targetCol, isJoin, err := resolveColumn(
|
||||
field, prefix, tableName, modelColumns,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -89,6 +93,22 @@ func GetCursorFilter(
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle joins
|
||||
if isJoin {
|
||||
if expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
if cursorCol == "" {
|
||||
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Build inequality
|
||||
op := "<"
|
||||
if desc {
|
||||
@@ -112,10 +132,12 @@ func GetCursorFilter(
|
||||
query := fmt.Sprintf(`EXISTS (
|
||||
SELECT 1
|
||||
FROM %s cursor_select
|
||||
%s
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
tableName,
|
||||
fullTableName,
|
||||
joinSQL,
|
||||
pkName,
|
||||
cursorID,
|
||||
orSQL,
|
||||
@@ -136,35 +158,44 @@ func getActiveCursor(options common.RequestOptions) (id string, direction Cursor
|
||||
return "", 0
|
||||
}
|
||||
|
||||
// Helper: resolve column (main table only for now)
|
||||
// Helper: resolve column (main table or join)
|
||||
func resolveColumn(
|
||||
field, prefix, tableName string,
|
||||
modelColumns []string,
|
||||
) (cursorCol, targetCol string, err error) {
|
||||
) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||
|
||||
// JSON field
|
||||
if strings.Contains(field, "->") {
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
// Main table column
|
||||
if modelColumns != nil {
|
||||
for _, col := range modelColumns {
|
||||
if strings.EqualFold(col, field) {
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No validation → allow all main-table fields
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
// Joined column (not supported in resolvespec yet)
|
||||
// Joined column
|
||||
if prefix != "" && prefix != tableName {
|
||||
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
|
||||
return "", "", true, nil
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("invalid column: %s", field)
|
||||
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||
}
|
||||
|
||||
// Helper: rewrite JOIN clause for cursor subquery
|
||||
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||
cursorAlias = "cursor_select_" + alias
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||
return joinSQL, cursorAlias
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
|
||||
@@ -20,7 +20,7 @@ func TestGetCursorFilter_Forward(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func TestGetCursorFilter_Backward(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -96,7 +96,7 @@ func TestGetCursorFilter_NoCursor(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at"}
|
||||
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no cursor is provided")
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func TestGetCursorFilter_NoSort(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title"}
|
||||
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no sort columns are defined")
|
||||
}
|
||||
@@ -140,7 +140,7 @@ func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "priority", "created_at"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -170,19 +170,50 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "name", "email"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
// Should handle schema prefix properly
|
||||
if !strings.Contains(filter, "users") {
|
||||
t.Errorf("Filter should reference table name users, got: %s", filter)
|
||||
// Should include full schema-qualified name in FROM clause
|
||||
if !strings.Contains(filter, "public.users") {
|
||||
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_LateralJoin(t *testing.T) {
|
||||
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
|
||||
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{{Column: "fn.sortorder", Direction: "ASC"}},
|
||||
CursorForward: "8975",
|
||||
}
|
||||
|
||||
tableName := "core.account"
|
||||
pkName := "rid_account"
|
||||
modelColumns := []string{"rid_account", "description", "pastelno"}
|
||||
expandJoins := map[string]string{"fn": lateralJoin}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, expandJoins)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Generated lateral cursor filter: %s", filter)
|
||||
|
||||
if !strings.Contains(filter, "cursor_select_fn") {
|
||||
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
|
||||
}
|
||||
if !strings.Contains(filter, "sortorder") {
|
||||
t.Errorf("Filter should reference sortorder column, got: %s", filter)
|
||||
}
|
||||
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
|
||||
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetActiveCursor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -288,18 +319,19 @@ func TestResolveColumn(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Joined column (not supported)",
|
||||
name: "Joined column (isJoin=true, no error)",
|
||||
field: "name",
|
||||
prefix: "user",
|
||||
tableName: "posts",
|
||||
modelColumns: []string{"id", "title"},
|
||||
wantErr: true,
|
||||
wantErr: false,
|
||||
// cursorCol and targetCol are empty when isJoin=true; handled by caller
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
||||
cursor, target, isJoin, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
@@ -312,6 +344,14 @@ func TestResolveColumn(t *testing.T) {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// For join columns, cursor/target are empty and isJoin=true
|
||||
if isJoin {
|
||||
if cursor != "" || target != "" {
|
||||
t.Errorf("Expected empty cursor/target for join column, got %q / %q", cursor, target)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cursor != tt.wantCursor {
|
||||
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
|
||||
}
|
||||
@@ -362,7 +402,7 @@ func TestCursorFilter_SQL_Safety(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "created_at"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -44,8 +44,8 @@ func TestBuildFilterCondition(t *testing.T) {
|
||||
Operator: "in",
|
||||
Value: []string{"active", "pending"},
|
||||
},
|
||||
expectedCondition: "status IN (?)",
|
||||
expectedArgsCount: 1,
|
||||
expectedCondition: "status IN (?,?)",
|
||||
expectedArgsCount: 2,
|
||||
},
|
||||
{
|
||||
name: "LIKE operator",
|
||||
|
||||
@@ -138,6 +138,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
||||
validator := common.NewColumnValidator(model)
|
||||
req.Options = validator.FilterRequestOptions(req.Options)
|
||||
|
||||
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||
beforeCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Writer: w,
|
||||
Request: r,
|
||||
Operation: req.Operation,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
|
||||
code := http.StatusUnauthorized
|
||||
if beforeCtx.AbortCode != 0 {
|
||||
code = beforeCtx.AbortCode
|
||||
}
|
||||
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch req.Operation {
|
||||
case "read":
|
||||
h.handleRead(ctx, w, id, req.Options)
|
||||
@@ -309,8 +329,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Extract model columns for validation
|
||||
modelColumns := reflection.GetModelColumns(model)
|
||||
|
||||
// Get cursor filter SQL
|
||||
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
// Default sort to primary key when none provided
|
||||
if len(options.Sort) == 0 {
|
||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||
}
|
||||
|
||||
// Get cursor filter SQL (expandJoins is empty for resolvespec — no custom SQL join support yet)
|
||||
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
logger.Error("Error building cursor filter: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
|
||||
@@ -1236,6 +1261,24 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
|
||||
logger.Info("Deleting records from %s.%s", schema, entity)
|
||||
|
||||
// Execute BeforeDelete hooks (covers model-rule checks before any deletion)
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
ID: id,
|
||||
Data: data,
|
||||
Writer: w,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Error("BeforeDelete hook failed: %v", err)
|
||||
h.sendError(w, http.StatusForbidden, "delete_forbidden", "Delete operation not allowed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle batch delete from request data
|
||||
if data != nil {
|
||||
switch v := data.(type) {
|
||||
@@ -1483,22 +1526,22 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
|
||||
var args []interface{}
|
||||
|
||||
switch filter.Operator {
|
||||
case "eq":
|
||||
case "eq", "=":
|
||||
condition = fmt.Sprintf("%s = ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "neq":
|
||||
case "neq", "!=", "<>":
|
||||
condition = fmt.Sprintf("%s != ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gt":
|
||||
case "gt", ">":
|
||||
condition = fmt.Sprintf("%s > ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gte":
|
||||
case "gte", ">=":
|
||||
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lt":
|
||||
case "lt", "<":
|
||||
condition = fmt.Sprintf("%s < ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lte":
|
||||
case "lte", "<=":
|
||||
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "like":
|
||||
@@ -1508,8 +1551,10 @@ func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionStr
|
||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition = fmt.Sprintf("%s IN (?)", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
condition, args = common.BuildInCondition(filter.Column, filter.Value)
|
||||
if condition == "" {
|
||||
return "", nil
|
||||
}
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
@@ -1525,22 +1570,22 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
||||
var args []interface{}
|
||||
|
||||
switch filter.Operator {
|
||||
case "eq":
|
||||
case "eq", "=":
|
||||
condition = fmt.Sprintf("%s = ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "neq":
|
||||
case "neq", "!=", "<>":
|
||||
condition = fmt.Sprintf("%s != ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gt":
|
||||
case "gt", ">":
|
||||
condition = fmt.Sprintf("%s > ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gte":
|
||||
case "gte", ">=":
|
||||
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lt":
|
||||
case "lt", "<":
|
||||
condition = fmt.Sprintf("%s < ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lte":
|
||||
case "lte", "<=":
|
||||
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "like":
|
||||
@@ -1550,8 +1595,10 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition = fmt.Sprintf("%s IN (?)", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
condition, args = common.BuildInCondition(filter.Column, filter.Value)
|
||||
if condition == "" {
|
||||
return query
|
||||
}
|
||||
default:
|
||||
return query
|
||||
}
|
||||
|
||||
@@ -12,6 +12,10 @@ import (
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||
// Use this for auth checks that need model rules and user context simultaneously.
|
||||
BeforeHandle HookType = "before_handle"
|
||||
|
||||
// Read operation hooks
|
||||
BeforeRead HookType = "before_read"
|
||||
AfterRead HookType = "after_read"
|
||||
@@ -43,6 +47,9 @@ type HookContext struct {
|
||||
Writer common.ResponseWriter
|
||||
Request common.Request
|
||||
|
||||
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||
Operation string
|
||||
|
||||
// Operation-specific fields
|
||||
ID string
|
||||
Data interface{} // For create/update operations
|
||||
|
||||
@@ -70,17 +70,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
||||
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
|
||||
|
||||
// Create handler functions for this specific entity
|
||||
postEntityHandler := createMuxHandler(handler, schema, entity, "")
|
||||
postEntityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
||||
getEntityHandler := createMuxGetHandler(handler, schema, entity, "")
|
||||
var postEntityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
|
||||
var postEntityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
|
||||
var getEntityHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
|
||||
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
|
||||
|
||||
// Apply authentication middleware if provided
|
||||
if authMiddleware != nil {
|
||||
postEntityHandler = authMiddleware(postEntityHandler).(http.HandlerFunc)
|
||||
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler).(http.HandlerFunc)
|
||||
getEntityHandler = authMiddleware(getEntityHandler).(http.HandlerFunc)
|
||||
postEntityHandler = authMiddleware(postEntityHandler)
|
||||
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler)
|
||||
getEntityHandler = authMiddleware(getEntityHandler)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
}
|
||||
|
||||
@@ -216,9 +216,34 @@ type BunRouterHandler interface {
|
||||
Handle(method, path string, handler bunrouter.HandlerFunc)
|
||||
}
|
||||
|
||||
// wrapBunRouterHandler wraps a bunrouter handler with auth middleware if provided
|
||||
func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware MiddlewareFunc) bunrouter.HandlerFunc {
|
||||
if authMiddleware == nil {
|
||||
return handler
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
// Create an http.Handler that calls the bunrouter handler
|
||||
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Replace the embedded *http.Request with the middleware-enriched one
|
||||
// so that auth context (user ID, etc.) is visible to the handler.
|
||||
enrichedReq := req
|
||||
enrichedReq.Request = r
|
||||
_ = handler(w, enrichedReq)
|
||||
})
|
||||
|
||||
// Wrap with auth middleware and execute
|
||||
wrappedHandler := authMiddleware(httpHandler)
|
||||
wrappedHandler.ServeHTTP(w, req.Request)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
|
||||
// Accepts bunrouter.Router or bunrouter.Group
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||
|
||||
// CORS config
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
@@ -256,7 +281,7 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
currentEntity := entity
|
||||
|
||||
// POST route without ID
|
||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -267,10 +292,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityPath, wrapBunRouterHandler(postEntityHandler, authMiddleware))
|
||||
|
||||
// POST route with ID
|
||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -282,10 +308,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityWithIDPath, wrapBunRouterHandler(postEntityWithIDHandler, authMiddleware))
|
||||
|
||||
// GET route without ID
|
||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -296,10 +323,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityPath, wrapBunRouterHandler(getEntityHandler, authMiddleware))
|
||||
|
||||
// GET route with ID
|
||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -311,9 +339,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityWithIDPath, wrapBunRouterHandler(getEntityWithIDHandler, authMiddleware))
|
||||
|
||||
// OPTIONS route without ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
@@ -330,6 +360,7 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
})
|
||||
|
||||
// OPTIONS route with ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
@@ -355,8 +386,8 @@ func ExampleWithBunRouter(bunDB *bun.DB) {
|
||||
// Create bunrouter
|
||||
bunRouter := bunrouter.New()
|
||||
|
||||
// Setup ResolveSpec routes with bunrouter
|
||||
SetupBunRouterRoutes(bunRouter, handler)
|
||||
// Setup ResolveSpec routes with bunrouter without authentication
|
||||
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
@@ -377,8 +408,8 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
|
||||
// Create bunrouter
|
||||
bunRouter := bunrouter.New()
|
||||
|
||||
// Setup ResolveSpec routes
|
||||
SetupBunRouterRoutes(bunRouter, handler)
|
||||
// Setup ResolveSpec routes without authentication
|
||||
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||
|
||||
// This gives you the full uptrace stack: bunrouter + Bun ORM
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
@@ -396,8 +427,87 @@ func ExampleBunRouterWithGroup(bunDB *bun.DB) {
|
||||
apiGroup := bunRouter.NewGroup("/api")
|
||||
|
||||
// Setup ResolveSpec routes on the group - routes will be under /api
|
||||
SetupBunRouterRoutes(apiGroup, handler)
|
||||
SetupBunRouterRoutes(apiGroup, handler, nil)
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
}
|
||||
|
||||
// ExampleWithGORMAndAuth shows how to use ResolveSpec with GORM and authentication
|
||||
func ExampleWithGORMAndAuth(db *gorm.DB) {
|
||||
// Create handler using GORM
|
||||
_ = NewHandlerWithGORM(db)
|
||||
|
||||
// Create auth middleware
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
|
||||
// Setup router with authentication
|
||||
_ = mux.NewRouter()
|
||||
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||
|
||||
// Register models
|
||||
// handler.RegisterModel("public", "users", &User{})
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", muxRouter)
|
||||
}
|
||||
|
||||
// ExampleWithBunAndAuth shows how to use ResolveSpec with Bun and authentication
|
||||
func ExampleWithBunAndAuth(bunDB *bun.DB) {
|
||||
// Create Bun adapter
|
||||
dbAdapter := database.NewBunAdapter(bunDB)
|
||||
|
||||
// Create model registry
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
// registry.RegisterModel("public.users", &User{})
|
||||
|
||||
// Create handler
|
||||
_ = NewHandler(dbAdapter, registry)
|
||||
|
||||
// Create auth middleware
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
|
||||
// Setup routes with authentication
|
||||
_ = mux.NewRouter()
|
||||
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", muxRouter)
|
||||
}
|
||||
|
||||
// ExampleBunRouterWithBunDBAndAuth shows the full uptrace stack with authentication
|
||||
func ExampleBunRouterWithBunDBAndAuth(bunDB *bun.DB) {
|
||||
// Create Bun database adapter
|
||||
dbAdapter := database.NewBunAdapter(bunDB)
|
||||
|
||||
// Create model registry
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
// registry.RegisterModel("public.users", &User{})
|
||||
|
||||
// Create handler with Bun
|
||||
_ = NewHandler(dbAdapter, registry)
|
||||
|
||||
// Create auth middleware
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
|
||||
// Create bunrouter
|
||||
_ = bunrouter.New()
|
||||
|
||||
// Setup ResolveSpec routes with authentication
|
||||
// SetupBunRouterRoutes(bunRouter, handler, authMiddleware)
|
||||
|
||||
// This gives you the full uptrace stack: bunrouter + Bun ORM with authentication
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package resolvespec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -10,6 +11,17 @@ import (
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
@@ -34,6 +46,18 @@ func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Hook 5: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelUpdateAllowed(secCtx)
|
||||
})
|
||||
|
||||
// Hook 6: BeforeDelete - enforce CanDelete rule from context/registry
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelDeleteAllowed(secCtx)
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for resolvespec handler")
|
||||
}
|
||||
|
||||
|
||||
@@ -147,6 +147,7 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
|
||||
```
|
||||
|
||||
**Available Hook Types**:
|
||||
* `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
|
||||
* `BeforeRead`, `AfterRead`
|
||||
* `BeforeCreate`, `AfterCreate`
|
||||
* `BeforeUpdate`, `AfterUpdate`
|
||||
@@ -157,11 +158,13 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
|
||||
* `Handler`: Access to handler, database, and registry
|
||||
* `Schema`, `Entity`, `TableName`: Request info
|
||||
* `Model`: The registered model type
|
||||
* `Operation`: Current operation string (`"read"`, `"create"`, `"update"`, `"delete"`)
|
||||
* `Options`: Parsed request options (filters, sorting, etc.)
|
||||
* `ID`: Record ID (for single-record operations)
|
||||
* `Data`: Request data (for create/update)
|
||||
* `Result`: Operation result (for after hooks)
|
||||
* `Writer`: Response writer (allows hooks to modify response)
|
||||
* `Abort`, `AbortMessage`, `AbortCode`: Set in hook to abort with an error response
|
||||
|
||||
## Cursor Pagination
|
||||
|
||||
|
||||
@@ -32,6 +32,8 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
modelColumns []string, // optional: for validation
|
||||
expandJoins map[string]string, // optional: alias → JOIN SQL
|
||||
) (string, error) {
|
||||
// Separate schema prefix from bare table name
|
||||
fullTableName := tableName
|
||||
if strings.Contains(tableName, ".") {
|
||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||
}
|
||||
@@ -62,7 +64,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
@@ -91,12 +93,18 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
}
|
||||
|
||||
// Handle joins
|
||||
if isJoin && expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
if isJoin {
|
||||
if expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
if cursorCol == "" {
|
||||
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,7 +135,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
tableName,
|
||||
fullTableName,
|
||||
joinSQL,
|
||||
pkName,
|
||||
cursorID,
|
||||
|
||||
@@ -187,9 +187,9 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
// Should handle schema prefix properly
|
||||
if !strings.Contains(filter, "users") {
|
||||
t.Errorf("Filter should reference table name users, got: %s", filter)
|
||||
// Should include full schema-qualified name in FROM clause
|
||||
if !strings.Contains(filter, "public.users") {
|
||||
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||
@@ -278,6 +278,47 @@ func TestCleanSortField(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_LateralJoin(t *testing.T) {
|
||||
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
|
||||
|
||||
opts := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "fn.sortorder", Direction: "ASC"},
|
||||
},
|
||||
},
|
||||
}
|
||||
opts.CursorForward = "8975"
|
||||
|
||||
tableName := "core.account"
|
||||
pkName := "rid_account"
|
||||
// modelColumns does not contain "sortorder" - it's a lateral join computed column
|
||||
modelColumns := []string{"rid_account", "description", "pastelno"}
|
||||
expandJoins := map[string]string{"fn": lateralJoin}
|
||||
|
||||
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Generated lateral cursor filter: %s", filter)
|
||||
|
||||
// Should contain the rewritten lateral join inside the EXISTS subquery
|
||||
if !strings.Contains(filter, "cursor_select_fn") {
|
||||
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
|
||||
}
|
||||
|
||||
// Should compare fn.sortorder values
|
||||
if !strings.Contains(filter, "sortorder") {
|
||||
t.Errorf("Filter should reference sortorder column, got: %s", filter)
|
||||
}
|
||||
|
||||
// Should NOT contain empty comparison like "< "
|
||||
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
|
||||
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPriorityChain(t *testing.T) {
|
||||
clauses := []string{
|
||||
"cursor_select.priority > posts.priority",
|
||||
|
||||
@@ -133,6 +133,41 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
||||
// Add request-scoped data to context (including options)
|
||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
||||
|
||||
// Derive operation for auth check
|
||||
var operation string
|
||||
switch method {
|
||||
case "GET":
|
||||
operation = "read"
|
||||
case "POST":
|
||||
operation = "create"
|
||||
case "PUT", "PATCH":
|
||||
operation = "update"
|
||||
case "DELETE":
|
||||
operation = "delete"
|
||||
default:
|
||||
operation = "read"
|
||||
}
|
||||
|
||||
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||
beforeCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Writer: w,
|
||||
Request: r,
|
||||
Operation: operation,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
|
||||
code := http.StatusUnauthorized
|
||||
if beforeCtx.AbortCode != 0 {
|
||||
code = beforeCtx.AbortCode
|
||||
}
|
||||
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "GET":
|
||||
if id != "" {
|
||||
@@ -688,12 +723,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Extract model columns for validation using the generic database function
|
||||
modelColumns := reflection.GetModelColumns(model)
|
||||
|
||||
// Build expand joins map (if needed in future)
|
||||
var expandJoins map[string]string
|
||||
if len(options.Expand) > 0 {
|
||||
expandJoins = make(map[string]string)
|
||||
// TODO: Build actual JOIN SQL for each expand relation
|
||||
// For now, pass empty map as joins are handled via Preload
|
||||
// Build expand joins map: custom SQL joins are available in cursor subquery
|
||||
expandJoins := make(map[string]string)
|
||||
for _, joinClause := range options.CustomSQLJoin {
|
||||
alias := extractJoinAlias(joinClause)
|
||||
if alias != "" {
|
||||
expandJoins[alias] = joinClause
|
||||
}
|
||||
}
|
||||
// TODO: also add Expand relation JOINs when those are built as SQL rather than Preload
|
||||
|
||||
// Default sort to primary key when none provided
|
||||
if len(options.Sort) == 0 {
|
||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||
}
|
||||
|
||||
// Get cursor filter SQL
|
||||
@@ -1498,8 +1540,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Warn("BeforeDelete hook failed for ID %s: %v", itemID, err)
|
||||
continue
|
||||
logger.Error("BeforeDelete hook failed for ID %s: %v", itemID, err)
|
||||
return fmt.Errorf("delete not allowed for ID %s: %w", itemID, err)
|
||||
}
|
||||
|
||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
@@ -1572,8 +1614,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Warn("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||
continue
|
||||
logger.Error("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||
return fmt.Errorf("delete not allowed for ID %v: %w", itemID, err)
|
||||
}
|
||||
|
||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
@@ -1630,8 +1672,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Warn("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||
continue
|
||||
logger.Error("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||
return fmt.Errorf("delete not allowed for ID %v: %w", itemID, err)
|
||||
}
|
||||
|
||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
@@ -2111,7 +2153,11 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
||||
// Column is already cast to TEXT if needed
|
||||
return applyWhere(fmt.Sprintf("%s ILIKE ?", qualifiedColumn), filter.Value)
|
||||
case "in":
|
||||
return applyWhere(fmt.Sprintf("%s IN (?)", qualifiedColumn), filter.Value)
|
||||
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
|
||||
if cond == "" {
|
||||
return query
|
||||
}
|
||||
return applyWhere(cond, inArgs...)
|
||||
case "between":
|
||||
// Handle between operator - exclusive (> val1 AND < val2)
|
||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||
@@ -2187,24 +2233,25 @@ func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common
|
||||
// buildFilterCondition builds a single filter condition and returns the condition string and args
|
||||
func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) {
|
||||
switch strings.ToLower(filter.Operator) {
|
||||
case "eq", "equals":
|
||||
case "eq", "equals", "=":
|
||||
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "neq", "not_equals", "ne":
|
||||
case "neq", "not_equals", "ne", "!=", "<>":
|
||||
return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "gt", "greater_than":
|
||||
case "gt", "greater_than", ">":
|
||||
return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "gte", "greater_than_equals", "ge":
|
||||
case "gte", "greater_than_equals", "ge", ">=":
|
||||
return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "lt", "less_than":
|
||||
case "lt", "less_than", "<":
|
||||
return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "lte", "less_than_equals", "le":
|
||||
case "lte", "less_than_equals", "le", "<=":
|
||||
return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "like":
|
||||
return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "in":
|
||||
return fmt.Sprintf("%s IN (?)", qualifiedColumn), []interface{}{filter.Value}
|
||||
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
|
||||
return cond, inArgs
|
||||
case "between":
|
||||
// Handle between operator - exclusive (> val1 AND < val2)
|
||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||
@@ -2839,6 +2886,8 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
|
||||
|
||||
// Filter base RequestOptions
|
||||
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
|
||||
// Restore JoinAliases cleared by FilterRequestOptions — still needed for SanitizeWhereClause
|
||||
filtered.RequestOptions.JoinAliases = options.JoinAliases
|
||||
|
||||
// Filter SearchColumns
|
||||
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)
|
||||
|
||||
@@ -274,9 +274,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve relation names (convert table names to field names) if model is provided
|
||||
// Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names
|
||||
if model != nil && !options.XFilesPresent {
|
||||
// Resolve relation names (convert table names/prefixes to actual model field names) if model is provided.
|
||||
// This runs for both regular headers and X-Files, because XFile prefixes don't always match model
|
||||
// field names (e.g., prefix "HUB" vs field "HUB_RID_HUB"). RelatedKey/ForeignKey are used to
|
||||
// disambiguate when multiple fields point to the same related type.
|
||||
if model != nil {
|
||||
h.resolveRelationNamesInOptions(&options, model)
|
||||
}
|
||||
|
||||
@@ -550,10 +552,8 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri
|
||||
// - "LEFT JOIN departments d ON ..." -> "d"
|
||||
// - "INNER JOIN users AS u ON ..." -> "u"
|
||||
// - "JOIN roles r ON ..." -> "r"
|
||||
// - "INNER JOIN LATERAL (...) fn ON true" -> "fn"
|
||||
func extractJoinAlias(joinClause string) string {
|
||||
// Pattern: JOIN table_name [AS] alias ON ...
|
||||
// We need to extract the alias (word before ON)
|
||||
|
||||
upperJoin := strings.ToUpper(joinClause)
|
||||
|
||||
// Find the "JOIN" keyword position
|
||||
@@ -562,7 +562,20 @@ func extractJoinAlias(joinClause string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the "ON" keyword position
|
||||
// Lateral joins: alias is the word after the closing ) and before ON
|
||||
if strings.Contains(upperJoin, "LATERAL") {
|
||||
lastClose := strings.LastIndex(joinClause, ")")
|
||||
if lastClose != -1 {
|
||||
words := strings.Fields(joinClause[lastClose+1:])
|
||||
// words should be like ["fn", "on", "true"] or ["on", "true"]
|
||||
if len(words) >= 1 && !strings.EqualFold(words[0], "on") {
|
||||
return words[0]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Regular joins: find the "ON" keyword position (first occurrence)
|
||||
onIdx := strings.Index(upperJoin, " ON ")
|
||||
if onIdx == -1 {
|
||||
return ""
|
||||
@@ -863,8 +876,21 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
|
||||
|
||||
// Resolve each part of the path
|
||||
currentModel := model
|
||||
for _, part := range parts {
|
||||
resolvedPart := h.resolveRelationName(currentModel, part)
|
||||
for partIdx, part := range parts {
|
||||
isLast := partIdx == len(parts)-1
|
||||
var resolvedPart string
|
||||
if isLast {
|
||||
// For the final part, use join-key-aware resolution to disambiguate when
|
||||
// multiple fields point to the same type (e.g., HUB_RID_HUB vs HUB_RID_ASSIGNEDTO).
|
||||
// RelatedKey = parent's local column linking to child; ForeignKey = local column linking to parent.
|
||||
localKey := preload.RelatedKey
|
||||
if localKey == "" {
|
||||
localKey = preload.ForeignKey
|
||||
}
|
||||
resolvedPart = h.resolveRelationNameWithJoinKey(currentModel, part, localKey)
|
||||
} else {
|
||||
resolvedPart = h.resolveRelationName(currentModel, part)
|
||||
}
|
||||
resolvedParts = append(resolvedParts, resolvedPart)
|
||||
|
||||
// Try to get the model type for the next level
|
||||
@@ -980,6 +1006,101 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
|
||||
return nameOrTable
|
||||
}
|
||||
|
||||
// resolveRelationNameWithJoinKey resolves a relation name like resolveRelationName, but when
|
||||
// multiple fields point to the same related type, uses localKey to pick the one whose bun join
|
||||
// tag starts with "join:localKey=". Falls back to resolveRelationName if no key match is found.
|
||||
func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable string, localKey string) string {
|
||||
if localKey == "" {
|
||||
return h.resolveRelationName(model, nameOrTable)
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType == nil {
|
||||
return nameOrTable
|
||||
}
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return nameOrTable
|
||||
}
|
||||
|
||||
// If it's already a direct field name, return as-is (no ambiguity).
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
if modelType.Field(i).Name == nameOrTable {
|
||||
return nameOrTable
|
||||
}
|
||||
}
|
||||
|
||||
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
|
||||
localKeyLower := strings.ToLower(localKey)
|
||||
|
||||
// Find all fields whose related type matches nameOrTable, then pick the one
|
||||
// whose bun join tag local key matches localKey.
|
||||
var fallbackField string
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
fieldType := field.Type
|
||||
|
||||
var targetType reflect.Type
|
||||
if fieldType.Kind() == reflect.Slice {
|
||||
targetType = fieldType.Elem()
|
||||
} else if fieldType.Kind() == reflect.Ptr {
|
||||
targetType = fieldType.Elem()
|
||||
}
|
||||
if targetType != nil && targetType.Kind() == reflect.Ptr {
|
||||
targetType = targetType.Elem()
|
||||
}
|
||||
if targetType == nil || targetType.Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
|
||||
normalizedTypeName := strings.ToLower(targetType.Name())
|
||||
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
|
||||
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
|
||||
if normalizedTypeName != normalizedInput {
|
||||
continue
|
||||
}
|
||||
|
||||
// Type name matches; record as fallback.
|
||||
if fallbackField == "" {
|
||||
fallbackField = field.Name
|
||||
}
|
||||
|
||||
// Check bun join tag: "join:localKey=foreignKey"
|
||||
bunTag := field.Tag.Get("bun")
|
||||
for _, tagPart := range strings.Split(bunTag, ",") {
|
||||
tagPart = strings.TrimSpace(tagPart)
|
||||
if !strings.HasPrefix(tagPart, "join:") {
|
||||
continue
|
||||
}
|
||||
joinSpec := strings.TrimPrefix(tagPart, "join:")
|
||||
// joinSpec can be "col1=col2" or "col1=col2 col3=col4" (multi-col joins)
|
||||
joinCols := strings.Fields(joinSpec)
|
||||
if len(joinCols) == 0 {
|
||||
joinCols = []string{joinSpec}
|
||||
}
|
||||
for _, joinCol := range joinCols {
|
||||
eqIdx := strings.Index(joinCol, "=")
|
||||
if eqIdx < 0 {
|
||||
continue
|
||||
}
|
||||
joinLocalKey := strings.ToLower(joinCol[:eqIdx])
|
||||
if joinLocalKey == localKeyLower {
|
||||
logger.Debug("Resolved '%s' (localKey: %s) -> field '%s'", nameOrTable, localKey, field.Name)
|
||||
return field.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fallbackField != "" {
|
||||
logger.Debug("No join key match for '%s' (localKey: %s), using first type match: '%s'", nameOrTable, localKey, fallbackField)
|
||||
return fallbackField
|
||||
}
|
||||
return h.resolveRelationName(model, nameOrTable)
|
||||
}
|
||||
|
||||
// addXFilesPreload converts an XFiles relation into a PreloadOption
|
||||
// and recursively processes its children
|
||||
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
|
||||
@@ -1061,15 +1182,42 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
}
|
||||
}
|
||||
|
||||
// Transfer SqlJoins from XFiles to PreloadOption first, so aliases are available for WHERE sanitization
|
||||
if len(xfile.SqlJoins) > 0 {
|
||||
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
|
||||
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
|
||||
|
||||
for _, joinClause := range xfile.SqlJoins {
|
||||
// Sanitize the join clause
|
||||
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
|
||||
if sanitizedJoin == "" {
|
||||
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
|
||||
continue
|
||||
}
|
||||
|
||||
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
|
||||
|
||||
// Extract join alias for validation
|
||||
alias := extractJoinAlias(sanitizedJoin)
|
||||
if alias != "" {
|
||||
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
|
||||
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
|
||||
}
|
||||
|
||||
// Add WHERE clause if SQL conditions specified
|
||||
// SqlJoins must be processed first so join aliases are known and not incorrectly replaced
|
||||
whereConditions := make([]string, 0)
|
||||
if len(xfile.SqlAnd) > 0 {
|
||||
// Process each SQL condition
|
||||
// Note: We don't add table prefixes here because they're only needed for JOINs
|
||||
// The handler will add prefixes later if SqlJoins are present
|
||||
var sqlAndOpts *common.RequestOptions
|
||||
if len(preloadOpt.JoinAliases) > 0 {
|
||||
sqlAndOpts = &common.RequestOptions{JoinAliases: preloadOpt.JoinAliases}
|
||||
}
|
||||
for _, sqlCond := range xfile.SqlAnd {
|
||||
// Sanitize the condition without adding prefixes
|
||||
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName)
|
||||
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName, sqlAndOpts)
|
||||
if sanitizedCond != "" {
|
||||
whereConditions = append(whereConditions, sanitizedCond)
|
||||
}
|
||||
@@ -1114,32 +1262,6 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
||||
}
|
||||
|
||||
// Transfer SqlJoins from XFiles to PreloadOption
|
||||
if len(xfile.SqlJoins) > 0 {
|
||||
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
|
||||
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
|
||||
|
||||
for _, joinClause := range xfile.SqlJoins {
|
||||
// Sanitize the join clause
|
||||
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
|
||||
if sanitizedJoin == "" {
|
||||
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
|
||||
continue
|
||||
}
|
||||
|
||||
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
|
||||
|
||||
// Extract join alias for validation
|
||||
alias := extractJoinAlias(sanitizedJoin)
|
||||
if alias != "" {
|
||||
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
|
||||
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
|
||||
}
|
||||
|
||||
// Check if this table has a recursive child - if so, mark THIS preload as recursive
|
||||
// and store the recursive child's RelatedKey for recursion generation
|
||||
hasRecursiveChild := false
|
||||
|
||||
@@ -142,6 +142,16 @@ func TestExtractJoinAlias(t *testing.T) {
|
||||
joinClause: "LEFT JOIN departments",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "LATERAL join with alias",
|
||||
joinClause: "inner join lateral (select sortorder from compute_fn(t.id)) fn on true",
|
||||
expected: "fn",
|
||||
},
|
||||
{
|
||||
name: "LATERAL join with multiline subquery containing inner ON",
|
||||
joinClause: "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(t.id) r\ninner join account a on a.id = r.id\n) fn on true",
|
||||
expected: "fn",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -12,6 +12,10 @@ import (
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||
// Use this for auth checks that need model rules and user context simultaneously.
|
||||
BeforeHandle HookType = "before_handle"
|
||||
|
||||
// Read operation hooks
|
||||
BeforeRead HookType = "before_read"
|
||||
AfterRead HookType = "after_read"
|
||||
@@ -42,6 +46,9 @@ type HookContext struct {
|
||||
Model interface{}
|
||||
Options ExtendedRequestOptions
|
||||
|
||||
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||
Operation string
|
||||
|
||||
// Operation-specific fields
|
||||
ID string
|
||||
Data interface{} // For create/update operations
|
||||
@@ -56,6 +63,14 @@ type HookContext struct {
|
||||
// Response writer - allows hooks to modify response
|
||||
Writer common.ResponseWriter
|
||||
|
||||
// Request - the original HTTP request
|
||||
Request common.Request
|
||||
|
||||
// Allow hooks to abort the operation
|
||||
Abort bool // If set to true, the operation will be aborted
|
||||
AbortMessage string // Message to return if aborted
|
||||
AbortCode int // HTTP status code if aborted
|
||||
|
||||
// Tx provides access to the database/transaction for executing additional SQL
|
||||
// This allows hooks to run custom queries in addition to the main Query chain
|
||||
Tx common.Database
|
||||
@@ -110,6 +125,12 @@ func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
|
||||
return fmt.Errorf("hook execution failed: %w", err)
|
||||
}
|
||||
|
||||
// Check if hook requested abort
|
||||
if ctx.Abort {
|
||||
logger.Warn("Hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// logger.Debug("All hooks for %s executed successfully", hookType)
|
||||
|
||||
@@ -125,17 +125,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
||||
metadataPath := buildRoutePath(schema, entity) + "/metadata"
|
||||
|
||||
// Create handler functions for this specific entity
|
||||
entityHandler := createMuxHandler(handler, schema, entity, "")
|
||||
entityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
||||
metadataHandler := createMuxGetHandler(handler, schema, entity, "")
|
||||
var entityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
|
||||
var entityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
|
||||
var metadataHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
|
||||
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"})
|
||||
|
||||
// Apply authentication middleware if provided
|
||||
if authMiddleware != nil {
|
||||
entityHandler = authMiddleware(entityHandler).(http.HandlerFunc)
|
||||
entityWithIDHandler = authMiddleware(entityWithIDHandler).(http.HandlerFunc)
|
||||
metadataHandler = authMiddleware(metadataHandler).(http.HandlerFunc)
|
||||
entityHandler = authMiddleware(entityHandler)
|
||||
entityWithIDHandler = authMiddleware(entityWithIDHandler)
|
||||
metadataHandler = authMiddleware(metadataHandler)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
}
|
||||
|
||||
@@ -280,9 +280,34 @@ type BunRouterHandler interface {
|
||||
Handle(method, path string, handler bunrouter.HandlerFunc)
|
||||
}
|
||||
|
||||
// wrapBunRouterHandler wraps a bunrouter handler with auth middleware if provided
|
||||
func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware MiddlewareFunc) bunrouter.HandlerFunc {
|
||||
if authMiddleware == nil {
|
||||
return handler
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
// Create an http.Handler that calls the bunrouter handler
|
||||
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Replace the embedded *http.Request with the middleware-enriched one
|
||||
// so that auth context (user ID, etc.) is visible to the handler.
|
||||
enrichedReq := req
|
||||
enrichedReq.Request = r
|
||||
_ = handler(w, enrichedReq)
|
||||
})
|
||||
|
||||
// Wrap with auth middleware and execute
|
||||
wrappedHandler := authMiddleware(httpHandler)
|
||||
wrappedHandler.ServeHTTP(w, req.Request)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API
|
||||
// Accepts bunrouter.Router or bunrouter.Group
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||
|
||||
// CORS config
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
@@ -292,6 +317,14 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -313,7 +346,7 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
currentEntity := entity
|
||||
|
||||
// GET and POST for /{schema}/{entity}
|
||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -324,9 +357,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityPath, wrapBunRouterHandler(getEntityHandler, authMiddleware))
|
||||
|
||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -337,10 +371,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityPath, wrapBunRouterHandler(postEntityHandler, authMiddleware))
|
||||
|
||||
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}/:id
|
||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -352,9 +387,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityWithIDPath, wrapBunRouterHandler(getEntityWithIDHandler, authMiddleware))
|
||||
|
||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -366,9 +402,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityWithIDPath, wrapBunRouterHandler(postEntityWithIDHandler, authMiddleware))
|
||||
|
||||
r.Handle("PUT", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
putEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -380,9 +417,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("PUT", entityWithIDPath, wrapBunRouterHandler(putEntityWithIDHandler, authMiddleware))
|
||||
|
||||
r.Handle("PATCH", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
patchEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -394,9 +432,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("PATCH", entityWithIDPath, wrapBunRouterHandler(patchEntityWithIDHandler, authMiddleware))
|
||||
|
||||
r.Handle("DELETE", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
deleteEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -408,10 +447,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("DELETE", entityWithIDPath, wrapBunRouterHandler(deleteEntityWithIDHandler, authMiddleware))
|
||||
|
||||
// Metadata endpoint
|
||||
r.Handle("GET", metadataPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
metadataHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
@@ -422,9 +462,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", metadataPath, wrapBunRouterHandler(metadataHandler, authMiddleware))
|
||||
|
||||
// OPTIONS route without ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
@@ -441,6 +483,7 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
})
|
||||
|
||||
// OPTIONS route with ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
@@ -466,8 +509,8 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
|
||||
// Create bunrouter
|
||||
bunRouter := bunrouter.New()
|
||||
|
||||
// Setup routes
|
||||
SetupBunRouterRoutes(bunRouter, handler)
|
||||
// Setup routes without authentication
|
||||
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||
|
||||
// Start server
|
||||
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
||||
@@ -487,7 +530,7 @@ func ExampleBunRouterWithGroup(bunDB *bun.DB) {
|
||||
apiGroup := bunRouter.NewGroup("/api")
|
||||
|
||||
// Setup RestHeadSpec routes on the group - routes will be under /api
|
||||
SetupBunRouterRoutes(apiGroup, handler)
|
||||
SetupBunRouterRoutes(apiGroup, handler, nil)
|
||||
|
||||
// Start server
|
||||
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package restheadspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
@@ -9,6 +10,17 @@ import (
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
@@ -33,6 +45,18 @@ func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Hook 5: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelUpdateAllowed(secCtx)
|
||||
})
|
||||
|
||||
// Hook 6: BeforeDelete - enforce CanDelete rule from context/registry
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelDeleteAllowed(secCtx)
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for restheadspec handler")
|
||||
}
|
||||
|
||||
|
||||
@@ -258,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
|
||||
// Add to blacklist
|
||||
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{
|
||||
"token": req.Token,
|
||||
"user_id": req.UserID,
|
||||
}).Error
|
||||
// Invalidate session via stored procedure
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
|
||||
@@ -405,11 +402,16 @@ assert.Equal(t, "user_id = {UserID}", row.Template)
|
||||
```
|
||||
HTTP Request
|
||||
↓
|
||||
NewAuthMiddleware → calls provider.Authenticate()
|
||||
↓ (adds UserContext to context)
|
||||
NewOptionalAuthMiddleware → calls provider.Authenticate()
|
||||
↓ (adds UserContext or guest context; never 401)
|
||||
SetSecurityMiddleware → adds SecurityList to context
|
||||
↓
|
||||
Handler.Handle()
|
||||
Handler.Handle() → resolves model
|
||||
↓
|
||||
BeforeHandle Hook → CheckModelAuthAllowed(secCtx, operation)
|
||||
├─ SecurityDisabled → allow
|
||||
├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
|
||||
└─ UserID == 0 → abort 401
|
||||
↓
|
||||
BeforeRead Hook → calls provider.GetColumnSecurity() + GetRowSecurity()
|
||||
↓
|
||||
@@ -693,15 +695,30 @@ http.Handle("/api/protected", authHandler)
|
||||
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
|
||||
http.Handle("/home", optionalHandler)
|
||||
|
||||
// Example handler
|
||||
func myHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := security.GetUserContext(r.Context())
|
||||
if userCtx.UserID == 0 {
|
||||
// Guest user
|
||||
} else {
|
||||
// Authenticated user
|
||||
}
|
||||
}
|
||||
// NewOptionalAuthMiddleware - For spec routes; auth enforcement deferred to BeforeHandle
|
||||
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
|
||||
apiRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||
restheadspec.RegisterSecurityHooks(handler, securityList) // includes BeforeHandle
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Model-Level Access Control
|
||||
|
||||
```go
|
||||
// Register model with rules (pkg/modelregistry)
|
||||
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
|
||||
SecurityDisabled: false, // skip all auth when true
|
||||
CanPublicRead: true, // unauthenticated reads allowed
|
||||
CanPublicCreate: false, // requires auth
|
||||
CanPublicUpdate: false, // requires auth
|
||||
CanPublicDelete: false, // requires auth
|
||||
CanUpdate: true, // authenticated can update
|
||||
CanDelete: false, // authenticated cannot delete (enforced in BeforeDelete)
|
||||
})
|
||||
|
||||
// CheckModelAuthAllowed used automatically in BeforeHandle hook
|
||||
// No code needed — call RegisterSecurityHooks and it's applied
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
@@ -12,6 +12,7 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
||||
- ✅ **Testable** - Easy to mock and test
|
||||
- ✅ **Extensible** - Implement custom providers for your needs
|
||||
- ✅ **Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability
|
||||
- ✅ **OAuth2 Authorization Server** - Built-in OAuth 2.1 + PKCE server (RFC 8414, 7591, 7009, 7662) with login form and external provider federation
|
||||
|
||||
## Stored Procedure Architecture
|
||||
|
||||
@@ -38,6 +39,12 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
||||
| `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator |
|
||||
| `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider |
|
||||
| `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider |
|
||||
| `resolvespec_oauth_register_client` | Persist OAuth2 client (RFC 7591) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_get_client` | Retrieve OAuth2 client by ID | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_save_code` | Persist authorization code | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_exchange_code` | Consume authorization code (single-use) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_introspect` | Token introspection (RFC 7662) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_revoke` | Token revocation (RFC 7009) | OAuthServer / DatabaseAuthenticator |
|
||||
|
||||
See `database_schema.sql` for complete stored procedure definitions and examples.
|
||||
|
||||
@@ -751,14 +758,25 @@ resolvespec.RegisterSecurityHooks(resolveHandler, securityList)
|
||||
```
|
||||
HTTP Request
|
||||
↓
|
||||
NewAuthMiddleware (security package)
|
||||
NewOptionalAuthMiddleware (security package) ← recommended for spec routes
|
||||
├─ Calls provider.Authenticate(request)
|
||||
└─ Adds UserContext to context
|
||||
├─ On success: adds authenticated UserContext to context
|
||||
└─ On failure: adds guest UserContext (UserID=0) to context
|
||||
↓
|
||||
SetSecurityMiddleware (security package)
|
||||
└─ Adds SecurityList to context
|
||||
↓
|
||||
Spec Handler (restheadspec/funcspec/resolvespec)
|
||||
Spec Handler (restheadspec/funcspec/resolvespec/websocketspec/mqttspec)
|
||||
└─ Resolves schema + entity + model from request
|
||||
↓
|
||||
BeforeHandle Hook (registered by spec via RegisterSecurityHooks)
|
||||
├─ Adapts spec's HookContext → SecurityContext
|
||||
├─ Calls security.CheckModelAuthAllowed(secCtx, operation)
|
||||
│ ├─ Loads model rules from context or registry
|
||||
│ ├─ SecurityDisabled → allow
|
||||
│ ├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
|
||||
│ └─ UserID == 0 → 401 unauthorized
|
||||
└─ On error: aborts with 401
|
||||
↓
|
||||
BeforeRead Hook (registered by spec)
|
||||
├─ Adapts spec's HookContext → SecurityContext
|
||||
@@ -784,7 +802,8 @@ HTTP Response (secured data)
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Security package is spec-agnostic and provides core logic
|
||||
- `NewOptionalAuthMiddleware` never rejects — it sets guest context on auth failure; `BeforeHandle` enforces auth after model resolution
|
||||
- `BeforeHandle` fires after model resolution, giving access to model rules and user context simultaneously
|
||||
- Each spec registers its own hooks that adapt to SecurityContext
|
||||
- Security rules are loaded once and cached for the request
|
||||
- Row security is applied to the query (database level)
|
||||
@@ -885,6 +904,155 @@ securityList := security.NewSecurityList(provider)
|
||||
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
|
||||
```
|
||||
|
||||
## OAuth2 Authorization Server
|
||||
|
||||
`OAuthServer` is a generic OAuth 2.1 + PKCE authorization server. It is not tied to any spec — `pkg/resolvemcp` uses it, but it can be used standalone with any `http.ServeMux`.
|
||||
|
||||
### Endpoints
|
||||
|
||||
| Method | Path | RFC |
|
||||
|--------|------|-----|
|
||||
| `GET` | `/.well-known/oauth-authorization-server` | RFC 8414 — server metadata |
|
||||
| `POST` | `/oauth/register` | RFC 7591 — dynamic client registration |
|
||||
| `GET` | `/oauth/authorize` | OAuth 2.1 — start authorization / provider selection |
|
||||
| `POST` | `/oauth/authorize` | OAuth 2.1 — login form submission |
|
||||
| `POST` | `/oauth/token` | OAuth 2.1 — code exchange + refresh |
|
||||
| `POST` | `/oauth/revoke` | RFC 7009 — token revocation |
|
||||
| `POST` | `/oauth/introspect` | RFC 7662 — token introspection |
|
||||
| `GET` | `{ProviderCallbackPath}` | External provider redirect target |
|
||||
|
||||
### Config
|
||||
|
||||
```go
|
||||
cfg := security.OAuthServerConfig{
|
||||
Issuer: "https://example.com", // Required — token issuer URL
|
||||
ProviderCallbackPath: "/oauth/provider/callback", // External provider redirect target
|
||||
LoginTitle: "My App Login", // HTML login page title
|
||||
PersistClients: true, // Store clients in DB (multi-instance safe)
|
||||
PersistCodes: true, // Store codes in DB (multi-instance safe)
|
||||
DefaultScopes: []string{"openid", "profile"}, // Returned when no scope requested
|
||||
AccessTokenTTL: time.Hour,
|
||||
AuthCodeTTL: 5 * time.Minute,
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Default | Notes |
|
||||
|-------|---------|-------|
|
||||
| `Issuer` | — | Required |
|
||||
| `ProviderCallbackPath` | `/oauth/provider/callback` | |
|
||||
| `LoginTitle` | `"Login"` | |
|
||||
| `PersistClients` | `false` | Set `true` for multi-instance |
|
||||
| `PersistCodes` | `false` | Set `true` for multi-instance |
|
||||
| `DefaultScopes` | `nil` | |
|
||||
| `AccessTokenTTL` | `1h` | |
|
||||
| `AuthCodeTTL` | `5m` | |
|
||||
|
||||
### Operating Modes
|
||||
|
||||
**Mode 1 — Direct login (username/password form)**
|
||||
|
||||
Pass a `*DatabaseAuthenticator` to `NewOAuthServer`. The server renders a login form at `GET /oauth/authorize` and issues tokens via the stored session after login.
|
||||
|
||||
```go
|
||||
auth := security.NewDatabaseAuthenticator(db)
|
||||
srv := security.NewOAuthServer(cfg, auth)
|
||||
```
|
||||
|
||||
**Mode 2 — External provider federation**
|
||||
|
||||
Pass `nil` as auth and register external providers. The authorize page shows a provider selection UI.
|
||||
|
||||
```go
|
||||
srv := security.NewOAuthServer(cfg, nil)
|
||||
srv.RegisterExternalProvider(googleAuth, "google")
|
||||
srv.RegisterExternalProvider(githubAuth, "github")
|
||||
```
|
||||
|
||||
**Mode 3 — Both**
|
||||
|
||||
Pass auth for the login form and also register external providers. The authorize page shows both a login form and provider buttons.
|
||||
|
||||
```go
|
||||
srv := security.NewOAuthServer(cfg, auth)
|
||||
srv.RegisterExternalProvider(googleAuth, "google")
|
||||
```
|
||||
|
||||
### Standalone Usage
|
||||
|
||||
```go
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/.well-known/", srv.HTTPHandler())
|
||||
mux.Handle("/oauth/", srv.HTTPHandler())
|
||||
mux.Handle(cfg.ProviderCallbackPath, srv.HTTPHandler())
|
||||
|
||||
http.ListenAndServe(":8080", mux)
|
||||
```
|
||||
|
||||
### DB Persistence
|
||||
|
||||
When `PersistClients: true` or `PersistCodes: true`, the server calls the corresponding `DatabaseAuthenticator` methods. Both flags default to `false` (in-memory maps). Enable both for multi-instance deployments.
|
||||
|
||||
Requires `oauth_clients` and `oauth_codes` tables + 6 stored procedures from `database_schema.sql`.
|
||||
|
||||
#### New DB Types
|
||||
|
||||
```go
|
||||
type OAuthServerClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||
}
|
||||
|
||||
type OAuthCode struct {
|
||||
Code string `json:"code"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ClientState string `json:"client_state,omitempty"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
SessionToken string `json:"session_token"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type OAuthTokenInfo struct {
|
||||
Active bool `json:"active"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
}
|
||||
```
|
||||
|
||||
#### DatabaseAuthenticator OAuth Methods
|
||||
|
||||
```go
|
||||
auth.OAuthRegisterClient(ctx, client) // RFC 7591 — persist client
|
||||
auth.OAuthGetClient(ctx, clientID) // retrieve client
|
||||
auth.OAuthSaveCode(ctx, code) // persist authorization code
|
||||
auth.OAuthExchangeCode(ctx, code) // consume code (single-use, deletes on read)
|
||||
auth.OAuthIntrospectToken(ctx, token) // RFC 7662 — returns OAuthTokenInfo
|
||||
auth.OAuthRevokeToken(ctx, token) // RFC 7009 — revoke session
|
||||
```
|
||||
|
||||
#### SQLNames Fields
|
||||
|
||||
```go
|
||||
type SQLNames struct {
|
||||
// ... existing fields ...
|
||||
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
|
||||
OAuthGetClient string // default: "resolvespec_oauth_get_client"
|
||||
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
|
||||
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
|
||||
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
|
||||
OAuthRevoke string // default: "resolvespec_oauth_revoke"
|
||||
}
|
||||
```
|
||||
|
||||
The main changes:
|
||||
1. Security package no longer knows about specific spec types
|
||||
2. Each spec registers its own security hooks
|
||||
@@ -1002,15 +1170,49 @@ func (p *MyProvider) GetRowSecurity(ctx context.Context, userID int, schema, tab
|
||||
}
|
||||
```
|
||||
|
||||
## Model-Level Access Control
|
||||
|
||||
Use `ModelRules` (from `pkg/modelregistry`) to control per-entity auth behavior:
|
||||
|
||||
```go
|
||||
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
|
||||
SecurityDisabled: false, // true = skip all auth checks
|
||||
CanPublicRead: true, // unauthenticated GET allowed
|
||||
CanPublicCreate: false, // requires auth
|
||||
CanPublicUpdate: false, // requires auth
|
||||
CanPublicDelete: false, // requires auth
|
||||
CanUpdate: true, // authenticated users can update
|
||||
CanDelete: false, // authenticated users cannot delete
|
||||
})
|
||||
```
|
||||
|
||||
`CheckModelAuthAllowed(secCtx, operation)` applies these rules in `BeforeHandle`:
|
||||
1. `SecurityDisabled` → allow all
|
||||
2. `CanPublicRead/Create/Update/Delete` → allow unauthenticated for that operation
|
||||
3. Guest (UserID == 0) → return 401
|
||||
4. Authenticated → allow (operation-specific `CanUpdate`/`CanDelete` checked in `BeforeUpdate`/`BeforeDelete`)
|
||||
|
||||
---
|
||||
|
||||
## Middleware and Handler API
|
||||
|
||||
### NewAuthMiddleware
|
||||
Standard middleware that authenticates all requests:
|
||||
Standard middleware that authenticates all requests and returns 401 on failure:
|
||||
|
||||
```go
|
||||
router.Use(security.NewAuthMiddleware(securityList))
|
||||
```
|
||||
|
||||
### NewOptionalAuthMiddleware
|
||||
Middleware for spec routes — always continues; sets guest context on auth failure:
|
||||
|
||||
```go
|
||||
// Use with RegisterSecurityHooks — auth enforcement is deferred to BeforeHandle
|
||||
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
|
||||
apiRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||
restheadspec.RegisterSecurityHooks(handler, securityList) // registers BeforeHandle
|
||||
```
|
||||
|
||||
Routes can skip authentication using the `SkipAuth` helper:
|
||||
|
||||
```go
|
||||
|
||||
@@ -1397,3 +1397,173 @@ $$ LANGUAGE plpgsql;
|
||||
|
||||
-- Get credentials by username
|
||||
-- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin');
|
||||
|
||||
-- ============================================
|
||||
-- OAuth2 Server Tables (OAuthServer persistence)
|
||||
-- ============================================
|
||||
|
||||
-- oauth_clients: persistent RFC 7591 registered clients
|
||||
CREATE TABLE IF NOT EXISTS oauth_clients (
|
||||
id SERIAL PRIMARY KEY,
|
||||
client_id VARCHAR(255) NOT NULL UNIQUE,
|
||||
redirect_uris TEXT[] NOT NULL,
|
||||
client_name VARCHAR(255),
|
||||
grant_types TEXT[] DEFAULT ARRAY['authorization_code'],
|
||||
allowed_scopes TEXT[] DEFAULT ARRAY['openid','profile','email'],
|
||||
is_active BOOLEAN DEFAULT true,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- oauth_codes: short-lived authorization codes (for multi-instance deployments)
|
||||
CREATE TABLE IF NOT EXISTS oauth_codes (
|
||||
id SERIAL PRIMARY KEY,
|
||||
code VARCHAR(255) NOT NULL UNIQUE,
|
||||
client_id VARCHAR(255) NOT NULL REFERENCES oauth_clients(client_id) ON DELETE CASCADE,
|
||||
redirect_uri TEXT NOT NULL,
|
||||
client_state TEXT,
|
||||
code_challenge VARCHAR(255) NOT NULL,
|
||||
code_challenge_method VARCHAR(10) DEFAULT 'S256',
|
||||
session_token TEXT NOT NULL,
|
||||
scopes TEXT[],
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_oauth_codes_code ON oauth_codes(code);
|
||||
CREATE INDEX IF NOT EXISTS idx_oauth_codes_expires ON oauth_codes(expires_at);
|
||||
|
||||
-- ============================================
|
||||
-- OAuth2 Server Stored Procedures
|
||||
-- ============================================
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_register_client(p_data jsonb)
|
||||
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_client_id text;
|
||||
v_row jsonb;
|
||||
BEGIN
|
||||
v_client_id := p_data->>'client_id';
|
||||
|
||||
INSERT INTO oauth_clients (client_id, redirect_uris, client_name, grant_types, allowed_scopes)
|
||||
VALUES (
|
||||
v_client_id,
|
||||
ARRAY(SELECT jsonb_array_elements_text(p_data->'redirect_uris')),
|
||||
p_data->>'client_name',
|
||||
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'grant_types')), ARRAY['authorization_code']),
|
||||
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'allowed_scopes')), ARRAY['openid','profile','email'])
|
||||
)
|
||||
RETURNING to_jsonb(oauth_clients.*) INTO v_row;
|
||||
|
||||
RETURN QUERY SELECT true, null::text, v_row;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM, null::jsonb;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_get_client(p_client_id text)
|
||||
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_row jsonb;
|
||||
BEGIN
|
||||
SELECT to_jsonb(oauth_clients.*)
|
||||
INTO v_row
|
||||
FROM oauth_clients
|
||||
WHERE client_id = p_client_id AND is_active = true;
|
||||
|
||||
IF v_row IS NULL THEN
|
||||
RETURN QUERY SELECT false, 'client not found'::text, null::jsonb;
|
||||
ELSE
|
||||
RETURN QUERY SELECT true, null::text, v_row;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_save_code(p_data jsonb)
|
||||
RETURNS TABLE(p_success bool, p_error text)
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
INSERT INTO oauth_codes (code, client_id, redirect_uri, client_state, code_challenge, code_challenge_method, session_token, scopes, expires_at)
|
||||
VALUES (
|
||||
p_data->>'code',
|
||||
p_data->>'client_id',
|
||||
p_data->>'redirect_uri',
|
||||
p_data->>'client_state',
|
||||
p_data->>'code_challenge',
|
||||
COALESCE(p_data->>'code_challenge_method', 'S256'),
|
||||
p_data->>'session_token',
|
||||
ARRAY(SELECT jsonb_array_elements_text(p_data->'scopes')),
|
||||
(p_data->>'expires_at')::timestamp
|
||||
);
|
||||
|
||||
RETURN QUERY SELECT true, null::text;
|
||||
EXCEPTION WHEN OTHERS THEN
|
||||
RETURN QUERY SELECT false, SQLERRM;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_exchange_code(p_code text)
|
||||
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_row jsonb;
|
||||
BEGIN
|
||||
DELETE FROM oauth_codes
|
||||
WHERE code = p_code AND expires_at > now()
|
||||
RETURNING jsonb_build_object(
|
||||
'client_id', client_id,
|
||||
'redirect_uri', redirect_uri,
|
||||
'client_state', client_state,
|
||||
'code_challenge', code_challenge,
|
||||
'code_challenge_method', code_challenge_method,
|
||||
'session_token', session_token,
|
||||
'scopes', to_jsonb(scopes)
|
||||
) INTO v_row;
|
||||
|
||||
IF v_row IS NULL THEN
|
||||
RETURN QUERY SELECT false, 'invalid or expired code'::text, null::jsonb;
|
||||
ELSE
|
||||
RETURN QUERY SELECT true, null::text, v_row;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_introspect(p_token text)
|
||||
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||
LANGUAGE plpgsql AS $$
|
||||
DECLARE
|
||||
v_row jsonb;
|
||||
BEGIN
|
||||
SELECT jsonb_build_object(
|
||||
'active', true,
|
||||
'sub', u.id::text,
|
||||
'username', u.username,
|
||||
'email', u.email,
|
||||
'user_level', u.user_level,
|
||||
'roles', to_jsonb(string_to_array(COALESCE(u.roles, ''), ',')),
|
||||
'exp', EXTRACT(EPOCH FROM s.expires_at)::bigint,
|
||||
'iat', EXTRACT(EPOCH FROM s.created_at)::bigint
|
||||
)
|
||||
INTO v_row
|
||||
FROM user_sessions s
|
||||
JOIN users u ON u.id = s.user_id
|
||||
WHERE s.session_token = p_token
|
||||
AND s.expires_at > now()
|
||||
AND u.is_active = true;
|
||||
|
||||
IF v_row IS NULL THEN
|
||||
RETURN QUERY SELECT true, null::text, '{"active":false}'::jsonb;
|
||||
ELSE
|
||||
RETURN QUERY SELECT true, null::text, v_row;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION resolvespec_oauth_revoke(p_token text)
|
||||
RETURNS TABLE(p_success bool, p_error text)
|
||||
LANGUAGE plpgsql AS $$
|
||||
BEGIN
|
||||
DELETE FROM user_sessions WHERE session_token = p_token;
|
||||
RETURN QUERY SELECT true, null::text;
|
||||
END;
|
||||
$$;
|
||||
|
||||
@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
// For JWT, logout could involve token blacklisting
|
||||
// Add token to blacklist table
|
||||
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
|
||||
// "token": req.Token,
|
||||
// "expires_at": time.Now().Add(24 * time.Hour),
|
||||
// }).Error
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"reflect"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// SecurityContext is a generic interface that any spec can implement to integrate with security features
|
||||
@@ -226,6 +227,122 @@ func ApplyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) err
|
||||
return applyColumnSecurity(secCtx, securityList)
|
||||
}
|
||||
|
||||
// checkModelUpdateAllowed returns an error if CanUpdate is false for the model.
|
||||
// Rules are read from context (set by NewModelAuthMiddleware) with a fallback to the model registry.
|
||||
func checkModelUpdateAllowed(secCtx SecurityContext) error {
|
||||
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||
if !ok {
|
||||
schema := secCtx.GetSchema()
|
||||
entity := secCtx.GetEntity()
|
||||
var err error
|
||||
if schema != "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||
}
|
||||
if err != nil || schema == "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||
}
|
||||
if err != nil {
|
||||
return nil // model not registered, allow by default
|
||||
}
|
||||
}
|
||||
if !rules.CanUpdate {
|
||||
return fmt.Errorf("update not allowed for %s", secCtx.GetEntity())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkModelDeleteAllowed returns an error if CanDelete is false for the model.
|
||||
// Rules are read from context (set by NewModelAuthMiddleware) with a fallback to the model registry.
|
||||
func checkModelDeleteAllowed(secCtx SecurityContext) error {
|
||||
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||
if !ok {
|
||||
schema := secCtx.GetSchema()
|
||||
entity := secCtx.GetEntity()
|
||||
var err error
|
||||
if schema != "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||
}
|
||||
if err != nil || schema == "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||
}
|
||||
if err != nil {
|
||||
return nil // model not registered, allow by default
|
||||
}
|
||||
}
|
||||
if !rules.CanDelete {
|
||||
return fmt.Errorf("delete not allowed for %s", secCtx.GetEntity())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckModelAuthAllowed checks whether the requested operation is permitted based on
|
||||
// model rules and the current user's authentication state. It is intended for use in
|
||||
// a BeforeHandle hook, fired after model resolution.
|
||||
//
|
||||
// Logic:
|
||||
// 1. Load model rules from context (set by NewModelAuthMiddleware) or fall back to registry.
|
||||
// 2. SecurityDisabled → allow.
|
||||
// 3. operation == "read" && CanPublicRead → allow.
|
||||
// 4. operation == "create" && CanPublicCreate → allow.
|
||||
// 5. operation == "update" && CanPublicUpdate → allow.
|
||||
// 6. operation == "delete" && CanPublicDelete → allow.
|
||||
// 7. Guest (UserID == 0) → return "authentication required".
|
||||
// 8. Authenticated user → allow (operation-specific checks remain in BeforeUpdate/BeforeDelete).
|
||||
func CheckModelAuthAllowed(secCtx SecurityContext, operation string) error {
|
||||
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||
if !ok {
|
||||
schema := secCtx.GetSchema()
|
||||
entity := secCtx.GetEntity()
|
||||
var err error
|
||||
if schema != "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||
}
|
||||
if err != nil || schema == "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||
}
|
||||
if err != nil {
|
||||
// Model not registered - fall through to auth check
|
||||
userID, _ := secCtx.GetUserID()
|
||||
if userID == 0 {
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if rules.SecurityDisabled {
|
||||
return nil
|
||||
}
|
||||
if operation == "read" && rules.CanPublicRead {
|
||||
return nil
|
||||
}
|
||||
if operation == "create" && rules.CanPublicCreate {
|
||||
return nil
|
||||
}
|
||||
if operation == "update" && rules.CanPublicUpdate {
|
||||
return nil
|
||||
}
|
||||
if operation == "delete" && rules.CanPublicDelete {
|
||||
return nil
|
||||
}
|
||||
|
||||
userID, _ := secCtx.GetUserID()
|
||||
if userID == 0 {
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckModelUpdateAllowed is the public wrapper for checkModelUpdateAllowed.
|
||||
func CheckModelUpdateAllowed(secCtx SecurityContext) error {
|
||||
return checkModelUpdateAllowed(secCtx)
|
||||
}
|
||||
|
||||
// CheckModelDeleteAllowed is the public wrapper for checkModelDeleteAllowed.
|
||||
func CheckModelDeleteAllowed(secCtx SecurityContext) error {
|
||||
return checkModelDeleteAllowed(secCtx)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
@@ -23,6 +25,7 @@ const (
|
||||
UserMetaKey contextKey = "user_meta"
|
||||
SkipAuthKey contextKey = "skip_auth"
|
||||
OptionalAuthKey contextKey = "optional_auth"
|
||||
ModelRulesKey contextKey = "model_rules"
|
||||
)
|
||||
|
||||
// SkipAuth returns a context with skip auth flag set to true
|
||||
@@ -136,6 +139,31 @@ func NewOptionalAuthHandler(securityList *SecurityList, next http.Handler) http.
|
||||
})
|
||||
}
|
||||
|
||||
// NewOptionalAuthMiddleware creates authentication middleware that always continues.
|
||||
// On auth failure, a guest user context is set instead of returning 401.
|
||||
// Intended for spec routes where auth enforcement is deferred to a BeforeHandle hook
|
||||
// after model resolution.
|
||||
func NewOptionalAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
provider := securityList.Provider()
|
||||
if provider == nil {
|
||||
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
userCtx, err := provider.Authenticate(r)
|
||||
if err != nil {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, setUserContext(r, userCtx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates an authentication middleware with the given security list
|
||||
// This middleware extracts user authentication from the request and adds it to context
|
||||
// Routes can skip authentication by setting SkipAuthKey context value (use SkipAuth helper)
|
||||
@@ -182,6 +210,68 @@ func NewAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handl
|
||||
}
|
||||
}
|
||||
|
||||
// NewModelAuthMiddleware creates authentication middleware that respects ModelRules for the given model name.
|
||||
// It first checks if ModelRules are set for the model:
|
||||
// - If SecurityDisabled is true, authentication is skipped and a guest context is set.
|
||||
// - Otherwise, all checks from NewAuthMiddleware apply (SkipAuthKey, provider check, OptionalAuthKey, Authenticate).
|
||||
//
|
||||
// If the model is not found in any registry, the middleware falls back to standard NewAuthMiddleware behaviour.
|
||||
func NewModelAuthMiddleware(securityList *SecurityList, modelName string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check ModelRules first
|
||||
if rules, err := modelregistry.GetModelRulesByName(modelName); err == nil {
|
||||
// Store rules in context for downstream use (e.g., security hooks)
|
||||
r = r.WithContext(context.WithValue(r.Context(), ModelRulesKey, rules))
|
||||
|
||||
if rules.SecurityDisabled {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
isRead := r.Method == http.MethodGet || r.Method == http.MethodHead
|
||||
isUpdate := r.Method == http.MethodPut || r.Method == http.MethodPatch
|
||||
if (isRead && rules.CanPublicRead) || (isUpdate && rules.CanPublicUpdate) {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this route should skip authentication
|
||||
if skip, ok := r.Context().Value(SkipAuthKey).(bool); ok && skip {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
|
||||
// Get the security provider
|
||||
provider := securityList.Provider()
|
||||
if provider == nil {
|
||||
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this route has optional authentication
|
||||
optional, _ := r.Context().Value(OptionalAuthKey).(bool)
|
||||
|
||||
// Try to authenticate
|
||||
userCtx, err := provider.Authenticate(r)
|
||||
if err != nil {
|
||||
if optional {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, setUserContext(r, userCtx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// SetSecurityMiddleware adds security context to requests
|
||||
// This middleware should be applied after AuthMiddleware
|
||||
func SetSecurityMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
|
||||
@@ -366,6 +456,131 @@ func GetUserMeta(ctx context.Context) (map[string]any, bool) {
|
||||
return meta, ok
|
||||
}
|
||||
|
||||
// SessionCookieOptions configures the session cookie set by SetSessionCookie.
|
||||
// All fields are optional; sensible secure defaults are applied when omitted.
|
||||
type SessionCookieOptions struct {
|
||||
// Name is the cookie name. Defaults to "session_token".
|
||||
Name string
|
||||
// Path is the cookie path. Defaults to "/".
|
||||
Path string
|
||||
// Domain restricts the cookie to a specific domain. Empty means current host.
|
||||
Domain string
|
||||
// Secure sets the Secure flag. Defaults to true.
|
||||
// Set to false only in local development over HTTP.
|
||||
Secure *bool
|
||||
// SameSite sets the SameSite policy. Defaults to http.SameSiteLaxMode.
|
||||
SameSite http.SameSite
|
||||
}
|
||||
|
||||
func (o SessionCookieOptions) name() string {
|
||||
if o.Name != "" {
|
||||
return o.Name
|
||||
}
|
||||
return "session_token"
|
||||
}
|
||||
|
||||
func (o SessionCookieOptions) path() string {
|
||||
if o.Path != "" {
|
||||
return o.Path
|
||||
}
|
||||
return "/"
|
||||
}
|
||||
|
||||
func (o SessionCookieOptions) secure() bool {
|
||||
if o.Secure != nil {
|
||||
return *o.Secure
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (o SessionCookieOptions) sameSite() http.SameSite {
|
||||
if o.SameSite != 0 {
|
||||
return o.SameSite
|
||||
}
|
||||
return http.SameSiteLaxMode
|
||||
}
|
||||
|
||||
// SetSessionCookie writes the session_token cookie to the response after a successful login.
|
||||
// Call this immediately after a successful Authenticator.Login() call.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resp, err := auth.Login(r.Context(), req)
|
||||
// if err != nil { ... }
|
||||
// security.SetSessionCookie(w, resp)
|
||||
// json.NewEncoder(w).Encode(resp)
|
||||
func SetSessionCookie(w http.ResponseWriter, loginResp *LoginResponse, opts ...SessionCookieOptions) {
|
||||
var o SessionCookieOptions
|
||||
if len(opts) > 0 {
|
||||
o = opts[0]
|
||||
}
|
||||
|
||||
maxAge := 0
|
||||
if loginResp.ExpiresIn > 0 {
|
||||
maxAge = int(loginResp.ExpiresIn)
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: o.name(),
|
||||
Value: loginResp.Token,
|
||||
Path: o.path(),
|
||||
Domain: o.Domain,
|
||||
MaxAge: maxAge,
|
||||
HttpOnly: true,
|
||||
Secure: o.secure(),
|
||||
SameSite: o.sameSite(),
|
||||
})
|
||||
}
|
||||
|
||||
// GetSessionCookie returns the session token value from the request cookie, or empty string if not present.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// token := security.GetSessionCookie(r)
|
||||
func GetSessionCookie(r *http.Request, opts ...SessionCookieOptions) string {
|
||||
var o SessionCookieOptions
|
||||
if len(opts) > 0 {
|
||||
o = opts[0]
|
||||
}
|
||||
cookie, err := r.Cookie(o.name())
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
// ClearSessionCookie expires the session_token cookie, effectively logging the user out on the browser side.
|
||||
// Call this after a successful Authenticator.Logout() call.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := auth.Logout(r.Context(), req)
|
||||
// if err != nil { ... }
|
||||
// security.ClearSessionCookie(w)
|
||||
func ClearSessionCookie(w http.ResponseWriter, opts ...SessionCookieOptions) {
|
||||
var o SessionCookieOptions
|
||||
if len(opts) > 0 {
|
||||
o = opts[0]
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: o.name(),
|
||||
Value: "",
|
||||
Path: o.path(),
|
||||
Domain: o.Domain,
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: o.secure(),
|
||||
SameSite: o.sameSite(),
|
||||
})
|
||||
}
|
||||
|
||||
// GetModelRulesFromContext extracts ModelRules stored by NewModelAuthMiddleware
|
||||
func GetModelRulesFromContext(ctx context.Context) (modelregistry.ModelRules, bool) {
|
||||
rules, ok := ctx.Value(ModelRulesKey).(modelregistry.ModelRules)
|
||||
return rules, ok
|
||||
}
|
||||
|
||||
// // Handler adapters for resolvespec/restheadspec compatibility
|
||||
// // These functions allow using NewAuthHandler and NewOptionalAuthHandler with custom handler abstractions
|
||||
|
||||
|
||||
@@ -244,10 +244,10 @@ func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userC
|
||||
var errMsg *string
|
||||
var userID *int
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM resolvespec_oauth_getorcreateuser($1::jsonb)
|
||||
`, userJSON).Scan(&success, &errMsg, &userID)
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get or create user: %w", err)
|
||||
@@ -287,10 +287,10 @@ func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, session
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
SELECT p_success, p_error
|
||||
FROM resolvespec_oauth_createsession($1::jsonb)
|
||||
`, sessionJSON).Scan(&success, &errMsg)
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create session: %w", err)
|
||||
@@ -385,10 +385,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var errMsg *string
|
||||
var sessionData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM resolvespec_oauth_getrefreshtoken($1)
|
||||
`, refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
|
||||
@@ -451,10 +451,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var updateSuccess bool
|
||||
var updateErrMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM resolvespec_oauth_updaterefreshtoken($1::jsonb)
|
||||
`, updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update session: %w", err)
|
||||
@@ -472,10 +472,10 @@ func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshT
|
||||
var userErrMsg *string
|
||||
var userData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, `
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM resolvespec_oauth_getuser($1)
|
||||
`, session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
|
||||
859
pkg/security/oauth_server.go
Normal file
859
pkg/security/oauth_server.go
Normal file
@@ -0,0 +1,859 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthServerConfig configures the MCP-standard OAuth2 authorization server.
|
||||
type OAuthServerConfig struct {
|
||||
// Issuer is the public base URL of this server (e.g. "https://api.example.com").
|
||||
// Used in /.well-known/oauth-authorization-server and to build endpoint URLs.
|
||||
Issuer string
|
||||
|
||||
// ProviderCallbackPath is the path on this server that external OAuth2 providers
|
||||
// redirect back to. Defaults to "/oauth/provider/callback".
|
||||
ProviderCallbackPath string
|
||||
|
||||
// LoginTitle is shown on the built-in login form when the server acts as its own
|
||||
// identity provider. Defaults to "MCP Login".
|
||||
LoginTitle string
|
||||
|
||||
// PersistClients stores registered clients in the database when a DatabaseAuthenticator is provided.
|
||||
// Clients registered during a session survive server restarts.
|
||||
PersistClients bool
|
||||
|
||||
// PersistCodes stores authorization codes in the database.
|
||||
// Useful for multi-instance deployments. Defaults to in-memory.
|
||||
PersistCodes bool
|
||||
|
||||
// DefaultScopes lists scopes advertised in server metadata. Defaults to ["openid","profile","email"].
|
||||
DefaultScopes []string
|
||||
|
||||
// AccessTokenTTL is the issued token lifetime. Defaults to 24h.
|
||||
AccessTokenTTL time.Duration
|
||||
|
||||
// AuthCodeTTL is the auth code lifetime. Defaults to 2 minutes.
|
||||
AuthCodeTTL time.Duration
|
||||
}
|
||||
|
||||
// oauthClient is a dynamically registered OAuth2 client (RFC 7591).
|
||||
type oauthClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||
}
|
||||
|
||||
// pendingAuth tracks an in-progress authorization code exchange.
|
||||
type pendingAuth struct {
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
ClientState string
|
||||
CodeChallenge string
|
||||
CodeChallengeMethod string
|
||||
ProviderName string // empty = password login
|
||||
ExpiresAt time.Time
|
||||
SessionToken string // set after authentication completes
|
||||
Scopes []string // requested scopes
|
||||
}
|
||||
|
||||
// externalProvider pairs a DatabaseAuthenticator with its provider name.
|
||||
type externalProvider struct {
|
||||
auth *DatabaseAuthenticator
|
||||
providerName string
|
||||
}
|
||||
|
||||
// OAuthServer implements the MCP-standard OAuth2 authorization server (OAuth 2.1 + PKCE).
|
||||
//
|
||||
// It can act as both:
|
||||
// - A direct identity provider using DatabaseAuthenticator username/password login
|
||||
// - A federation layer that delegates authentication to external OAuth2 providers
|
||||
// (Google, GitHub, Microsoft, etc.) registered via RegisterExternalProvider
|
||||
//
|
||||
// The server exposes these RFC-compliant endpoints:
|
||||
//
|
||||
// GET /.well-known/oauth-authorization-server RFC 8414 — server metadata discovery
|
||||
// POST /oauth/register RFC 7591 — dynamic client registration
|
||||
// GET /oauth/authorize OAuth 2.1 + PKCE — start authorization
|
||||
// POST /oauth/authorize Direct login form submission
|
||||
// POST /oauth/token Token exchange and refresh
|
||||
// POST /oauth/revoke RFC 7009 — token revocation
|
||||
// POST /oauth/introspect RFC 7662 — token introspection
|
||||
// GET {ProviderCallbackPath} Internal — external provider callback
|
||||
type OAuthServer struct {
|
||||
cfg OAuthServerConfig
|
||||
auth *DatabaseAuthenticator // nil = only external providers
|
||||
providers []externalProvider
|
||||
|
||||
mu sync.RWMutex
|
||||
clients map[string]*oauthClient
|
||||
pending map[string]*pendingAuth // provider_state → pending (external flow)
|
||||
codes map[string]*pendingAuth // auth_code → pending (post-auth)
|
||||
}
|
||||
|
||||
// NewOAuthServer creates a new MCP OAuth2 authorization server.
|
||||
//
|
||||
// Pass a DatabaseAuthenticator to enable direct username/password login (the server
|
||||
// acts as its own identity provider). Pass nil to use only external providers.
|
||||
// External providers are added separately via RegisterExternalProvider.
|
||||
func NewOAuthServer(cfg OAuthServerConfig, auth *DatabaseAuthenticator) *OAuthServer {
|
||||
if cfg.ProviderCallbackPath == "" {
|
||||
cfg.ProviderCallbackPath = "/oauth/provider/callback"
|
||||
}
|
||||
if cfg.LoginTitle == "" {
|
||||
cfg.LoginTitle = "Sign in"
|
||||
}
|
||||
if len(cfg.DefaultScopes) == 0 {
|
||||
cfg.DefaultScopes = []string{"openid", "profile", "email"}
|
||||
}
|
||||
if cfg.AccessTokenTTL == 0 {
|
||||
cfg.AccessTokenTTL = 24 * time.Hour
|
||||
}
|
||||
if cfg.AuthCodeTTL == 0 {
|
||||
cfg.AuthCodeTTL = 2 * time.Minute
|
||||
}
|
||||
s := &OAuthServer{
|
||||
cfg: cfg,
|
||||
auth: auth,
|
||||
clients: make(map[string]*oauthClient),
|
||||
pending: make(map[string]*pendingAuth),
|
||||
codes: make(map[string]*pendingAuth),
|
||||
}
|
||||
go s.cleanupExpired()
|
||||
return s
|
||||
}
|
||||
|
||||
// RegisterExternalProvider adds an external OAuth2 provider (Google, GitHub, Microsoft, etc.)
|
||||
// that handles user authentication via redirect. The DatabaseAuthenticator must have been
|
||||
// configured with WithOAuth2(providerName, ...) before calling this.
|
||||
// Multiple providers can be registered; the first is used as the default.
|
||||
func (s *OAuthServer) RegisterExternalProvider(auth *DatabaseAuthenticator, providerName string) {
|
||||
s.providers = append(s.providers, externalProvider{auth: auth, providerName: providerName})
|
||||
}
|
||||
|
||||
// ProviderCallbackPath returns the configured path for external provider callbacks.
|
||||
func (s *OAuthServer) ProviderCallbackPath() string {
|
||||
return s.cfg.ProviderCallbackPath
|
||||
}
|
||||
|
||||
// HTTPHandler returns an http.Handler that serves all RFC-required OAuth2 endpoints.
|
||||
// Mount it at the root of your HTTP server alongside the MCP transport.
|
||||
//
|
||||
// mux := http.NewServeMux()
|
||||
// mux.Handle("/", oauthServer.HTTPHandler())
|
||||
// mux.Handle("/mcp/", mcpTransport)
|
||||
func (s *OAuthServer) HTTPHandler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/.well-known/oauth-authorization-server", s.metadataHandler)
|
||||
mux.HandleFunc("/oauth/register", s.registerHandler)
|
||||
mux.HandleFunc("/oauth/authorize", s.authorizeHandler)
|
||||
mux.HandleFunc("/oauth/token", s.tokenHandler)
|
||||
mux.HandleFunc("/oauth/revoke", s.revokeHandler)
|
||||
mux.HandleFunc("/oauth/introspect", s.introspectHandler)
|
||||
mux.HandleFunc(s.cfg.ProviderCallbackPath, s.providerCallbackHandler)
|
||||
return mux
|
||||
}
|
||||
|
||||
// cleanupExpired removes stale pending auths and codes every 5 minutes.
|
||||
func (s *OAuthServer) cleanupExpired() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
for k, p := range s.pending {
|
||||
if now.After(p.ExpiresAt) {
|
||||
delete(s.pending, k)
|
||||
}
|
||||
}
|
||||
for k, p := range s.codes {
|
||||
if now.After(p.ExpiresAt) {
|
||||
delete(s.codes, k)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 8414 — Server metadata
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) metadataHandler(w http.ResponseWriter, r *http.Request) {
|
||||
issuer := s.cfg.Issuer
|
||||
meta := map[string]interface{}{
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": issuer + "/oauth/authorize",
|
||||
"token_endpoint": issuer + "/oauth/token",
|
||||
"registration_endpoint": issuer + "/oauth/register",
|
||||
"revocation_endpoint": issuer + "/oauth/revoke",
|
||||
"introspection_endpoint": issuer + "/oauth/introspect",
|
||||
"scopes_supported": s.cfg.DefaultScopes,
|
||||
"response_types_supported": []string{"code"},
|
||||
"grant_types_supported": []string{"authorization_code", "refresh_token"},
|
||||
"code_challenge_methods_supported": []string{"S256"},
|
||||
"token_endpoint_auth_methods_supported": []string{"none"},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(meta) //nolint:errcheck
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 7591 — Dynamic client registration
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) registerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOAuthError(w, "invalid_request", "malformed JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(req.RedirectURIs) == 0 {
|
||||
writeOAuthError(w, "invalid_request", "redirect_uris required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
grantTypes := req.GrantTypes
|
||||
if len(grantTypes) == 0 {
|
||||
grantTypes = []string{"authorization_code"}
|
||||
}
|
||||
allowedScopes := req.AllowedScopes
|
||||
if len(allowedScopes) == 0 {
|
||||
allowedScopes = s.cfg.DefaultScopes
|
||||
}
|
||||
clientID, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
client := &oauthClient{
|
||||
ClientID: clientID,
|
||||
RedirectURIs: req.RedirectURIs,
|
||||
ClientName: req.ClientName,
|
||||
GrantTypes: grantTypes,
|
||||
AllowedScopes: allowedScopes,
|
||||
}
|
||||
|
||||
if s.cfg.PersistClients && s.auth != nil {
|
||||
dbClient := &OAuthServerClient{
|
||||
ClientID: client.ClientID,
|
||||
RedirectURIs: client.RedirectURIs,
|
||||
ClientName: client.ClientName,
|
||||
GrantTypes: client.GrantTypes,
|
||||
AllowedScopes: client.AllowedScopes,
|
||||
}
|
||||
if _, err := s.auth.OAuthRegisterClient(r.Context(), dbClient); err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.clients[clientID] = client
|
||||
s.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(client) //nolint:errcheck
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Authorization endpoint — GET + POST /oauth/authorize
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
s.authorizeGet(w, r)
|
||||
case http.MethodPost:
|
||||
s.authorizePost(w, r)
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// authorizeGet validates the request and either:
|
||||
// - Redirects to an external provider (if providers are registered)
|
||||
// - Renders a login form (if the server is its own identity provider)
|
||||
func (s *OAuthServer) authorizeGet(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
clientID := q.Get("client_id")
|
||||
redirectURI := q.Get("redirect_uri")
|
||||
clientState := q.Get("state")
|
||||
codeChallenge := q.Get("code_challenge")
|
||||
codeChallengeMethod := q.Get("code_challenge_method")
|
||||
providerName := q.Get("provider")
|
||||
scopeStr := q.Get("scope")
|
||||
var scopes []string
|
||||
if scopeStr != "" {
|
||||
scopes = strings.Fields(scopeStr)
|
||||
}
|
||||
|
||||
if q.Get("response_type") != "code" {
|
||||
writeOAuthError(w, "unsupported_response_type", "only 'code' is supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if codeChallenge == "" {
|
||||
writeOAuthError(w, "invalid_request", "code_challenge required (PKCE S256)", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if codeChallengeMethod != "" && codeChallengeMethod != "S256" {
|
||||
writeOAuthError(w, "invalid_request", "only S256 code_challenge_method is supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
|
||||
if !ok {
|
||||
writeOAuthError(w, "invalid_client", "unknown client_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !oauthSliceContains(client.RedirectURIs, redirectURI) {
|
||||
writeOAuthError(w, "invalid_request", "redirect_uri not registered", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// External provider path
|
||||
if len(s.providers) > 0 {
|
||||
s.redirectToExternalProvider(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName, scopes)
|
||||
return
|
||||
}
|
||||
|
||||
// Direct login form path (server is its own identity provider)
|
||||
if s.auth == nil {
|
||||
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "")
|
||||
}
|
||||
|
||||
// authorizePost handles login form submission for the direct login flow.
|
||||
func (s *OAuthServer) authorizePost(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "invalid form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
clientID := r.FormValue("client_id")
|
||||
redirectURI := r.FormValue("redirect_uri")
|
||||
clientState := r.FormValue("client_state")
|
||||
codeChallenge := r.FormValue("code_challenge")
|
||||
codeChallengeMethod := r.FormValue("code_challenge_method")
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
scopeStr := r.FormValue("scope")
|
||||
var scopes []string
|
||||
if scopeStr != "" {
|
||||
scopes = strings.Fields(scopeStr)
|
||||
}
|
||||
|
||||
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
|
||||
if !ok || !oauthSliceContains(client.RedirectURIs, redirectURI) {
|
||||
http.Error(w, "invalid client or redirect_uri", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if s.auth == nil {
|
||||
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loginResp, err := s.auth.Login(r.Context(), LoginRequest{
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
if err != nil {
|
||||
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "Invalid username or password")
|
||||
return
|
||||
}
|
||||
|
||||
s.issueCodeAndRedirect(w, r, loginResp.Token, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, "", scopes)
|
||||
}
|
||||
|
||||
// redirectToExternalProvider stores the pending auth and redirects to the configured provider.
|
||||
func (s *OAuthServer) redirectToExternalProvider(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
|
||||
var provider *externalProvider
|
||||
if providerName != "" {
|
||||
for i := range s.providers {
|
||||
if s.providers[i].providerName == providerName {
|
||||
provider = &s.providers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if provider == nil {
|
||||
http.Error(w, fmt.Sprintf("provider %q not found", providerName), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = &s.providers[0]
|
||||
}
|
||||
|
||||
providerState, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
pending := &pendingAuth{
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ClientState: clientState,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
ProviderName: provider.providerName,
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
Scopes: scopes,
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.pending[providerState] = pending
|
||||
s.mu.Unlock()
|
||||
|
||||
authURL, err := provider.auth.OAuth2GetAuthURL(provider.providerName, providerState)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// External provider callback — GET {ProviderCallbackPath}
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) providerCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
providerState := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
http.Error(w, "missing code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
pending, ok := s.pending[providerState]
|
||||
if ok {
|
||||
delete(s.pending, providerState)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok || time.Now().After(pending.ExpiresAt) {
|
||||
http.Error(w, "invalid or expired state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
provider := s.providerByName(pending.ProviderName)
|
||||
if provider == nil {
|
||||
http.Error(w, fmt.Sprintf("provider %q not found", pending.ProviderName), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loginResp, err := provider.auth.OAuth2HandleCallback(r.Context(), pending.ProviderName, code, providerState)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
s.issueCodeAndRedirect(w, r, loginResp.Token,
|
||||
pending.ClientID, pending.RedirectURI, pending.ClientState,
|
||||
pending.CodeChallenge, pending.CodeChallengeMethod, pending.ProviderName, pending.Scopes)
|
||||
}
|
||||
|
||||
// issueCodeAndRedirect generates a short-lived auth code and redirects to the MCP client.
|
||||
func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Request, sessionToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
|
||||
authCode, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
pending := &pendingAuth{
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ClientState: clientState,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
ProviderName: providerName,
|
||||
SessionToken: sessionToken,
|
||||
ExpiresAt: time.Now().Add(s.cfg.AuthCodeTTL),
|
||||
Scopes: scopes,
|
||||
}
|
||||
|
||||
if s.cfg.PersistCodes && s.auth != nil {
|
||||
oauthCode := &OAuthCode{
|
||||
Code: authCode,
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ClientState: clientState,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
SessionToken: sessionToken,
|
||||
Scopes: scopes,
|
||||
ExpiresAt: pending.ExpiresAt,
|
||||
}
|
||||
if err := s.auth.OAuthSaveCode(r.Context(), oauthCode); err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
s.mu.Lock()
|
||||
s.codes[authCode] = pending
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
redirectURL, err := url.Parse(redirectURI)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid redirect_uri", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
qp := redirectURL.Query()
|
||||
qp.Set("code", authCode)
|
||||
if clientState != "" {
|
||||
qp.Set("state", clientState)
|
||||
}
|
||||
redirectURL.RawQuery = qp.Encode()
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Token endpoint — POST /oauth/token
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) tokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
writeOAuthError(w, "invalid_request", "cannot parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
switch r.FormValue("grant_type") {
|
||||
case "authorization_code":
|
||||
s.handleAuthCodeGrant(w, r)
|
||||
case "refresh_token":
|
||||
s.handleRefreshGrant(w, r)
|
||||
default:
|
||||
writeOAuthError(w, "unsupported_grant_type", "", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.FormValue("code")
|
||||
redirectURI := r.FormValue("redirect_uri")
|
||||
clientID := r.FormValue("client_id")
|
||||
codeVerifier := r.FormValue("code_verifier")
|
||||
|
||||
if code == "" || codeVerifier == "" {
|
||||
writeOAuthError(w, "invalid_request", "code and code_verifier required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var sessionToken string
|
||||
var scopes []string
|
||||
|
||||
if s.cfg.PersistCodes && s.auth != nil {
|
||||
oauthCode, err := s.auth.OAuthExchangeCode(r.Context(), code)
|
||||
if err != nil {
|
||||
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if oauthCode.ClientID != clientID {
|
||||
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if oauthCode.RedirectURI != redirectURI {
|
||||
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validatePKCESHA256(oauthCode.CodeChallenge, codeVerifier) {
|
||||
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
sessionToken = oauthCode.SessionToken
|
||||
scopes = oauthCode.Scopes
|
||||
} else {
|
||||
s.mu.Lock()
|
||||
pending, ok := s.codes[code]
|
||||
if ok {
|
||||
delete(s.codes, code)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok || time.Now().After(pending.ExpiresAt) {
|
||||
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if pending.ClientID != clientID {
|
||||
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if pending.RedirectURI != redirectURI {
|
||||
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validatePKCESHA256(pending.CodeChallenge, codeVerifier) {
|
||||
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
sessionToken = pending.SessionToken
|
||||
scopes = pending.Scopes
|
||||
}
|
||||
|
||||
writeOAuthToken(w, sessionToken, "", scopes)
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request) {
|
||||
refreshToken := r.FormValue("refresh_token")
|
||||
providerName := r.FormValue("provider")
|
||||
if refreshToken == "" {
|
||||
writeOAuthError(w, "invalid_request", "refresh_token required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Try external providers first, then fall back to DatabaseAuthenticator
|
||||
provider := s.providerByName(providerName)
|
||||
if provider != nil {
|
||||
loginResp, err := provider.auth.OAuth2RefreshToken(r.Context(), refreshToken, providerName)
|
||||
if err != nil {
|
||||
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
|
||||
return
|
||||
}
|
||||
|
||||
if s.auth != nil {
|
||||
loginResp, err := s.auth.RefreshToken(r.Context(), refreshToken)
|
||||
if err != nil {
|
||||
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
|
||||
return
|
||||
}
|
||||
|
||||
writeOAuthError(w, "invalid_grant", "no provider available for refresh", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 7009 — Token revocation
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) revokeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
token := r.FormValue("token")
|
||||
if token == "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
if s.auth != nil {
|
||||
s.auth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 7662 — Token introspection
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) introspectHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
token := r.FormValue("token")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if token == "" || s.auth == nil {
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
|
||||
info, err := s.auth.OAuthIntrospectToken(r.Context(), token)
|
||||
if err != nil {
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(info) //nolint:errcheck
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Login form (direct identity provider mode)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) renderLoginForm(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scope, errMsg string) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
errHTML := ""
|
||||
if errMsg != "" {
|
||||
errHTML = `<p style="color:red">` + errMsg + `</p>`
|
||||
}
|
||||
fmt.Fprintf(w, loginFormHTML,
|
||||
s.cfg.LoginTitle,
|
||||
s.cfg.LoginTitle,
|
||||
errHTML,
|
||||
clientID,
|
||||
htmlEscape(redirectURI),
|
||||
htmlEscape(clientState),
|
||||
htmlEscape(codeChallenge),
|
||||
htmlEscape(codeChallengeMethod),
|
||||
htmlEscape(scope),
|
||||
)
|
||||
}
|
||||
|
||||
const loginFormHTML = `<!DOCTYPE html>
|
||||
<html><head><meta charset="utf-8"><title>%s</title>
|
||||
<style>body{font-family:sans-serif;display:flex;justify-content:center;align-items:center;min-height:100vh;margin:0;background:#f5f5f5}
|
||||
.card{background:#fff;padding:2rem;border-radius:8px;box-shadow:0 2px 8px rgba(0,0,0,.15);width:320px}
|
||||
h2{margin:0 0 1.5rem;font-size:1.25rem}
|
||||
label{display:block;margin-bottom:.25rem;font-size:.875rem;color:#555}
|
||||
input[type=text],input[type=password]{width:100%%;box-sizing:border-box;padding:.5rem;border:1px solid #ccc;border-radius:4px;margin-bottom:1rem;font-size:1rem}
|
||||
button{width:100%%;padding:.6rem;background:#0070f3;color:#fff;border:none;border-radius:4px;font-size:1rem;cursor:pointer}
|
||||
button:hover{background:#005fd4}.err{color:#d32f2f;margin-bottom:1rem;font-size:.875rem}</style>
|
||||
</head><body><div class="card">
|
||||
<h2>%s</h2>%s
|
||||
<form method="POST" action="/oauth/authorize">
|
||||
<input type="hidden" name="client_id" value="%s">
|
||||
<input type="hidden" name="redirect_uri" value="%s">
|
||||
<input type="hidden" name="client_state" value="%s">
|
||||
<input type="hidden" name="code_challenge" value="%s">
|
||||
<input type="hidden" name="code_challenge_method" value="%s">
|
||||
<input type="hidden" name="scope" value="%s">
|
||||
<label>Username</label><input type="text" name="username" autofocus autocomplete="username">
|
||||
<label>Password</label><input type="password" name="password" autocomplete="current-password">
|
||||
<button type="submit">Sign in</button>
|
||||
</form></div></body></html>`
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// lookupOrFetchClient checks in-memory first, then DB if PersistClients is enabled.
|
||||
func (s *OAuthServer) lookupOrFetchClient(ctx context.Context, clientID string) (*oauthClient, bool) {
|
||||
s.mu.RLock()
|
||||
c, ok := s.clients[clientID]
|
||||
s.mu.RUnlock()
|
||||
if ok {
|
||||
return c, true
|
||||
}
|
||||
|
||||
if !s.cfg.PersistClients || s.auth == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
dbClient, err := s.auth.OAuthGetClient(ctx, clientID)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
c = &oauthClient{
|
||||
ClientID: dbClient.ClientID,
|
||||
RedirectURIs: dbClient.RedirectURIs,
|
||||
ClientName: dbClient.ClientName,
|
||||
GrantTypes: dbClient.GrantTypes,
|
||||
AllowedScopes: dbClient.AllowedScopes,
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.clients[clientID] = c
|
||||
s.mu.Unlock()
|
||||
return c, true
|
||||
}
|
||||
|
||||
func (s *OAuthServer) providerByName(name string) *externalProvider {
|
||||
for i := range s.providers {
|
||||
if s.providers[i].providerName == name {
|
||||
return &s.providers[i]
|
||||
}
|
||||
}
|
||||
// If name is empty and only one provider exists, return it
|
||||
if name == "" && len(s.providers) == 1 {
|
||||
return &s.providers[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePKCESHA256(challenge, verifier string) bool {
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(h[:]) == challenge
|
||||
}
|
||||
|
||||
func randomOAuthToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func oauthSliceContains(slice []string, s string) bool {
|
||||
for _, v := range slice {
|
||||
if strings.EqualFold(v, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func writeOAuthToken(w http.ResponseWriter, accessToken, refreshToken string, scopes []string) {
|
||||
resp := map[string]interface{}{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
if refreshToken != "" {
|
||||
resp["refresh_token"] = refreshToken
|
||||
}
|
||||
if len(scopes) > 0 {
|
||||
resp["scope"] = strings.Join(scopes, " ")
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
json.NewEncoder(w).Encode(resp) //nolint:errcheck
|
||||
}
|
||||
|
||||
func writeOAuthError(w http.ResponseWriter, errCode, description string, status int) {
|
||||
resp := map[string]string{"error": errCode}
|
||||
if description != "" {
|
||||
resp["error_description"] = description
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(resp) //nolint:errcheck
|
||||
}
|
||||
|
||||
func htmlEscape(s string) string {
|
||||
s = strings.ReplaceAll(s, "&", "&")
|
||||
s = strings.ReplaceAll(s, `"`, """)
|
||||
s = strings.ReplaceAll(s, "<", "<")
|
||||
s = strings.ReplaceAll(s, ">", ">")
|
||||
return s
|
||||
}
|
||||
202
pkg/security/oauth_server_db.go
Normal file
202
pkg/security/oauth_server_db.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthServerClient is a persisted RFC 7591 registered OAuth2 client.
|
||||
type OAuthServerClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthCode is a short-lived authorization code.
|
||||
type OAuthCode struct {
|
||||
Code string `json:"code"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ClientState string `json:"client_state,omitempty"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
SessionToken string `json:"session_token"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// OAuthTokenInfo is the RFC 7662 token introspection response.
|
||||
type OAuthTokenInfo struct {
|
||||
Active bool `json:"active"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthRegisterClient persists an OAuth2 client registration.
|
||||
func (a *DatabaseAuthenticator) OAuthRegisterClient(ctx context.Context, client *OAuthServerClient) (*OAuthServerClient, error) {
|
||||
input, err := json.Marshal(client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal client: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthRegisterClient), input).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to register client")
|
||||
}
|
||||
|
||||
var result OAuthServerClient
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse registered client: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthGetClient retrieves a registered client by ID.
|
||||
func (a *DatabaseAuthenticator) OAuthGetClient(ctx context.Context, clientID string) (*OAuthServerClient, error) {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetClient), clientID).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get client: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("client not found")
|
||||
}
|
||||
|
||||
var result OAuthServerClient
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse client: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthSaveCode persists an authorization code.
|
||||
func (a *DatabaseAuthenticator) OAuthSaveCode(ctx context.Context, code *OAuthCode) error {
|
||||
input, err := json.Marshal(code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal code: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthSaveCode), input).Scan(&success, &errMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save code: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return fmt.Errorf("failed to save code")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OAuthExchangeCode retrieves and deletes an authorization code (single use).
|
||||
func (a *DatabaseAuthenticator) OAuthExchangeCode(ctx context.Context, code string) (*OAuthCode, error) {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthExchangeCode), code).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid or expired code")
|
||||
}
|
||||
|
||||
var result OAuthCode
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse code data: %w", err)
|
||||
}
|
||||
result.Code = code
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthIntrospectToken validates a token and returns its metadata (RFC 7662).
|
||||
func (a *DatabaseAuthenticator) OAuthIntrospectToken(ctx context.Context, token string) (*OAuthTokenInfo, error) {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthIntrospect), token).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to introspect token: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("introspection failed")
|
||||
}
|
||||
|
||||
var result OAuthTokenInfo
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token info: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthRevokeToken revokes a token by deleting the session (RFC 7009).
|
||||
func (a *DatabaseAuthenticator) OAuthRevokeToken(ctx context.Context, token string) error {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthRevoke), token).Scan(&success, &errMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke token: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return fmt.Errorf("failed to revoke token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -11,12 +11,14 @@ import (
|
||||
)
|
||||
|
||||
// DatabasePasskeyProvider implements PasskeyProvider using database storage
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabasePasskeyProvider struct {
|
||||
db *sql.DB
|
||||
rpID string // Relying Party ID (domain)
|
||||
rpName string // Relying Party display name
|
||||
rpOrigin string // Expected origin for WebAuthn
|
||||
timeout int64 // Timeout in milliseconds (default: 60000)
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
// DatabasePasskeyProviderOptions configures the passkey provider
|
||||
@@ -29,6 +31,8 @@ type DatabasePasskeyProviderOptions struct {
|
||||
RPOrigin string
|
||||
// Timeout is the timeout for operations in milliseconds (default: 60000)
|
||||
Timeout int64
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
SQLNames *SQLNames
|
||||
}
|
||||
|
||||
// NewDatabasePasskeyProvider creates a new database-backed passkey provider
|
||||
@@ -37,12 +41,15 @@ func NewDatabasePasskeyProvider(db *sql.DB, opts DatabasePasskeyProviderOptions)
|
||||
opts.Timeout = 60000 // 60 seconds default
|
||||
}
|
||||
|
||||
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||
|
||||
return &DatabasePasskeyProvider{
|
||||
db: db,
|
||||
rpID: opts.RPID,
|
||||
rpName: opts.RPName,
|
||||
rpOrigin: opts.RPOrigin,
|
||||
timeout: opts.Timeout,
|
||||
sqlNames: sqlNames,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,7 +139,7 @@ func (p *DatabasePasskeyProvider) CompleteRegistration(ctx context.Context, user
|
||||
var errorMsg sql.NullString
|
||||
var credentialID sql.NullInt64
|
||||
|
||||
query := `SELECT p_success, p_error, p_credential_id FROM resolvespec_passkey_store_credential($1::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential_id FROM %s($1::jsonb)`, p.sqlNames.PasskeyStoreCredential)
|
||||
err = p.db.QueryRowContext(ctx, query, string(credJSON)).Scan(&success, &errorMsg, &credentialID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to store credential: %w", err)
|
||||
@@ -173,7 +180,7 @@ func (p *DatabasePasskeyProvider) BeginAuthentication(ctx context.Context, usern
|
||||
var userID sql.NullInt64
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user_id, p_credentials::text FROM resolvespec_passkey_get_credentials_by_username($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user_id, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetCredsByUsername)
|
||||
err := p.db.QueryRowContext(ctx, query, username).Scan(&success, &errorMsg, &userID, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
@@ -233,7 +240,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var errorMsg sql.NullString
|
||||
var credentialJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_credential::text FROM resolvespec_passkey_get_credential($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credential::text FROM %s($1)`, p.sqlNames.PasskeyGetCredential)
|
||||
err := p.db.QueryRowContext(ctx, query, response.RawID).Scan(&success, &errorMsg, &credentialJSON)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get credential: %w", err)
|
||||
@@ -264,7 +271,7 @@ func (p *DatabasePasskeyProvider) CompleteAuthentication(ctx context.Context, re
|
||||
var updateError sql.NullString
|
||||
var cloneWarning sql.NullBool
|
||||
|
||||
updateQuery := `SELECT p_success, p_error, p_clone_warning FROM resolvespec_passkey_update_counter($1, $2)`
|
||||
updateQuery := fmt.Sprintf(`SELECT p_success, p_error, p_clone_warning FROM %s($1, $2)`, p.sqlNames.PasskeyUpdateCounter)
|
||||
err = p.db.QueryRowContext(ctx, updateQuery, response.RawID, newCounter).Scan(&updateSuccess, &updateError, &cloneWarning)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to update counter: %w", err)
|
||||
@@ -283,7 +290,7 @@ func (p *DatabasePasskeyProvider) GetCredentials(ctx context.Context, userID int
|
||||
var errorMsg sql.NullString
|
||||
var credentialsJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_credentials::text FROM resolvespec_passkey_get_user_credentials($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_credentials::text FROM %s($1)`, p.sqlNames.PasskeyGetUserCredentials)
|
||||
err := p.db.QueryRowContext(ctx, query, userID).Scan(&success, &errorMsg, &credentialsJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get credentials: %w", err)
|
||||
@@ -362,7 +369,7 @@ func (p *DatabasePasskeyProvider) DeleteCredential(ctx context.Context, userID i
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_delete_credential($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, p.sqlNames.PasskeyDeleteCredential)
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete credential: %w", err)
|
||||
@@ -388,7 +395,7 @@ func (p *DatabasePasskeyProvider) UpdateCredentialName(ctx context.Context, user
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_passkey_update_name($1, $2, $3)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3)`, p.sqlNames.PasskeyUpdateName)
|
||||
err = p.db.QueryRowContext(ctx, query, userID, credID, name).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update credential name: %w", err)
|
||||
|
||||
@@ -58,8 +58,7 @@ func (a *HeaderAuthenticator) Authenticate(r *http.Request) (*UserContext, error
|
||||
|
||||
// DatabaseAuthenticator provides session-based authentication with database storage
|
||||
// All database operations go through stored procedures for security and consistency
|
||||
// Requires stored procedures: resolvespec_login, resolvespec_logout, resolvespec_session,
|
||||
// resolvespec_session_update, resolvespec_refresh_token
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
// See database_schema.sql for procedure definitions
|
||||
// Also supports multiple OAuth2 providers configured with WithOAuth2()
|
||||
// Also supports passkey authentication configured with WithPasskey()
|
||||
@@ -67,6 +66,7 @@ type DatabaseAuthenticator struct {
|
||||
db *sql.DB
|
||||
cache *cache.Cache
|
||||
cacheTTL time.Duration
|
||||
sqlNames *SQLNames
|
||||
|
||||
// OAuth2 providers registry (multiple providers supported)
|
||||
oauth2Providers map[string]*OAuth2Provider
|
||||
@@ -85,6 +85,9 @@ type DatabaseAuthenticatorOptions struct {
|
||||
Cache *cache.Cache
|
||||
// PasskeyProvider is an optional passkey provider for WebAuthn/FIDO2 authentication
|
||||
PasskeyProvider PasskeyProvider
|
||||
// SQLNames provides custom SQL procedure/function names. If nil, uses DefaultSQLNames().
|
||||
// Partial overrides are supported: only set the fields you want to change.
|
||||
SQLNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseAuthenticator(db *sql.DB) *DatabaseAuthenticator {
|
||||
@@ -103,10 +106,13 @@ func NewDatabaseAuthenticatorWithOptions(db *sql.DB, opts DatabaseAuthenticatorO
|
||||
cacheInstance = cache.GetDefaultCache()
|
||||
}
|
||||
|
||||
sqlNames := MergeSQLNames(DefaultSQLNames(), opts.SQLNames)
|
||||
|
||||
return &DatabaseAuthenticator{
|
||||
db: db,
|
||||
cache: cacheInstance,
|
||||
cacheTTL: opts.CacheTTL,
|
||||
sqlNames: sqlNames,
|
||||
passkeyProvider: opts.PasskeyProvider,
|
||||
}
|
||||
}
|
||||
@@ -118,12 +124,11 @@ func (a *DatabaseAuthenticator) Login(ctx context.Context, req LoginRequest) (*L
|
||||
return nil, fmt.Errorf("failed to marshal login request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_login stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_login($1::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Login)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
@@ -153,12 +158,11 @@ func (a *DatabaseAuthenticator) Register(ctx context.Context, req RegisterReques
|
||||
return nil, fmt.Errorf("failed to marshal register request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_register stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_register($1::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Register)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register query failed: %w", err)
|
||||
@@ -187,12 +191,11 @@ func (a *DatabaseAuthenticator) Logout(ctx context.Context, req LogoutRequest) e
|
||||
return fmt.Errorf("failed to marshal logout request: %w", err)
|
||||
}
|
||||
|
||||
// Call resolvespec_logout stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_data::text FROM resolvespec_logout($1::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.Logout)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
@@ -222,9 +225,8 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
||||
|
||||
if sessionToken == "" {
|
||||
// Try cookie
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err == nil {
|
||||
tokens = []string{cookie.Value}
|
||||
if token := GetSessionCookie(r); token != "" {
|
||||
tokens = []string{token}
|
||||
reference = "cookie"
|
||||
}
|
||||
} else {
|
||||
@@ -267,7 +269,7 @@ func (a *DatabaseAuthenticator) Authenticate(r *http.Request) (*UserContext, err
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.db.QueryRowContext(r.Context(), query, token, reference).Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session query failed: %w", err)
|
||||
@@ -339,24 +341,22 @@ func (a *DatabaseAuthenticator) updateSessionActivity(ctx context.Context, sessi
|
||||
return
|
||||
}
|
||||
|
||||
// Call resolvespec_session_update stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var updatedUserJSON sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session_update($1, $2::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.SessionUpdate)
|
||||
_ = a.db.QueryRowContext(ctx, query, sessionToken, string(userJSON)).Scan(&success, &errorMsg, &updatedUserJSON)
|
||||
}
|
||||
|
||||
// RefreshToken implements Refreshable interface
|
||||
func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
|
||||
// Call api_refresh_token stored procedure
|
||||
// First, we need to get the current user context for the refresh token
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userJSON sql.NullString
|
||||
// Get current session to pass to refresh
|
||||
query := `SELECT p_success, p_error, p_user::text FROM resolvespec_session($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2)`, a.sqlNames.Session)
|
||||
err := a.db.QueryRowContext(ctx, query, refreshToken, "refresh").Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token query failed: %w", err)
|
||||
@@ -369,12 +369,11 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
return nil, fmt.Errorf("invalid refresh token")
|
||||
}
|
||||
|
||||
// Call resolvespec_refresh_token to generate new token
|
||||
var newSuccess bool
|
||||
var newErrorMsg sql.NullString
|
||||
var newUserJSON sql.NullString
|
||||
|
||||
refreshQuery := `SELECT p_success, p_error, p_user::text FROM resolvespec_refresh_token($1, $2::jsonb)`
|
||||
refreshQuery := fmt.Sprintf(`SELECT p_success, p_error, p_user::text FROM %s($1, $2::jsonb)`, a.sqlNames.RefreshToken)
|
||||
err = a.db.QueryRowContext(ctx, refreshQuery, refreshToken, userJSON).Scan(&newSuccess, &newErrorMsg, &newUserJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("refresh token generation failed: %w", err)
|
||||
@@ -402,27 +401,28 @@ func (a *DatabaseAuthenticator) RefreshToken(ctx context.Context, refreshToken s
|
||||
|
||||
// JWTAuthenticator provides JWT token-based authentication
|
||||
// All database operations go through stored procedures
|
||||
// Requires stored procedures: resolvespec_jwt_login, resolvespec_jwt_logout
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
// NOTE: JWT signing/verification requires github.com/golang-jwt/jwt/v5 to be installed and imported
|
||||
type JWTAuthenticator struct {
|
||||
secretKey []byte
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewJWTAuthenticator(secretKey string, db *sql.DB) *JWTAuthenticator {
|
||||
func NewJWTAuthenticator(secretKey string, db *sql.DB, names ...*SQLNames) *JWTAuthenticator {
|
||||
return &JWTAuthenticator{
|
||||
secretKey: []byte(secretKey),
|
||||
db: db,
|
||||
sqlNames: resolveSQLNames(names...),
|
||||
}
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||
// Call resolvespec_jwt_login stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var userJSON []byte
|
||||
|
||||
query := `SELECT p_success, p_error, p_user FROM resolvespec_jwt_login($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_user FROM %s($1, $2)`, a.sqlNames.JWTLogin)
|
||||
err := a.db.QueryRowContext(ctx, query, req.Username, req.Password).Scan(&success, &errorMsg, &userJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login query failed: %w", err)
|
||||
@@ -472,11 +472,10 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginR
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
// Call resolvespec_jwt_logout stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_jwt_logout($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2)`, a.sqlNames.JWTLogout)
|
||||
err := a.db.QueryRowContext(ctx, query, req.Token, req.UserID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("logout query failed: %w", err)
|
||||
@@ -512,24 +511,24 @@ func (a *JWTAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||
|
||||
// DatabaseColumnSecurityProvider loads column security from database
|
||||
// All database operations go through stored procedures
|
||||
// Requires stored procedure: resolvespec_column_security
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabaseColumnSecurityProvider struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseColumnSecurityProvider(db *sql.DB) *DatabaseColumnSecurityProvider {
|
||||
return &DatabaseColumnSecurityProvider{db: db}
|
||||
func NewDatabaseColumnSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseColumnSecurityProvider {
|
||||
return &DatabaseColumnSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||
}
|
||||
|
||||
func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
||||
var rules []ColumnSecurity
|
||||
|
||||
// Call resolvespec_column_security stored procedure
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var rulesJSON []byte
|
||||
|
||||
query := `SELECT p_success, p_error, p_rules FROM resolvespec_column_security($1, $2, $3)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_rules FROM %s($1, $2, $3)`, p.sqlNames.ColumnSecurity)
|
||||
err := p.db.QueryRowContext(ctx, query, userID, schema, table).Scan(&success, &errorMsg, &rulesJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load column security: %w", err)
|
||||
@@ -577,21 +576,21 @@ func (p *DatabaseColumnSecurityProvider) GetColumnSecurity(ctx context.Context,
|
||||
|
||||
// DatabaseRowSecurityProvider loads row security from database
|
||||
// All database operations go through stored procedures
|
||||
// Requires stored procedure: resolvespec_row_security
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
type DatabaseRowSecurityProvider struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
func NewDatabaseRowSecurityProvider(db *sql.DB) *DatabaseRowSecurityProvider {
|
||||
return &DatabaseRowSecurityProvider{db: db}
|
||||
func NewDatabaseRowSecurityProvider(db *sql.DB, names ...*SQLNames) *DatabaseRowSecurityProvider {
|
||||
return &DatabaseRowSecurityProvider{db: db, sqlNames: resolveSQLNames(names...)}
|
||||
}
|
||||
|
||||
func (p *DatabaseRowSecurityProvider) GetRowSecurity(ctx context.Context, userID int, schema, table string) (RowSecurity, error) {
|
||||
var template string
|
||||
var hasBlock bool
|
||||
|
||||
// Call resolvespec_row_security stored procedure
|
||||
query := `SELECT p_template, p_block FROM resolvespec_row_security($1, $2, $3)`
|
||||
query := fmt.Sprintf(`SELECT p_template, p_block FROM %s($1, $2, $3)`, p.sqlNames.RowSecurity)
|
||||
|
||||
err := p.db.QueryRowContext(ctx, query, schema, table, userID).Scan(&template, &hasBlock)
|
||||
if err != nil {
|
||||
@@ -759,56 +758,47 @@ func (a *DatabaseAuthenticator) LoginWithPasskey(ctx context.Context, req Passke
|
||||
return nil, fmt.Errorf("passkey authentication failed: %w", err)
|
||||
}
|
||||
|
||||
// Get user data from database
|
||||
var username, email, roles string
|
||||
var userLevel int
|
||||
query := `SELECT username, email, user_level, COALESCE(roles, '') FROM users WHERE id = $1 AND is_active = true`
|
||||
err = a.db.QueryRowContext(ctx, query, userID).Scan(&username, &email, &userLevel, &roles)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
// Build request JSON for passkey login stored procedure
|
||||
reqData := map[string]any{
|
||||
"user_id": userID,
|
||||
}
|
||||
|
||||
// Generate session token
|
||||
sessionToken := "sess_" + generateRandomString(32) + "_" + fmt.Sprintf("%d", time.Now().Unix())
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
|
||||
// Extract IP and user agent from claims
|
||||
ipAddress := ""
|
||||
userAgent := ""
|
||||
if req.Claims != nil {
|
||||
if ip, ok := req.Claims["ip_address"].(string); ok {
|
||||
ipAddress = ip
|
||||
reqData["ip_address"] = ip
|
||||
}
|
||||
if ua, ok := req.Claims["user_agent"].(string); ok {
|
||||
userAgent = ua
|
||||
reqData["user_agent"] = ua
|
||||
}
|
||||
}
|
||||
|
||||
// Create session
|
||||
insertQuery := `INSERT INTO user_sessions (session_token, user_id, expires_at, ip_address, user_agent, last_activity_at)
|
||||
VALUES ($1, $2, $3, $4, $5, now())`
|
||||
_, err = a.db.ExecContext(ctx, insertQuery, sessionToken, userID, expiresAt, ipAddress, userAgent)
|
||||
reqJSON, err := json.Marshal(reqData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
return nil, fmt.Errorf("failed to marshal passkey login request: %w", err)
|
||||
}
|
||||
|
||||
// Update last login
|
||||
updateQuery := `UPDATE users SET last_login_at = now() WHERE id = $1`
|
||||
_, _ = a.db.ExecContext(ctx, updateQuery, userID)
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
var dataJSON sql.NullString
|
||||
|
||||
// Return login response
|
||||
return &LoginResponse{
|
||||
Token: sessionToken,
|
||||
User: &UserContext{
|
||||
UserID: userID,
|
||||
UserName: username,
|
||||
Email: email,
|
||||
UserLevel: userLevel,
|
||||
SessionID: sessionToken,
|
||||
Roles: parseRoles(roles),
|
||||
},
|
||||
ExpiresIn: int64(24 * time.Hour.Seconds()),
|
||||
}, nil
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_data::text FROM %s($1::jsonb)`, a.sqlNames.PasskeyLogin)
|
||||
err = a.db.QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &dataJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("passkey login query failed: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errorMsg.Valid {
|
||||
return nil, fmt.Errorf("%s", errorMsg.String)
|
||||
}
|
||||
return nil, fmt.Errorf("passkey login failed")
|
||||
}
|
||||
|
||||
var response LoginResponse
|
||||
if err := json.Unmarshal([]byte(dataJSON.String), &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse passkey login response: %w", err)
|
||||
}
|
||||
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// GetPasskeyCredentials returns all passkey credentials for a user
|
||||
|
||||
254
pkg/security/sql_names.go
Normal file
254
pkg/security/sql_names.go
Normal file
@@ -0,0 +1,254 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var validSQLIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
|
||||
|
||||
// SQLNames defines all configurable SQL stored procedure and table names
|
||||
// used by the security package. Override individual fields to remap
|
||||
// to custom database objects. Use DefaultSQLNames() for baseline defaults,
|
||||
// and MergeSQLNames() to apply partial overrides.
|
||||
type SQLNames struct {
|
||||
// Auth procedures (DatabaseAuthenticator)
|
||||
Login string // default: "resolvespec_login"
|
||||
Register string // default: "resolvespec_register"
|
||||
Logout string // default: "resolvespec_logout"
|
||||
Session string // default: "resolvespec_session"
|
||||
SessionUpdate string // default: "resolvespec_session_update"
|
||||
RefreshToken string // default: "resolvespec_refresh_token"
|
||||
|
||||
// JWT procedures (JWTAuthenticator)
|
||||
JWTLogin string // default: "resolvespec_jwt_login"
|
||||
JWTLogout string // default: "resolvespec_jwt_logout"
|
||||
|
||||
// Security policy procedures
|
||||
ColumnSecurity string // default: "resolvespec_column_security"
|
||||
RowSecurity string // default: "resolvespec_row_security"
|
||||
|
||||
// TOTP procedures (DatabaseTwoFactorProvider)
|
||||
TOTPEnable string // default: "resolvespec_totp_enable"
|
||||
TOTPDisable string // default: "resolvespec_totp_disable"
|
||||
TOTPGetStatus string // default: "resolvespec_totp_get_status"
|
||||
TOTPGetSecret string // default: "resolvespec_totp_get_secret"
|
||||
TOTPRegenerateBackup string // default: "resolvespec_totp_regenerate_backup_codes"
|
||||
TOTPValidateBackupCode string // default: "resolvespec_totp_validate_backup_code"
|
||||
|
||||
// Passkey procedures (DatabasePasskeyProvider)
|
||||
PasskeyStoreCredential string // default: "resolvespec_passkey_store_credential"
|
||||
PasskeyGetCredsByUsername string // default: "resolvespec_passkey_get_credentials_by_username"
|
||||
PasskeyGetCredential string // default: "resolvespec_passkey_get_credential"
|
||||
PasskeyUpdateCounter string // default: "resolvespec_passkey_update_counter"
|
||||
PasskeyGetUserCredentials string // default: "resolvespec_passkey_get_user_credentials"
|
||||
PasskeyDeleteCredential string // default: "resolvespec_passkey_delete_credential"
|
||||
PasskeyUpdateName string // default: "resolvespec_passkey_update_name"
|
||||
PasskeyLogin string // default: "resolvespec_passkey_login"
|
||||
|
||||
// OAuth2 procedures (DatabaseAuthenticator OAuth2 methods)
|
||||
OAuthGetOrCreateUser string // default: "resolvespec_oauth_getorcreateuser"
|
||||
OAuthCreateSession string // default: "resolvespec_oauth_createsession"
|
||||
OAuthGetRefreshToken string // default: "resolvespec_oauth_getrefreshtoken"
|
||||
OAuthUpdateRefreshToken string // default: "resolvespec_oauth_updaterefreshtoken"
|
||||
OAuthGetUser string // default: "resolvespec_oauth_getuser"
|
||||
|
||||
// OAuth2 server procedures (OAuthServer persistence)
|
||||
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
|
||||
OAuthGetClient string // default: "resolvespec_oauth_get_client"
|
||||
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
|
||||
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
|
||||
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
|
||||
OAuthRevoke string // default: "resolvespec_oauth_revoke"
|
||||
}
|
||||
|
||||
// DefaultSQLNames returns an SQLNames with all default resolvespec_* values.
|
||||
func DefaultSQLNames() *SQLNames {
|
||||
return &SQLNames{
|
||||
Login: "resolvespec_login",
|
||||
Register: "resolvespec_register",
|
||||
Logout: "resolvespec_logout",
|
||||
Session: "resolvespec_session",
|
||||
SessionUpdate: "resolvespec_session_update",
|
||||
RefreshToken: "resolvespec_refresh_token",
|
||||
|
||||
JWTLogin: "resolvespec_jwt_login",
|
||||
JWTLogout: "resolvespec_jwt_logout",
|
||||
|
||||
ColumnSecurity: "resolvespec_column_security",
|
||||
RowSecurity: "resolvespec_row_security",
|
||||
|
||||
TOTPEnable: "resolvespec_totp_enable",
|
||||
TOTPDisable: "resolvespec_totp_disable",
|
||||
TOTPGetStatus: "resolvespec_totp_get_status",
|
||||
TOTPGetSecret: "resolvespec_totp_get_secret",
|
||||
TOTPRegenerateBackup: "resolvespec_totp_regenerate_backup_codes",
|
||||
TOTPValidateBackupCode: "resolvespec_totp_validate_backup_code",
|
||||
|
||||
PasskeyStoreCredential: "resolvespec_passkey_store_credential",
|
||||
PasskeyGetCredsByUsername: "resolvespec_passkey_get_credentials_by_username",
|
||||
PasskeyGetCredential: "resolvespec_passkey_get_credential",
|
||||
PasskeyUpdateCounter: "resolvespec_passkey_update_counter",
|
||||
PasskeyGetUserCredentials: "resolvespec_passkey_get_user_credentials",
|
||||
PasskeyDeleteCredential: "resolvespec_passkey_delete_credential",
|
||||
PasskeyUpdateName: "resolvespec_passkey_update_name",
|
||||
PasskeyLogin: "resolvespec_passkey_login",
|
||||
|
||||
OAuthGetOrCreateUser: "resolvespec_oauth_getorcreateuser",
|
||||
OAuthCreateSession: "resolvespec_oauth_createsession",
|
||||
OAuthGetRefreshToken: "resolvespec_oauth_getrefreshtoken",
|
||||
OAuthUpdateRefreshToken: "resolvespec_oauth_updaterefreshtoken",
|
||||
OAuthGetUser: "resolvespec_oauth_getuser",
|
||||
|
||||
OAuthRegisterClient: "resolvespec_oauth_register_client",
|
||||
OAuthGetClient: "resolvespec_oauth_get_client",
|
||||
OAuthSaveCode: "resolvespec_oauth_save_code",
|
||||
OAuthExchangeCode: "resolvespec_oauth_exchange_code",
|
||||
OAuthIntrospect: "resolvespec_oauth_introspect",
|
||||
OAuthRevoke: "resolvespec_oauth_revoke",
|
||||
}
|
||||
}
|
||||
|
||||
// MergeSQLNames returns a copy of base with any non-empty fields from override applied.
|
||||
// If override is nil, a copy of base is returned.
|
||||
func MergeSQLNames(base, override *SQLNames) *SQLNames {
|
||||
if override == nil {
|
||||
copied := *base
|
||||
return &copied
|
||||
}
|
||||
merged := *base
|
||||
if override.Login != "" {
|
||||
merged.Login = override.Login
|
||||
}
|
||||
if override.Register != "" {
|
||||
merged.Register = override.Register
|
||||
}
|
||||
if override.Logout != "" {
|
||||
merged.Logout = override.Logout
|
||||
}
|
||||
if override.Session != "" {
|
||||
merged.Session = override.Session
|
||||
}
|
||||
if override.SessionUpdate != "" {
|
||||
merged.SessionUpdate = override.SessionUpdate
|
||||
}
|
||||
if override.RefreshToken != "" {
|
||||
merged.RefreshToken = override.RefreshToken
|
||||
}
|
||||
if override.JWTLogin != "" {
|
||||
merged.JWTLogin = override.JWTLogin
|
||||
}
|
||||
if override.JWTLogout != "" {
|
||||
merged.JWTLogout = override.JWTLogout
|
||||
}
|
||||
if override.ColumnSecurity != "" {
|
||||
merged.ColumnSecurity = override.ColumnSecurity
|
||||
}
|
||||
if override.RowSecurity != "" {
|
||||
merged.RowSecurity = override.RowSecurity
|
||||
}
|
||||
if override.TOTPEnable != "" {
|
||||
merged.TOTPEnable = override.TOTPEnable
|
||||
}
|
||||
if override.TOTPDisable != "" {
|
||||
merged.TOTPDisable = override.TOTPDisable
|
||||
}
|
||||
if override.TOTPGetStatus != "" {
|
||||
merged.TOTPGetStatus = override.TOTPGetStatus
|
||||
}
|
||||
if override.TOTPGetSecret != "" {
|
||||
merged.TOTPGetSecret = override.TOTPGetSecret
|
||||
}
|
||||
if override.TOTPRegenerateBackup != "" {
|
||||
merged.TOTPRegenerateBackup = override.TOTPRegenerateBackup
|
||||
}
|
||||
if override.TOTPValidateBackupCode != "" {
|
||||
merged.TOTPValidateBackupCode = override.TOTPValidateBackupCode
|
||||
}
|
||||
if override.PasskeyStoreCredential != "" {
|
||||
merged.PasskeyStoreCredential = override.PasskeyStoreCredential
|
||||
}
|
||||
if override.PasskeyGetCredsByUsername != "" {
|
||||
merged.PasskeyGetCredsByUsername = override.PasskeyGetCredsByUsername
|
||||
}
|
||||
if override.PasskeyGetCredential != "" {
|
||||
merged.PasskeyGetCredential = override.PasskeyGetCredential
|
||||
}
|
||||
if override.PasskeyUpdateCounter != "" {
|
||||
merged.PasskeyUpdateCounter = override.PasskeyUpdateCounter
|
||||
}
|
||||
if override.PasskeyGetUserCredentials != "" {
|
||||
merged.PasskeyGetUserCredentials = override.PasskeyGetUserCredentials
|
||||
}
|
||||
if override.PasskeyDeleteCredential != "" {
|
||||
merged.PasskeyDeleteCredential = override.PasskeyDeleteCredential
|
||||
}
|
||||
if override.PasskeyUpdateName != "" {
|
||||
merged.PasskeyUpdateName = override.PasskeyUpdateName
|
||||
}
|
||||
if override.PasskeyLogin != "" {
|
||||
merged.PasskeyLogin = override.PasskeyLogin
|
||||
}
|
||||
if override.OAuthGetOrCreateUser != "" {
|
||||
merged.OAuthGetOrCreateUser = override.OAuthGetOrCreateUser
|
||||
}
|
||||
if override.OAuthCreateSession != "" {
|
||||
merged.OAuthCreateSession = override.OAuthCreateSession
|
||||
}
|
||||
if override.OAuthGetRefreshToken != "" {
|
||||
merged.OAuthGetRefreshToken = override.OAuthGetRefreshToken
|
||||
}
|
||||
if override.OAuthUpdateRefreshToken != "" {
|
||||
merged.OAuthUpdateRefreshToken = override.OAuthUpdateRefreshToken
|
||||
}
|
||||
if override.OAuthGetUser != "" {
|
||||
merged.OAuthGetUser = override.OAuthGetUser
|
||||
}
|
||||
if override.OAuthRegisterClient != "" {
|
||||
merged.OAuthRegisterClient = override.OAuthRegisterClient
|
||||
}
|
||||
if override.OAuthGetClient != "" {
|
||||
merged.OAuthGetClient = override.OAuthGetClient
|
||||
}
|
||||
if override.OAuthSaveCode != "" {
|
||||
merged.OAuthSaveCode = override.OAuthSaveCode
|
||||
}
|
||||
if override.OAuthExchangeCode != "" {
|
||||
merged.OAuthExchangeCode = override.OAuthExchangeCode
|
||||
}
|
||||
if override.OAuthIntrospect != "" {
|
||||
merged.OAuthIntrospect = override.OAuthIntrospect
|
||||
}
|
||||
if override.OAuthRevoke != "" {
|
||||
merged.OAuthRevoke = override.OAuthRevoke
|
||||
}
|
||||
return &merged
|
||||
}
|
||||
|
||||
// ValidateSQLNames checks that all non-empty fields in names are valid SQL identifiers.
|
||||
// Returns an error if any field contains invalid characters.
|
||||
func ValidateSQLNames(names *SQLNames) error {
|
||||
v := reflect.ValueOf(names).Elem()
|
||||
typ := v.Type()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
val := field.String()
|
||||
if val != "" && !validSQLIdentifier.MatchString(val) {
|
||||
return fmt.Errorf("SQLNames.%s contains invalid characters: %q", typ.Field(i).Name, val)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveSQLNames merges an optional override with defaults.
|
||||
// Used by constructors that accept variadic *SQLNames.
|
||||
func resolveSQLNames(override ...*SQLNames) *SQLNames {
|
||||
if len(override) > 0 && override[0] != nil {
|
||||
return MergeSQLNames(DefaultSQLNames(), override[0])
|
||||
}
|
||||
return DefaultSQLNames()
|
||||
}
|
||||
145
pkg/security/sql_names_test.go
Normal file
145
pkg/security/sql_names_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultSQLNames_AllFieldsNonEmpty(t *testing.T) {
|
||||
names := DefaultSQLNames()
|
||||
v := reflect.ValueOf(names).Elem()
|
||||
typ := v.Type()
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
if field.String() == "" {
|
||||
t.Errorf("DefaultSQLNames().%s is empty", typ.Field(i).Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_PartialOverride(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
override := &SQLNames{
|
||||
Login: "custom_login",
|
||||
TOTPEnable: "custom_totp_enable",
|
||||
PasskeyLogin: "custom_passkey_login",
|
||||
}
|
||||
|
||||
merged := MergeSQLNames(base, override)
|
||||
|
||||
if merged.Login != "custom_login" {
|
||||
t.Errorf("MergeSQLNames().Login = %q, want %q", merged.Login, "custom_login")
|
||||
}
|
||||
if merged.TOTPEnable != "custom_totp_enable" {
|
||||
t.Errorf("MergeSQLNames().TOTPEnable = %q, want %q", merged.TOTPEnable, "custom_totp_enable")
|
||||
}
|
||||
if merged.PasskeyLogin != "custom_passkey_login" {
|
||||
t.Errorf("MergeSQLNames().PasskeyLogin = %q, want %q", merged.PasskeyLogin, "custom_passkey_login")
|
||||
}
|
||||
// Non-overridden fields should retain defaults
|
||||
if merged.Logout != "resolvespec_logout" {
|
||||
t.Errorf("MergeSQLNames().Logout = %q, want %q", merged.Logout, "resolvespec_logout")
|
||||
}
|
||||
if merged.Session != "resolvespec_session" {
|
||||
t.Errorf("MergeSQLNames().Session = %q, want %q", merged.Session, "resolvespec_session")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_NilOverride(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
merged := MergeSQLNames(base, nil)
|
||||
|
||||
// Should be a copy, not the same pointer
|
||||
if merged == base {
|
||||
t.Error("MergeSQLNames with nil override should return a copy, not the same pointer")
|
||||
}
|
||||
|
||||
// All values should match
|
||||
v1 := reflect.ValueOf(base).Elem()
|
||||
v2 := reflect.ValueOf(merged).Elem()
|
||||
typ := v1.Type()
|
||||
|
||||
for i := 0; i < v1.NumField(); i++ {
|
||||
f1 := v1.Field(i)
|
||||
f2 := v2.Field(i)
|
||||
if f1.Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
if f1.String() != f2.String() {
|
||||
t.Errorf("MergeSQLNames(base, nil).%s = %q, want %q", typ.Field(i).Name, f2.String(), f1.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_DoesNotMutateBase(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
originalLogin := base.Login
|
||||
|
||||
override := &SQLNames{Login: "custom_login"}
|
||||
_ = MergeSQLNames(base, override)
|
||||
|
||||
if base.Login != originalLogin {
|
||||
t.Errorf("MergeSQLNames mutated base: Login = %q, want %q", base.Login, originalLogin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeSQLNames_AllFieldsMerged(t *testing.T) {
|
||||
base := DefaultSQLNames()
|
||||
override := &SQLNames{}
|
||||
v := reflect.ValueOf(override).Elem()
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if v.Field(i).Kind() == reflect.String {
|
||||
v.Field(i).SetString("custom_sentinel")
|
||||
}
|
||||
}
|
||||
|
||||
merged := MergeSQLNames(base, override)
|
||||
mv := reflect.ValueOf(merged).Elem()
|
||||
typ := mv.Type()
|
||||
for i := 0; i < mv.NumField(); i++ {
|
||||
if mv.Field(i).Kind() != reflect.String {
|
||||
continue
|
||||
}
|
||||
if mv.Field(i).String() != "custom_sentinel" {
|
||||
t.Errorf("MergeSQLNames did not merge field %s", typ.Field(i).Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSQLNames_Valid(t *testing.T) {
|
||||
names := DefaultSQLNames()
|
||||
if err := ValidateSQLNames(names); err != nil {
|
||||
t.Errorf("ValidateSQLNames(defaults) error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSQLNames_Invalid(t *testing.T) {
|
||||
names := DefaultSQLNames()
|
||||
names.Login = "resolvespec_login; DROP TABLE users; --"
|
||||
|
||||
err := ValidateSQLNames(names)
|
||||
if err == nil {
|
||||
t.Error("ValidateSQLNames should reject names with invalid characters")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSQLNames_NoOverride(t *testing.T) {
|
||||
names := resolveSQLNames()
|
||||
if names.Login != "resolvespec_login" {
|
||||
t.Errorf("resolveSQLNames().Login = %q, want default", names.Login)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveSQLNames_WithOverride(t *testing.T) {
|
||||
names := resolveSQLNames(&SQLNames{Login: "custom_login"})
|
||||
if names.Login != "custom_login" {
|
||||
t.Errorf("resolveSQLNames().Login = %q, want %q", names.Login, "custom_login")
|
||||
}
|
||||
if names.Logout != "resolvespec_logout" {
|
||||
t.Errorf("resolveSQLNames().Logout = %q, want default", names.Logout)
|
||||
}
|
||||
}
|
||||
@@ -9,23 +9,23 @@ import (
|
||||
)
|
||||
|
||||
// DatabaseTwoFactorProvider implements TwoFactorAuthProvider using PostgreSQL stored procedures
|
||||
// Requires stored procedures: resolvespec_totp_enable, resolvespec_totp_disable,
|
||||
// resolvespec_totp_get_status, resolvespec_totp_get_secret,
|
||||
// resolvespec_totp_regenerate_backup_codes, resolvespec_totp_validate_backup_code
|
||||
// Procedure names are configurable via SQLNames (see DefaultSQLNames for defaults)
|
||||
// See totp_database_schema.sql for procedure definitions
|
||||
type DatabaseTwoFactorProvider struct {
|
||||
db *sql.DB
|
||||
totpGen *TOTPGenerator
|
||||
db *sql.DB
|
||||
totpGen *TOTPGenerator
|
||||
sqlNames *SQLNames
|
||||
}
|
||||
|
||||
// NewDatabaseTwoFactorProvider creates a new database-backed 2FA provider
|
||||
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig) *DatabaseTwoFactorProvider {
|
||||
func NewDatabaseTwoFactorProvider(db *sql.DB, config *TwoFactorConfig, names ...*SQLNames) *DatabaseTwoFactorProvider {
|
||||
if config == nil {
|
||||
config = DefaultTwoFactorConfig()
|
||||
}
|
||||
return &DatabaseTwoFactorProvider{
|
||||
db: db,
|
||||
totpGen: NewTOTPGenerator(config),
|
||||
db: db,
|
||||
totpGen: NewTOTPGenerator(config),
|
||||
sqlNames: resolveSQLNames(names...),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ func (p *DatabaseTwoFactorProvider) Enable2FA(userID int, secret string, backupC
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_enable($1, $2, $3::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2, $3::jsonb)`, p.sqlNames.TOTPEnable)
|
||||
err = p.db.QueryRow(query, userID, secret, string(codesJSON)).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("enable 2FA query failed: %w", err)
|
||||
@@ -97,7 +97,7 @@ func (p *DatabaseTwoFactorProvider) Disable2FA(userID int) error {
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_disable($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1)`, p.sqlNames.TOTPDisable)
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("disable 2FA query failed: %w", err)
|
||||
@@ -119,7 +119,7 @@ func (p *DatabaseTwoFactorProvider) Get2FAStatus(userID int) (bool, error) {
|
||||
var errorMsg sql.NullString
|
||||
var enabled bool
|
||||
|
||||
query := `SELECT p_success, p_error, p_enabled FROM resolvespec_totp_get_status($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_enabled FROM %s($1)`, p.sqlNames.TOTPGetStatus)
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &enabled)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get 2FA status query failed: %w", err)
|
||||
@@ -141,7 +141,7 @@ func (p *DatabaseTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||
var errorMsg sql.NullString
|
||||
var secret sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error, p_secret FROM resolvespec_totp_get_secret($1)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_secret FROM %s($1)`, p.sqlNames.TOTPGetSecret)
|
||||
err := p.db.QueryRow(query, userID).Scan(&success, &errorMsg, &secret)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get 2FA secret query failed: %w", err)
|
||||
@@ -185,7 +185,7 @@ func (p *DatabaseTwoFactorProvider) GenerateBackupCodes(userID int, count int) (
|
||||
var success bool
|
||||
var errorMsg sql.NullString
|
||||
|
||||
query := `SELECT p_success, p_error FROM resolvespec_totp_regenerate_backup_codes($1, $2::jsonb)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error FROM %s($1, $2::jsonb)`, p.sqlNames.TOTPRegenerateBackup)
|
||||
err = p.db.QueryRow(query, userID, string(codesJSON)).Scan(&success, &errorMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("regenerate backup codes query failed: %w", err)
|
||||
@@ -212,7 +212,7 @@ func (p *DatabaseTwoFactorProvider) ValidateBackupCode(userID int, code string)
|
||||
var errorMsg sql.NullString
|
||||
var valid bool
|
||||
|
||||
query := `SELECT p_success, p_error, p_valid FROM resolvespec_totp_validate_backup_code($1, $2)`
|
||||
query := fmt.Sprintf(`SELECT p_success, p_error, p_valid FROM %s($1, $2)`, p.sqlNames.TOTPValidateBackupCode)
|
||||
err := p.db.QueryRow(query, userID, codeHash).Scan(&success, &errorMsg, &valid)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("validate backup code query failed: %w", err)
|
||||
|
||||
@@ -3,6 +3,7 @@ package server_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -29,18 +30,18 @@ func ExampleManager_basic() {
|
||||
GZIP: true, // Enable GZIP compression
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Start all servers
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Server is now running...
|
||||
// When done, stop gracefully
|
||||
if err := mgr.StopAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +62,7 @@ func ExampleManager_https() {
|
||||
SSLKey: "/path/to/key.pem",
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Option 2: Self-signed certificate (for development)
|
||||
@@ -73,27 +74,27 @@ func ExampleManager_https() {
|
||||
SelfSignedSSL: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Option 3: Let's Encrypt / AutoTLS (for production)
|
||||
_, err = mgr.Add(server.Config{
|
||||
Name: "https-server-letsencrypt",
|
||||
Host: "0.0.0.0",
|
||||
Port: 443,
|
||||
Handler: handler,
|
||||
AutoTLS: true,
|
||||
AutoTLSDomains: []string{"example.com", "www.example.com"},
|
||||
AutoTLSEmail: "admin@example.com",
|
||||
Name: "https-server-letsencrypt",
|
||||
Host: "0.0.0.0",
|
||||
Port: 443,
|
||||
Handler: handler,
|
||||
AutoTLS: true,
|
||||
AutoTLSDomains: []string{"example.com", "www.example.com"},
|
||||
AutoTLSEmail: "admin@example.com",
|
||||
AutoTLSCacheDir: "./certs-cache",
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Start all servers
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
@@ -136,7 +137,7 @@ func ExampleManager_gracefulShutdown() {
|
||||
IdleTimeout: 120 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Start servers and block until shutdown signal (SIGINT/SIGTERM)
|
||||
@@ -164,7 +165,7 @@ func ExampleManager_healthChecks() {
|
||||
Handler: mux,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Add health and readiness endpoints
|
||||
@@ -173,7 +174,7 @@ func ExampleManager_healthChecks() {
|
||||
|
||||
// Start the server
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Health check returns:
|
||||
@@ -204,7 +205,7 @@ func ExampleManager_multipleServers() {
|
||||
GZIP: true,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Admin API server (different port)
|
||||
@@ -218,7 +219,7 @@ func ExampleManager_multipleServers() {
|
||||
Handler: adminHandler,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Metrics server (internal only)
|
||||
@@ -232,18 +233,18 @@ func ExampleManager_multipleServers() {
|
||||
Handler: metricsHandler,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Start all servers at once
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Get specific server instance
|
||||
publicInstance, err := mgr.Get("public-api")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Public API running on: %s\n", publicInstance.Addr())
|
||||
|
||||
@@ -253,7 +254,7 @@ func ExampleManager_multipleServers() {
|
||||
|
||||
// Stop all servers gracefully (in parallel)
|
||||
if err := mgr.StopAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,11 +274,11 @@ func ExampleManager_monitoring() {
|
||||
Handler: handler,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if err := mgr.StartAll(); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Check server status
|
||||
|
||||
@@ -98,6 +98,7 @@ func (p *EmbedFSProvider) Open(name string) (fs.File, error) {
|
||||
|
||||
// Apply prefix stripping by prepending the prefix to the requested path
|
||||
actualPath := name
|
||||
alternatePath := ""
|
||||
if p.stripPrefix != "" {
|
||||
// Clean the paths to handle leading/trailing slashes
|
||||
prefix := strings.Trim(p.stripPrefix, "/")
|
||||
@@ -105,12 +106,25 @@ func (p *EmbedFSProvider) Open(name string) (fs.File, error) {
|
||||
|
||||
if prefix != "" {
|
||||
actualPath = path.Join(prefix, cleanName)
|
||||
alternatePath = cleanName
|
||||
} else {
|
||||
actualPath = cleanName
|
||||
}
|
||||
}
|
||||
// First try the actual path with prefix
|
||||
if file, err := p.fs.Open(actualPath); err == nil {
|
||||
return file, nil
|
||||
}
|
||||
|
||||
return p.fs.Open(actualPath)
|
||||
// If alternate path is different, try it as well
|
||||
if alternatePath != "" && alternatePath != actualPath {
|
||||
if file, err := p.fs.Open(alternatePath); err == nil {
|
||||
return file, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If both attempts fail, return the error from the first attempt
|
||||
return nil, fmt.Errorf("file not found: %s", name)
|
||||
}
|
||||
|
||||
// Close releases any resources held by the provider.
|
||||
|
||||
@@ -53,6 +53,7 @@ func (p *LocalFSProvider) Open(name string) (fs.File, error) {
|
||||
|
||||
// Apply prefix stripping by prepending the prefix to the requested path
|
||||
actualPath := name
|
||||
alternatePath := ""
|
||||
if p.stripPrefix != "" {
|
||||
// Clean the paths to handle leading/trailing slashes
|
||||
prefix := strings.Trim(p.stripPrefix, "/")
|
||||
@@ -60,12 +61,26 @@ func (p *LocalFSProvider) Open(name string) (fs.File, error) {
|
||||
|
||||
if prefix != "" {
|
||||
actualPath = path.Join(prefix, cleanName)
|
||||
alternatePath = cleanName
|
||||
} else {
|
||||
actualPath = cleanName
|
||||
}
|
||||
}
|
||||
|
||||
return p.fs.Open(actualPath)
|
||||
// First try the actual path with prefix
|
||||
if file, err := p.fs.Open(actualPath); err == nil {
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// If alternate path is different, try it as well
|
||||
if alternatePath != "" && alternatePath != actualPath {
|
||||
if file, err := p.fs.Open(alternatePath); err == nil {
|
||||
return file, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If both attempts fail, return the error from the first attempt
|
||||
return nil, fmt.Errorf("file not found: %s", name)
|
||||
}
|
||||
|
||||
// Close releases any resources held by the provider.
|
||||
|
||||
@@ -56,6 +56,7 @@ func (p *ZipFSProvider) Open(name string) (fs.File, error) {
|
||||
|
||||
// Apply prefix stripping by prepending the prefix to the requested path
|
||||
actualPath := name
|
||||
alternatePath := ""
|
||||
if p.stripPrefix != "" {
|
||||
// Clean the paths to handle leading/trailing slashes
|
||||
prefix := strings.Trim(p.stripPrefix, "/")
|
||||
@@ -63,12 +64,26 @@ func (p *ZipFSProvider) Open(name string) (fs.File, error) {
|
||||
|
||||
if prefix != "" {
|
||||
actualPath = path.Join(prefix, cleanName)
|
||||
alternatePath = cleanName
|
||||
} else {
|
||||
actualPath = cleanName
|
||||
}
|
||||
}
|
||||
|
||||
return p.zipFS.Open(actualPath)
|
||||
// First try the actual path with prefix
|
||||
if file, err := p.zipFS.Open(actualPath); err == nil {
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// If alternate path is different, try it as well
|
||||
if alternatePath != "" && alternatePath != actualPath {
|
||||
if file, err := p.zipFS.Open(alternatePath); err == nil {
|
||||
return file, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If both attempts fail, return the error from the first attempt
|
||||
return nil, fmt.Errorf("file not found: %s", name)
|
||||
}
|
||||
|
||||
// Close releases resources held by the zip reader.
|
||||
|
||||
@@ -330,6 +330,7 @@ Hooks allow you to intercept and modify operations at various points in the life
|
||||
|
||||
### Available Hook Types
|
||||
|
||||
- **BeforeHandle** — fires after model resolution, before operation dispatch (auth checks)
|
||||
- **BeforeRead** / **AfterRead**
|
||||
- **BeforeCreate** / **AfterCreate**
|
||||
- **BeforeUpdate** / **AfterUpdate**
|
||||
@@ -337,6 +338,8 @@ Hooks allow you to intercept and modify operations at various points in the life
|
||||
- **BeforeSubscribe** / **AfterSubscribe**
|
||||
- **BeforeConnect** / **AfterConnect**
|
||||
|
||||
`HookContext` includes `Operation string` (`"read"`, `"create"`, `"update"`, `"delete"`) and `Abort bool`, `AbortMessage string`, `AbortCode int` for abort signaling.
|
||||
|
||||
### Hook Example
|
||||
|
||||
```go
|
||||
@@ -599,7 +602,19 @@ asyncio.run(main())
|
||||
|
||||
## Authentication
|
||||
|
||||
Implement authentication using hooks:
|
||||
Use `RegisterSecurityHooks` for integrated auth with model-rule support:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList := security.NewSecurityList(provider)
|
||||
websocketspec.RegisterSecurityHooks(handler, securityList)
|
||||
// Registers BeforeHandle (model auth), BeforeRead (load rules),
|
||||
// AfterRead (column security + audit), BeforeUpdate, BeforeDelete
|
||||
```
|
||||
|
||||
Or implement custom authentication using hooks directly:
|
||||
|
||||
```go
|
||||
handler := websocketspec.NewHandlerWithGORM(db)
|
||||
|
||||
@@ -177,6 +177,16 @@ func (h *Handler) handleRequest(conn *Connection, msg *Message) {
|
||||
Metadata: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||
hookCtx.Operation = string(msg.Operation)
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
if hookCtx.Abort {
|
||||
errResp := NewErrorResponse(msg.ID, "unauthorized", hookCtx.AbortMessage)
|
||||
_ = conn.SendJSON(errResp)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Route to operation handler
|
||||
switch msg.Operation {
|
||||
case OperationRead:
|
||||
@@ -618,7 +628,10 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
|
||||
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||
if hookCtx.Options != nil {
|
||||
for _, filter := range hookCtx.Options.Filters {
|
||||
countQuery = countQuery.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||
cond, args := h.buildFilterCondition(filter)
|
||||
if cond != "" {
|
||||
countQuery = countQuery.Where(cond, args...)
|
||||
}
|
||||
}
|
||||
}
|
||||
count, _ := countQuery.Count(hookCtx.Context)
|
||||
@@ -790,14 +803,12 @@ func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.Fi
|
||||
|
||||
// buildFilterCondition builds a filter condition and returns it with args
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
|
||||
var condition string
|
||||
var args []interface{}
|
||||
|
||||
if strings.EqualFold(filter.Operator, "in") {
|
||||
cond, args := common.BuildInCondition(filter.Column, filter.Value)
|
||||
return cond, args
|
||||
}
|
||||
operatorSQL := h.getOperatorSQL(filter.Operator)
|
||||
condition = fmt.Sprintf("%s %s ?", filter.Column, operatorSQL)
|
||||
args = []interface{}{filter.Value}
|
||||
|
||||
return condition, args
|
||||
return fmt.Sprintf("%s %s ?", filter.Column, operatorSQL), []interface{}{filter.Value}
|
||||
}
|
||||
|
||||
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
||||
|
||||
@@ -2,6 +2,7 @@ package websocketspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
@@ -10,6 +11,10 @@ import (
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||
// Use this for auth checks that need model rules and user context simultaneously.
|
||||
BeforeHandle HookType = "before_handle"
|
||||
|
||||
// BeforeRead is called before a read operation
|
||||
BeforeRead HookType = "before_read"
|
||||
// AfterRead is called after a read operation
|
||||
@@ -83,6 +88,9 @@ type HookContext struct {
|
||||
// Options contains the parsed request options
|
||||
Options *common.RequestOptions
|
||||
|
||||
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||
Operation string
|
||||
|
||||
// ID is the record ID for single-record operations
|
||||
ID string
|
||||
|
||||
@@ -98,6 +106,11 @@ type HookContext struct {
|
||||
// Error is any error that occurred (for after hooks)
|
||||
Error error
|
||||
|
||||
// Allow hooks to abort the operation
|
||||
Abort bool // If set to true, the operation will be aborted
|
||||
AbortMessage string // Message to return if aborted
|
||||
AbortCode int // HTTP status code if aborted
|
||||
|
||||
// Metadata is additional context data
|
||||
Metadata map[string]interface{}
|
||||
}
|
||||
@@ -171,6 +184,11 @@ func (hr *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||
if err := hook(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if hook requested abort
|
||||
if ctx.Abort {
|
||||
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
108
pkg/websocketspec/security_hooks.go
Normal file
108
pkg/websocketspec/security_hooks.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package websocketspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.LoadSecurityRules(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 2: AfterRead - Apply column-level security (masking)
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.ApplyColumnSecurity(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 3 (Optional): Audit logging
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Hook 4: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelUpdateAllowed(secCtx)
|
||||
})
|
||||
|
||||
// Hook 5: BeforeDelete - enforce CanDelete rule from context/registry
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelDeleteAllowed(secCtx)
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for websocketspec handler")
|
||||
}
|
||||
|
||||
// securityContext adapts websocketspec.HookContext to security.SecurityContext interface
|
||||
type securityContext struct {
|
||||
ctx *HookContext
|
||||
}
|
||||
|
||||
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||
return &securityContext{ctx: ctx}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetContext() context.Context {
|
||||
return s.ctx.Context
|
||||
}
|
||||
|
||||
func (s *securityContext) GetUserID() (int, bool) {
|
||||
return security.GetUserID(s.ctx.Context)
|
||||
}
|
||||
|
||||
func (s *securityContext) GetSchema() string {
|
||||
return s.ctx.Schema
|
||||
}
|
||||
|
||||
func (s *securityContext) GetEntity() string {
|
||||
return s.ctx.Entity
|
||||
}
|
||||
|
||||
func (s *securityContext) GetModel() interface{} {
|
||||
return s.ctx.Model
|
||||
}
|
||||
|
||||
// GetQuery retrieves a stored query from hook metadata (websocketspec has no Query field)
|
||||
func (s *securityContext) GetQuery() interface{} {
|
||||
if s.ctx.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
return s.ctx.Metadata["query"]
|
||||
}
|
||||
|
||||
// SetQuery stores the query in hook metadata
|
||||
func (s *securityContext) SetQuery(query interface{}) {
|
||||
if s.ctx.Metadata == nil {
|
||||
s.ctx.Metadata = make(map[string]interface{})
|
||||
}
|
||||
s.ctx.Metadata["query"] = query
|
||||
}
|
||||
|
||||
func (s *securityContext) GetResult() interface{} {
|
||||
return s.ctx.Result
|
||||
}
|
||||
|
||||
func (s *securityContext) SetResult(result interface{}) {
|
||||
s.ctx.Result = result
|
||||
}
|
||||
8
resolvespec-js/.changeset/README.md
Normal file
8
resolvespec-js/.changeset/README.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# Changesets
|
||||
|
||||
Hello and welcome! This folder has been automatically generated by `@changesets/cli`, a build tool that works
|
||||
with multi-package repos, or single-package repos to help you version and publish your code. You can
|
||||
find the full documentation for it [in our repository](https://github.com/changesets/changesets)
|
||||
|
||||
We have a quick list of common questions to get you started engaging with this project in
|
||||
[our documentation](https://github.com/changesets/changesets/blob/main/docs/common-questions.md)
|
||||
11
resolvespec-js/.changeset/config.json
Normal file
11
resolvespec-js/.changeset/config.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"$schema": "https://unpkg.com/@changesets/config@3.1.2/schema.json",
|
||||
"changelog": "@changesets/cli/changelog",
|
||||
"commit": false,
|
||||
"fixed": [],
|
||||
"linked": [],
|
||||
"access": "restricted",
|
||||
"baseBranch": "main",
|
||||
"updateInternalDependencies": "patch",
|
||||
"ignore": []
|
||||
}
|
||||
7
resolvespec-js/CHANGELOG.md
Normal file
7
resolvespec-js/CHANGELOG.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# @warkypublic/resolvespec-js
|
||||
|
||||
## 1.0.1
|
||||
|
||||
### Patch Changes
|
||||
|
||||
- Fixed headerpsec
|
||||
132
resolvespec-js/PLAN.md
Normal file
132
resolvespec-js/PLAN.md
Normal file
@@ -0,0 +1,132 @@
|
||||
# ResolveSpec JS - Implementation Plan
|
||||
|
||||
TypeScript client library for ResolveSpec, RestHeaderSpec, WebSocket and MQTT APIs.
|
||||
|
||||
---
|
||||
|
||||
## Status
|
||||
|
||||
| Phase | Description | Status |
|
||||
|-------|-------------|--------|
|
||||
| 0 | Restructure into folders | Done |
|
||||
| 1 | Fix types (align with Go) | Done |
|
||||
| 2 | Fix REST client | Done |
|
||||
| 3 | Build config | Done |
|
||||
| 4 | Tests | Done |
|
||||
| 5 | HeaderSpec client | Done |
|
||||
| 6 | MQTT client | Planned |
|
||||
| 6.5 | Unified class pattern + singleton factories | Done |
|
||||
| 7 | Response cache (TTL) | Planned |
|
||||
| 8 | TanStack Query integration | Planned |
|
||||
| 9 | React Hooks | Planned |
|
||||
|
||||
**Build:** `dist/index.js` (ES) + `dist/index.cjs` (CJS) + `.d.ts` declarations
|
||||
**Tests:** 65 passing (common: 10, resolvespec: 13, websocketspec: 15, headerspec: 27)
|
||||
|
||||
---
|
||||
|
||||
## Folder Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── common/
|
||||
│ ├── types.ts # Core types aligned with Go pkg/common/types.go
|
||||
│ └── index.ts
|
||||
├── resolvespec/
|
||||
│ ├── client.ts # ResolveSpecClient class + createResolveSpecClient singleton
|
||||
│ └── index.ts
|
||||
├── headerspec/
|
||||
│ ├── client.ts # HeaderSpecClient class + createHeaderSpecClient singleton + buildHeaders utility
|
||||
│ └── index.ts
|
||||
├── websocketspec/
|
||||
│ ├── types.ts # WS-specific types (WSMessage, WSOptions, etc.)
|
||||
│ ├── client.ts # WebSocketClient class + createWebSocketClient singleton
|
||||
│ └── index.ts
|
||||
├── mqttspec/ # Future
|
||||
│ ├── types.ts
|
||||
│ ├── client.ts
|
||||
│ └── index.ts
|
||||
├── __tests__/
|
||||
│ ├── common.test.ts
|
||||
│ ├── resolvespec.test.ts
|
||||
│ ├── headerspec.test.ts
|
||||
│ └── websocketspec.test.ts
|
||||
└── index.ts # Root barrel export
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Type Alignment with Go
|
||||
|
||||
Types in `src/common/types.ts` match `pkg/common/types.go`:
|
||||
|
||||
- **Operator**: `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, `contains`, `startswith`, `endswith`, `between`, `between_inclusive`, `is_null`, `is_not_null`
|
||||
- **FilterOption**: `column`, `operator`, `value`, `logic_operator` (AND/OR)
|
||||
- **Options**: `columns`, `omit_columns`, `filters`, `sort`, `limit`, `offset`, `preload`, `customOperators`, `computedColumns`, `parameters`, `cursor_forward`, `cursor_backward`, `fetch_row_number`
|
||||
- **PreloadOption**: `relation`, `table_name`, `columns`, `omit_columns`, `sort`, `filters`, `where`, `limit`, `offset`, `updatable`, `recursive`, `computed_ql`, `primary_key`, `related_key`, `foreign_key`, `recursive_child_key`, `sql_joins`, `join_aliases`
|
||||
- **Parameter**: `name`, `value`, `sequence?`
|
||||
- **Metadata**: `total`, `count`, `filtered`, `limit`, `offset`, `row_number?`
|
||||
- **APIError**: `code`, `message`, `details?`, `detail?`
|
||||
|
||||
---
|
||||
|
||||
## HeaderSpec Header Mapping
|
||||
|
||||
Maps Options to HTTP headers per Go `restheadspec/headers.go`:
|
||||
|
||||
| Header | Options field | Format |
|
||||
|--------|--------------|--------|
|
||||
| `X-Select-Fields` | `columns` | comma-separated |
|
||||
| `X-Not-Select-Fields` | `omit_columns` | comma-separated |
|
||||
| `X-FieldFilter-{col}` | `filters` (eq, AND) | value |
|
||||
| `X-SearchOp-{op}-{col}` | `filters` (AND) | value |
|
||||
| `X-SearchOr-{op}-{col}` | `filters` (OR) | value |
|
||||
| `X-Sort` | `sort` | `+col` (asc), `-col` (desc) |
|
||||
| `X-Limit` | `limit` | number |
|
||||
| `X-Offset` | `offset` | number |
|
||||
| `X-Cursor-Forward` | `cursor_forward` | string |
|
||||
| `X-Cursor-Backward` | `cursor_backward` | string |
|
||||
| `X-Preload` | `preload` | `Rel:col1,col2` pipe-separated |
|
||||
| `X-Fetch-RowNumber` | `fetch_row_number` | string |
|
||||
| `X-CQL-SEL-{col}` | `computedColumns` | expression |
|
||||
| `X-Custom-SQL-W` | `customOperators` | SQL AND-joined |
|
||||
|
||||
Complex values use `ZIP_` + base64 encoding.
|
||||
HTTP methods: GET=read, POST=create, PUT=update, DELETE=delete.
|
||||
|
||||
---
|
||||
|
||||
## Build & Test
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
pnpm run build # vite library mode → dist/
|
||||
pnpm run test # vitest
|
||||
pnpm run lint # eslint
|
||||
```
|
||||
|
||||
**Config files:** `tsconfig.json` (ES2020, strict, bundler), `vite.config.ts` (lib mode, dts via vite-plugin-dts)
|
||||
**Externals:** `uuid`, `semver`
|
||||
|
||||
---
|
||||
|
||||
## Remaining Work
|
||||
|
||||
- **Phase 6 — MQTT Client**: Topic-based CRUD over MQTT (optional/future)
|
||||
- **Phase 7 — Cache**: In-memory response cache with TTL, key = URL + options hash, auto-invalidation on CUD, `skipCache` flag
|
||||
- **Phase 8 — TanStack Query Integration**: Query/mutation hooks wrapping each client, query key factories, automatic cache invalidation
|
||||
- **Phase 9 — React Hooks**: `useResolveSpec`, `useHeaderSpec`, `useWebSocket` hooks with provider context, loading/error states
|
||||
- ESLint config may need updating for new folder structure
|
||||
|
||||
---
|
||||
|
||||
## Reference Files
|
||||
|
||||
| Purpose | Path |
|
||||
|---------|------|
|
||||
| Go types (source of truth) | `pkg/common/types.go` |
|
||||
| Go REST handler | `pkg/resolvespec/handler.go` |
|
||||
| Go HeaderSpec handler | `pkg/restheadspec/handler.go` |
|
||||
| Go HeaderSpec header parsing | `pkg/restheadspec/headers.go` |
|
||||
| Go test models | `pkg/testmodels/business.go` |
|
||||
| Go tests | `tests/crud_test.go` |
|
||||
213
resolvespec-js/README.md
Normal file
213
resolvespec-js/README.md
Normal file
@@ -0,0 +1,213 @@
|
||||
# ResolveSpec JS
|
||||
|
||||
TypeScript client library for ResolveSpec APIs. Supports body-based REST, header-based REST, and WebSocket protocols.
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
pnpm add @warkypublic/resolvespec-js
|
||||
```
|
||||
|
||||
## Clients
|
||||
|
||||
| Client | Protocol | Singleton Factory |
|
||||
| --- | --- | --- |
|
||||
| `ResolveSpecClient` | REST (body-based) | `getResolveSpecClient(config)` |
|
||||
| `HeaderSpecClient` | REST (header-based) | `getHeaderSpecClient(config)` |
|
||||
| `WebSocketClient` | WebSocket | `getWebSocketClient(config)` |
|
||||
|
||||
All clients use the class pattern. Singleton factories return cached instances keyed by URL.
|
||||
|
||||
## REST Client (Body-Based)
|
||||
|
||||
Options sent in JSON request body. Maps to Go `pkg/resolvespec`.
|
||||
|
||||
```typescript
|
||||
import { ResolveSpecClient, getResolveSpecClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
// Class instantiation
|
||||
const client = new ResolveSpecClient({ baseUrl: 'http://localhost:3000', token: 'your-token' });
|
||||
|
||||
// Or singleton factory (returns cached instance per baseUrl)
|
||||
const client = getResolveSpecClient({ baseUrl: 'http://localhost:3000', token: 'your-token' });
|
||||
|
||||
// Read with filters, sort, pagination
|
||||
const result = await client.read('public', 'users', undefined, {
|
||||
columns: ['id', 'name', 'email'],
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }],
|
||||
sort: [{ column: 'name', direction: 'asc' }],
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
preload: [{ relation: 'Posts', columns: ['id', 'title'] }],
|
||||
});
|
||||
|
||||
// Read by ID
|
||||
const user = await client.read('public', 'users', 42);
|
||||
|
||||
// Create
|
||||
const created = await client.create('public', 'users', { name: 'New User' });
|
||||
|
||||
// Update
|
||||
await client.update('public', 'users', { name: 'Updated' }, 42);
|
||||
|
||||
// Delete
|
||||
await client.delete('public', 'users', 42);
|
||||
|
||||
// Metadata
|
||||
const meta = await client.getMetadata('public', 'users');
|
||||
```
|
||||
|
||||
## HeaderSpec Client (Header-Based)
|
||||
|
||||
Options sent via HTTP headers. Maps to Go `pkg/restheadspec`.
|
||||
|
||||
```typescript
|
||||
import { HeaderSpecClient, getHeaderSpecClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
const client = new HeaderSpecClient({ baseUrl: 'http://localhost:3000', token: 'your-token' });
|
||||
// Or: const client = getHeaderSpecClient({ baseUrl: 'http://localhost:3000', token: 'your-token' });
|
||||
|
||||
// GET with options as headers
|
||||
const result = await client.read('public', 'users', undefined, {
|
||||
columns: ['id', 'name'],
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' },
|
||||
{ column: 'age', operator: 'gte', value: 18, logic_operator: 'AND' },
|
||||
],
|
||||
sort: [{ column: 'name', direction: 'asc' }],
|
||||
limit: 50,
|
||||
preload: [{ relation: 'Department', columns: ['id', 'name'] }],
|
||||
});
|
||||
|
||||
// POST create
|
||||
await client.create('public', 'users', { name: 'New User' });
|
||||
|
||||
// PUT update
|
||||
await client.update('public', 'users', '42', { name: 'Updated' });
|
||||
|
||||
// DELETE
|
||||
await client.delete('public', 'users', '42');
|
||||
```
|
||||
|
||||
### Header Mapping
|
||||
|
||||
| Header | Options Field | Format |
|
||||
| --- | --- | --- |
|
||||
| `X-Select-Fields` | `columns` | comma-separated |
|
||||
| `X-Not-Select-Fields` | `omit_columns` | comma-separated |
|
||||
| `X-FieldFilter-{col}` | `filters` (eq, AND) | value |
|
||||
| `X-SearchOp-{op}-{col}` | `filters` (AND) | value |
|
||||
| `X-SearchOr-{op}-{col}` | `filters` (OR) | value |
|
||||
| `X-Sort` | `sort` | `+col` asc, `-col` desc |
|
||||
| `X-Limit` / `X-Offset` | `limit` / `offset` | number |
|
||||
| `X-Cursor-Forward` | `cursor_forward` | string |
|
||||
| `X-Cursor-Backward` | `cursor_backward` | string |
|
||||
| `X-Preload` | `preload` | `Rel:col1,col2` pipe-separated |
|
||||
| `X-Fetch-RowNumber` | `fetch_row_number` | string |
|
||||
| `X-CQL-SEL-{col}` | `computedColumns` | expression |
|
||||
| `X-Custom-SQL-W` | `customOperators` | SQL AND-joined |
|
||||
|
||||
### Utility Functions
|
||||
|
||||
```typescript
|
||||
import { buildHeaders, encodeHeaderValue, decodeHeaderValue } from '@warkypublic/resolvespec-js';
|
||||
|
||||
const headers = buildHeaders({ columns: ['id', 'name'], limit: 10 });
|
||||
// => { 'X-Select-Fields': 'id,name', 'X-Limit': '10' }
|
||||
|
||||
const encoded = encodeHeaderValue('complex value'); // 'ZIP_...'
|
||||
const decoded = decodeHeaderValue(encoded); // 'complex value'
|
||||
```
|
||||
|
||||
## WebSocket Client
|
||||
|
||||
Real-time CRUD with subscriptions. Maps to Go `pkg/websocketspec`.
|
||||
|
||||
```typescript
|
||||
import { WebSocketClient, getWebSocketClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
const ws = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
heartbeatInterval: 30000,
|
||||
});
|
||||
// Or: const ws = getWebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
|
||||
await ws.connect();
|
||||
|
||||
// CRUD
|
||||
const users = await ws.read('users', { schema: 'public', limit: 10 });
|
||||
const created = await ws.create('users', { name: 'New' }, { schema: 'public' });
|
||||
await ws.update('users', '1', { name: 'Updated' });
|
||||
await ws.delete('users', '1');
|
||||
|
||||
// Subscribe to changes
|
||||
const subId = await ws.subscribe('users', (notification) => {
|
||||
console.log(notification.operation, notification.data);
|
||||
});
|
||||
|
||||
// Unsubscribe
|
||||
await ws.unsubscribe(subId);
|
||||
|
||||
// Events
|
||||
ws.on('connect', () => console.log('connected'));
|
||||
ws.on('disconnect', () => console.log('disconnected'));
|
||||
ws.on('error', (err) => console.error(err));
|
||||
|
||||
ws.disconnect();
|
||||
```
|
||||
|
||||
## Types
|
||||
|
||||
All types align with Go `pkg/common/types.go`.
|
||||
|
||||
### Key Types
|
||||
|
||||
```typescript
|
||||
interface Options {
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
filters?: FilterOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
preload?: PreloadOption[];
|
||||
customOperators?: CustomOperator[];
|
||||
computedColumns?: ComputedColumn[];
|
||||
parameters?: Parameter[];
|
||||
cursor_forward?: string;
|
||||
cursor_backward?: string;
|
||||
fetch_row_number?: string;
|
||||
}
|
||||
|
||||
interface FilterOption {
|
||||
column: string;
|
||||
operator: Operator | string;
|
||||
value: any;
|
||||
logic_operator?: 'AND' | 'OR';
|
||||
}
|
||||
|
||||
// Operators: eq, neq, gt, gte, lt, lte, like, ilike, in,
|
||||
// contains, startswith, endswith, between,
|
||||
// between_inclusive, is_null, is_not_null
|
||||
|
||||
interface APIResponse<T> {
|
||||
success: boolean;
|
||||
data: T;
|
||||
metadata?: Metadata;
|
||||
error?: APIError;
|
||||
}
|
||||
```
|
||||
|
||||
## Build
|
||||
|
||||
```bash
|
||||
pnpm install
|
||||
pnpm run build # dist/index.js (ES) + dist/index.cjs (CJS) + .d.ts
|
||||
pnpm run test # vitest
|
||||
pnpm run lint # eslint
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
@@ -1,530 +0,0 @@
|
||||
# WebSocketSpec JavaScript Client
|
||||
|
||||
A TypeScript/JavaScript client for connecting to WebSocketSpec servers with full support for real-time subscriptions, CRUD operations, and automatic reconnection.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install @warkypublic/resolvespec-js
|
||||
# or
|
||||
yarn add @warkypublic/resolvespec-js
|
||||
# or
|
||||
pnpm add @warkypublic/resolvespec-js
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```typescript
|
||||
import { WebSocketClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
// Create client
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
debug: true
|
||||
});
|
||||
|
||||
// Connect
|
||||
await client.connect();
|
||||
|
||||
// Read records
|
||||
const users = await client.read('users', {
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
],
|
||||
limit: 10
|
||||
});
|
||||
|
||||
// Subscribe to changes
|
||||
const subscriptionId = await client.subscribe('users', (notification) => {
|
||||
console.log('User changed:', notification.operation, notification.data);
|
||||
}, { schema: 'public' });
|
||||
|
||||
// Clean up
|
||||
await client.unsubscribe(subscriptionId);
|
||||
client.disconnect();
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- **Real-Time Updates**: Subscribe to entity changes and receive instant notifications
|
||||
- **Full CRUD Support**: Create, read, update, and delete operations
|
||||
- **TypeScript Support**: Full type definitions included
|
||||
- **Auto Reconnection**: Automatic reconnection with configurable retry logic
|
||||
- **Heartbeat**: Built-in keepalive mechanism
|
||||
- **Event System**: Listen to connection, error, and message events
|
||||
- **Promise-based API**: All async operations return promises
|
||||
- **Filter & Sort**: Advanced querying with filters, sorting, and pagination
|
||||
- **Preloading**: Load related entities in a single query
|
||||
|
||||
## Configuration
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws', // WebSocket server URL
|
||||
reconnect: true, // Enable auto-reconnection
|
||||
reconnectInterval: 3000, // Reconnection delay (ms)
|
||||
maxReconnectAttempts: 10, // Max reconnection attempts
|
||||
heartbeatInterval: 30000, // Heartbeat interval (ms)
|
||||
debug: false // Enable debug logging
|
||||
});
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### Connection Management
|
||||
|
||||
#### `connect(): Promise<void>`
|
||||
Connect to the WebSocket server.
|
||||
|
||||
```typescript
|
||||
await client.connect();
|
||||
```
|
||||
|
||||
#### `disconnect(): void`
|
||||
Disconnect from the server.
|
||||
|
||||
```typescript
|
||||
client.disconnect();
|
||||
```
|
||||
|
||||
#### `isConnected(): boolean`
|
||||
Check if currently connected.
|
||||
|
||||
```typescript
|
||||
if (client.isConnected()) {
|
||||
console.log('Connected!');
|
||||
}
|
||||
```
|
||||
|
||||
#### `getState(): ConnectionState`
|
||||
Get current connection state: `'connecting'`, `'connected'`, `'disconnecting'`, `'disconnected'`, or `'reconnecting'`.
|
||||
|
||||
```typescript
|
||||
const state = client.getState();
|
||||
console.log('State:', state);
|
||||
```
|
||||
|
||||
### CRUD Operations
|
||||
|
||||
#### `read<T>(entity: string, options?): Promise<T>`
|
||||
Read records from an entity.
|
||||
|
||||
```typescript
|
||||
// Read all active users
|
||||
const users = await client.read('users', {
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
],
|
||||
columns: ['id', 'name', 'email'],
|
||||
sort: [
|
||||
{ column: 'name', direction: 'asc' }
|
||||
],
|
||||
limit: 10,
|
||||
offset: 0
|
||||
});
|
||||
|
||||
// Read single record by ID
|
||||
const user = await client.read('users', {
|
||||
schema: 'public',
|
||||
record_id: '123'
|
||||
});
|
||||
|
||||
// Read with preloading
|
||||
const posts = await client.read('posts', {
|
||||
schema: 'public',
|
||||
preload: [
|
||||
{
|
||||
relation: 'user',
|
||||
columns: ['id', 'name', 'email']
|
||||
},
|
||||
{
|
||||
relation: 'comments',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'approved' }
|
||||
]
|
||||
}
|
||||
]
|
||||
});
|
||||
```
|
||||
|
||||
#### `create<T>(entity: string, data: any, options?): Promise<T>`
|
||||
Create a new record.
|
||||
|
||||
```typescript
|
||||
const newUser = await client.create('users', {
|
||||
name: 'John Doe',
|
||||
email: 'john@example.com',
|
||||
status: 'active'
|
||||
}, {
|
||||
schema: 'public'
|
||||
});
|
||||
```
|
||||
|
||||
#### `update<T>(entity: string, id: string, data: any, options?): Promise<T>`
|
||||
Update an existing record.
|
||||
|
||||
```typescript
|
||||
const updatedUser = await client.update('users', '123', {
|
||||
name: 'John Updated',
|
||||
email: 'john.new@example.com'
|
||||
}, {
|
||||
schema: 'public'
|
||||
});
|
||||
```
|
||||
|
||||
#### `delete(entity: string, id: string, options?): Promise<void>`
|
||||
Delete a record.
|
||||
|
||||
```typescript
|
||||
await client.delete('users', '123', {
|
||||
schema: 'public'
|
||||
});
|
||||
```
|
||||
|
||||
#### `meta<T>(entity: string, options?): Promise<T>`
|
||||
Get metadata for an entity.
|
||||
|
||||
```typescript
|
||||
const metadata = await client.meta('users', {
|
||||
schema: 'public'
|
||||
});
|
||||
console.log('Columns:', metadata.columns);
|
||||
console.log('Primary key:', metadata.primary_key);
|
||||
```
|
||||
|
||||
### Subscriptions
|
||||
|
||||
#### `subscribe(entity: string, callback: Function, options?): Promise<string>`
|
||||
Subscribe to entity changes.
|
||||
|
||||
```typescript
|
||||
const subscriptionId = await client.subscribe(
|
||||
'users',
|
||||
(notification) => {
|
||||
console.log('Operation:', notification.operation); // 'create', 'update', or 'delete'
|
||||
console.log('Data:', notification.data);
|
||||
console.log('Timestamp:', notification.timestamp);
|
||||
},
|
||||
{
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
]
|
||||
}
|
||||
);
|
||||
```
|
||||
|
||||
#### `unsubscribe(subscriptionId: string): Promise<void>`
|
||||
Unsubscribe from entity changes.
|
||||
|
||||
```typescript
|
||||
await client.unsubscribe(subscriptionId);
|
||||
```
|
||||
|
||||
#### `getSubscriptions(): Subscription[]`
|
||||
Get list of active subscriptions.
|
||||
|
||||
```typescript
|
||||
const subscriptions = client.getSubscriptions();
|
||||
console.log('Active subscriptions:', subscriptions.length);
|
||||
```
|
||||
|
||||
### Event Handling
|
||||
|
||||
#### `on(event: string, callback: Function): void`
|
||||
Add event listener.
|
||||
|
||||
```typescript
|
||||
// Connection events
|
||||
client.on('connect', () => {
|
||||
console.log('Connected!');
|
||||
});
|
||||
|
||||
client.on('disconnect', (event) => {
|
||||
console.log('Disconnected:', event.code, event.reason);
|
||||
});
|
||||
|
||||
client.on('error', (error) => {
|
||||
console.error('Error:', error);
|
||||
});
|
||||
|
||||
// State changes
|
||||
client.on('stateChange', (state) => {
|
||||
console.log('State:', state);
|
||||
});
|
||||
|
||||
// All messages
|
||||
client.on('message', (message) => {
|
||||
console.log('Message:', message);
|
||||
});
|
||||
```
|
||||
|
||||
#### `off(event: string): void`
|
||||
Remove event listener.
|
||||
|
||||
```typescript
|
||||
client.off('connect');
|
||||
```
|
||||
|
||||
## Filter Operators
|
||||
|
||||
- `eq` - Equal (=)
|
||||
- `neq` - Not Equal (!=)
|
||||
- `gt` - Greater Than (>)
|
||||
- `gte` - Greater Than or Equal (>=)
|
||||
- `lt` - Less Than (<)
|
||||
- `lte` - Less Than or Equal (<=)
|
||||
- `like` - LIKE (case-sensitive)
|
||||
- `ilike` - ILIKE (case-insensitive)
|
||||
- `in` - IN (array of values)
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic CRUD
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
await client.connect();
|
||||
|
||||
// Create
|
||||
const user = await client.create('users', {
|
||||
name: 'Alice',
|
||||
email: 'alice@example.com'
|
||||
});
|
||||
|
||||
// Read
|
||||
const users = await client.read('users', {
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
|
||||
});
|
||||
|
||||
// Update
|
||||
await client.update('users', user.id, { name: 'Alice Updated' });
|
||||
|
||||
// Delete
|
||||
await client.delete('users', user.id);
|
||||
|
||||
client.disconnect();
|
||||
```
|
||||
|
||||
### Real-Time Subscriptions
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
await client.connect();
|
||||
|
||||
// Subscribe to all user changes
|
||||
const subId = await client.subscribe('users', (notification) => {
|
||||
switch (notification.operation) {
|
||||
case 'create':
|
||||
console.log('New user:', notification.data);
|
||||
break;
|
||||
case 'update':
|
||||
console.log('User updated:', notification.data);
|
||||
break;
|
||||
case 'delete':
|
||||
console.log('User deleted:', notification.data);
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
// Later: unsubscribe
|
||||
await client.unsubscribe(subId);
|
||||
```
|
||||
|
||||
### React Integration
|
||||
|
||||
```typescript
|
||||
import { useEffect, useState } from 'react';
|
||||
import { WebSocketClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
function useWebSocket(url: string) {
|
||||
const [client] = useState(() => new WebSocketClient({ url }));
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
client.on('connect', () => setIsConnected(true));
|
||||
client.on('disconnect', () => setIsConnected(false));
|
||||
client.connect();
|
||||
|
||||
return () => client.disconnect();
|
||||
}, [client]);
|
||||
|
||||
return { client, isConnected };
|
||||
}
|
||||
|
||||
function UsersComponent() {
|
||||
const { client, isConnected } = useWebSocket('ws://localhost:8080/ws');
|
||||
const [users, setUsers] = useState([]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isConnected) return;
|
||||
|
||||
const loadUsers = async () => {
|
||||
// Subscribe to changes
|
||||
await client.subscribe('users', (notification) => {
|
||||
if (notification.operation === 'create') {
|
||||
setUsers(prev => [...prev, notification.data]);
|
||||
} else if (notification.operation === 'update') {
|
||||
setUsers(prev => prev.map(u =>
|
||||
u.id === notification.data.id ? notification.data : u
|
||||
));
|
||||
} else if (notification.operation === 'delete') {
|
||||
setUsers(prev => prev.filter(u => u.id !== notification.data.id));
|
||||
}
|
||||
});
|
||||
|
||||
// Load initial data
|
||||
const data = await client.read('users');
|
||||
setUsers(data);
|
||||
};
|
||||
|
||||
loadUsers();
|
||||
}, [client, isConnected]);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h2>Users {isConnected ? '🟢' : '🔴'}</h2>
|
||||
{users.map(user => (
|
||||
<div key={user.id}>{user.name}</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### TypeScript with Typed Models
|
||||
|
||||
```typescript
|
||||
interface User {
|
||||
id: number;
|
||||
name: string;
|
||||
email: string;
|
||||
status: 'active' | 'inactive';
|
||||
}
|
||||
|
||||
interface Post {
|
||||
id: number;
|
||||
title: string;
|
||||
content: string;
|
||||
user_id: number;
|
||||
user?: User;
|
||||
}
|
||||
|
||||
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
await client.connect();
|
||||
|
||||
// Type-safe operations
|
||||
const users = await client.read<User[]>('users', {
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
|
||||
});
|
||||
|
||||
const newUser = await client.create<User>('users', {
|
||||
name: 'Bob',
|
||||
email: 'bob@example.com',
|
||||
status: 'active'
|
||||
});
|
||||
|
||||
// Type-safe subscriptions
|
||||
await client.subscribe(
|
||||
'posts',
|
||||
(notification) => {
|
||||
const post = notification.data as Post;
|
||||
console.log('Post:', post.title);
|
||||
}
|
||||
);
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
maxReconnectAttempts: 5
|
||||
});
|
||||
|
||||
client.on('error', (error) => {
|
||||
console.error('Connection error:', error);
|
||||
});
|
||||
|
||||
client.on('stateChange', (state) => {
|
||||
console.log('State:', state);
|
||||
if (state === 'reconnecting') {
|
||||
console.log('Attempting to reconnect...');
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
await client.connect();
|
||||
|
||||
try {
|
||||
const user = await client.read('users', { record_id: '999' });
|
||||
} catch (error) {
|
||||
console.error('Record not found:', error);
|
||||
}
|
||||
|
||||
try {
|
||||
await client.create('users', { /* invalid data */ });
|
||||
} catch (error) {
|
||||
console.error('Validation failed:', error);
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Connection failed:', error);
|
||||
}
|
||||
```
|
||||
|
||||
### Multiple Subscriptions
|
||||
|
||||
```typescript
|
||||
const client = new WebSocketClient({ url: 'ws://localhost:8080/ws' });
|
||||
await client.connect();
|
||||
|
||||
// Subscribe to multiple entities
|
||||
const userSub = await client.subscribe('users', (n) => {
|
||||
console.log('[Users]', n.operation, n.data);
|
||||
});
|
||||
|
||||
const postSub = await client.subscribe('posts', (n) => {
|
||||
console.log('[Posts]', n.operation, n.data);
|
||||
}, {
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'published' }]
|
||||
});
|
||||
|
||||
const commentSub = await client.subscribe('comments', (n) => {
|
||||
console.log('[Comments]', n.operation, n.data);
|
||||
});
|
||||
|
||||
// Check active subscriptions
|
||||
console.log('Active:', client.getSubscriptions().length);
|
||||
|
||||
// Clean up
|
||||
await client.unsubscribe(userSub);
|
||||
await client.unsubscribe(postSub);
|
||||
await client.unsubscribe(commentSub);
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always Clean Up**: Call `disconnect()` when done to close the connection properly
|
||||
2. **Use TypeScript**: Leverage type definitions for better type safety
|
||||
3. **Handle Errors**: Always wrap operations in try-catch blocks
|
||||
4. **Limit Subscriptions**: Don't create too many subscriptions per connection
|
||||
5. **Use Filters**: Apply filters to subscriptions to reduce unnecessary notifications
|
||||
6. **Connection State**: Check `isConnected()` before operations
|
||||
7. **Event Listeners**: Remove event listeners when no longer needed with `off()`
|
||||
8. **Reconnection**: Enable auto-reconnection for production apps
|
||||
|
||||
## Browser Support
|
||||
|
||||
- Chrome/Edge 88+
|
||||
- Firefox 85+
|
||||
- Safari 14+
|
||||
- Node.js 14.16+
|
||||
|
||||
## License
|
||||
|
||||
MIT
|
||||
1
resolvespec-js/dist/index.cjs
vendored
Normal file
1
resolvespec-js/dist/index.cjs
vendored
Normal file
File diff suppressed because one or more lines are too long
366
resolvespec-js/dist/index.d.ts
vendored
Normal file
366
resolvespec-js/dist/index.d.ts
vendored
Normal file
@@ -0,0 +1,366 @@
|
||||
export declare interface APIError {
|
||||
code: string;
|
||||
message: string;
|
||||
details?: any;
|
||||
detail?: string;
|
||||
}
|
||||
|
||||
export declare interface APIResponse<T = any> {
|
||||
success: boolean;
|
||||
data: T;
|
||||
metadata?: Metadata;
|
||||
error?: APIError;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build HTTP headers from Options, matching Go's restheadspec handler conventions.
|
||||
*
|
||||
* Header mapping:
|
||||
* - X-Select-Fields: comma-separated columns
|
||||
* - X-Not-Select-Fields: comma-separated omit_columns
|
||||
* - X-FieldFilter-{col}: exact match (eq)
|
||||
* - X-SearchOp-{operator}-{col}: AND filter
|
||||
* - X-SearchOr-{operator}-{col}: OR filter
|
||||
* - X-Sort: +col (asc), -col (desc)
|
||||
* - X-Limit, X-Offset: pagination
|
||||
* - X-Cursor-Forward, X-Cursor-Backward: cursor pagination
|
||||
* - X-Preload: RelationName:field1,field2 pipe-separated
|
||||
* - X-Fetch-RowNumber: row number fetch
|
||||
* - X-CQL-SEL-{col}: computed columns
|
||||
* - X-Custom-SQL-W: custom operators (AND)
|
||||
*/
|
||||
export declare function buildHeaders(options: Options): Record<string, string>;
|
||||
|
||||
export declare interface ClientConfig {
|
||||
baseUrl: string;
|
||||
token?: string;
|
||||
}
|
||||
|
||||
export declare interface Column {
|
||||
name: string;
|
||||
type: string;
|
||||
is_nullable: boolean;
|
||||
is_primary: boolean;
|
||||
is_unique: boolean;
|
||||
has_index: boolean;
|
||||
}
|
||||
|
||||
export declare interface ComputedColumn {
|
||||
name: string;
|
||||
expression: string;
|
||||
}
|
||||
|
||||
export declare type ConnectionState = 'connecting' | 'connected' | 'disconnecting' | 'disconnected' | 'reconnecting';
|
||||
|
||||
export declare interface CustomOperator {
|
||||
name: string;
|
||||
sql: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode a header value that may be base64 encoded with ZIP_ or __ prefix.
|
||||
*/
|
||||
export declare function decodeHeaderValue(value: string): string;
|
||||
|
||||
/**
|
||||
* Encode a value with base64 and ZIP_ prefix for complex header values.
|
||||
*/
|
||||
export declare function encodeHeaderValue(value: string): string;
|
||||
|
||||
export declare interface FilterOption {
|
||||
column: string;
|
||||
operator: Operator | string;
|
||||
value: any;
|
||||
logic_operator?: 'AND' | 'OR';
|
||||
}
|
||||
|
||||
export declare function getHeaderSpecClient(config: ClientConfig): HeaderSpecClient;
|
||||
|
||||
export declare function getResolveSpecClient(config: ClientConfig): ResolveSpecClient;
|
||||
|
||||
export declare function getWebSocketClient(config: WebSocketClientConfig): WebSocketClient;
|
||||
|
||||
/**
|
||||
* HeaderSpec REST client.
|
||||
* Sends query options via HTTP headers instead of request body, matching the Go restheadspec handler.
|
||||
*
|
||||
* HTTP methods: GET=read, POST=create, PUT=update, DELETE=delete
|
||||
*/
|
||||
export declare class HeaderSpecClient {
|
||||
private config;
|
||||
constructor(config: ClientConfig);
|
||||
private buildUrl;
|
||||
private baseHeaders;
|
||||
private fetchWithError;
|
||||
read<T = any>(schema: string, entity: string, id?: string, options?: Options): Promise<APIResponse<T>>;
|
||||
create<T = any>(schema: string, entity: string, data: any, options?: Options): Promise<APIResponse<T>>;
|
||||
update<T = any>(schema: string, entity: string, id: string, data: any, options?: Options): Promise<APIResponse<T>>;
|
||||
delete(schema: string, entity: string, id: string): Promise<APIResponse<void>>;
|
||||
}
|
||||
|
||||
export declare type MessageType = 'request' | 'response' | 'notification' | 'subscription' | 'error' | 'ping' | 'pong';
|
||||
|
||||
export declare interface Metadata {
|
||||
total: number;
|
||||
count: number;
|
||||
filtered: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
row_number?: number;
|
||||
}
|
||||
|
||||
export declare type Operation = 'read' | 'create' | 'update' | 'delete';
|
||||
|
||||
export declare type Operator = 'eq' | 'neq' | 'gt' | 'gte' | 'lt' | 'lte' | 'like' | 'ilike' | 'in' | 'contains' | 'startswith' | 'endswith' | 'between' | 'between_inclusive' | 'is_null' | 'is_not_null';
|
||||
|
||||
export declare interface Options {
|
||||
preload?: PreloadOption[];
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
filters?: FilterOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
customOperators?: CustomOperator[];
|
||||
computedColumns?: ComputedColumn[];
|
||||
parameters?: Parameter[];
|
||||
cursor_forward?: string;
|
||||
cursor_backward?: string;
|
||||
fetch_row_number?: string;
|
||||
}
|
||||
|
||||
export declare interface Parameter {
|
||||
name: string;
|
||||
value: string;
|
||||
sequence?: number;
|
||||
}
|
||||
|
||||
export declare interface PreloadOption {
|
||||
relation: string;
|
||||
table_name?: string;
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
sort?: SortOption[];
|
||||
filters?: FilterOption[];
|
||||
where?: string;
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
updatable?: boolean;
|
||||
computed_ql?: Record<string, string>;
|
||||
recursive?: boolean;
|
||||
primary_key?: string;
|
||||
related_key?: string;
|
||||
foreign_key?: string;
|
||||
recursive_child_key?: string;
|
||||
sql_joins?: string[];
|
||||
join_aliases?: string[];
|
||||
}
|
||||
|
||||
export declare interface RequestBody {
|
||||
operation: Operation;
|
||||
id?: number | string | string[];
|
||||
data?: any | any[];
|
||||
options?: Options;
|
||||
}
|
||||
|
||||
export declare class ResolveSpecClient {
|
||||
private config;
|
||||
constructor(config: ClientConfig);
|
||||
private buildUrl;
|
||||
private baseHeaders;
|
||||
private fetchWithError;
|
||||
getMetadata(schema: string, entity: string): Promise<APIResponse<TableMetadata>>;
|
||||
read<T = any>(schema: string, entity: string, id?: number | string | string[], options?: Options): Promise<APIResponse<T>>;
|
||||
create<T = any>(schema: string, entity: string, data: any | any[], options?: Options): Promise<APIResponse<T>>;
|
||||
update<T = any>(schema: string, entity: string, data: any | any[], id?: number | string | string[], options?: Options): Promise<APIResponse<T>>;
|
||||
delete(schema: string, entity: string, id: number | string): Promise<APIResponse<void>>;
|
||||
}
|
||||
|
||||
export declare type SortDirection = 'asc' | 'desc' | 'ASC' | 'DESC';
|
||||
|
||||
export declare interface SortOption {
|
||||
column: string;
|
||||
direction: SortDirection;
|
||||
}
|
||||
|
||||
export declare interface Subscription {
|
||||
id: string;
|
||||
entity: string;
|
||||
schema?: string;
|
||||
options?: WSOptions;
|
||||
callback?: (notification: WSNotificationMessage) => void;
|
||||
}
|
||||
|
||||
export declare interface SubscriptionOptions {
|
||||
filters?: FilterOption[];
|
||||
onNotification?: (notification: WSNotificationMessage) => void;
|
||||
}
|
||||
|
||||
export declare interface TableMetadata {
|
||||
schema: string;
|
||||
table: string;
|
||||
columns: Column[];
|
||||
relations: string[];
|
||||
}
|
||||
|
||||
export declare class WebSocketClient {
|
||||
private ws;
|
||||
private config;
|
||||
private messageHandlers;
|
||||
private subscriptions;
|
||||
private eventListeners;
|
||||
private state;
|
||||
private reconnectAttempts;
|
||||
private reconnectTimer;
|
||||
private heartbeatTimer;
|
||||
private isManualClose;
|
||||
constructor(config: WebSocketClientConfig);
|
||||
connect(): Promise<void>;
|
||||
disconnect(): void;
|
||||
request<T = any>(operation: WSOperation, entity: string, options?: {
|
||||
schema?: string;
|
||||
record_id?: string;
|
||||
data?: any;
|
||||
options?: WSOptions;
|
||||
}): Promise<T>;
|
||||
read<T = any>(entity: string, options?: {
|
||||
schema?: string;
|
||||
record_id?: string;
|
||||
filters?: FilterOption[];
|
||||
columns?: string[];
|
||||
sort?: SortOption[];
|
||||
preload?: PreloadOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
}): Promise<T>;
|
||||
create<T = any>(entity: string, data: any, options?: {
|
||||
schema?: string;
|
||||
}): Promise<T>;
|
||||
update<T = any>(entity: string, id: string, data: any, options?: {
|
||||
schema?: string;
|
||||
}): Promise<T>;
|
||||
delete(entity: string, id: string, options?: {
|
||||
schema?: string;
|
||||
}): Promise<void>;
|
||||
meta<T = any>(entity: string, options?: {
|
||||
schema?: string;
|
||||
}): Promise<T>;
|
||||
subscribe(entity: string, callback: (notification: WSNotificationMessage) => void, options?: {
|
||||
schema?: string;
|
||||
filters?: FilterOption[];
|
||||
}): Promise<string>;
|
||||
unsubscribe(subscriptionId: string): Promise<void>;
|
||||
getSubscriptions(): Subscription[];
|
||||
getState(): ConnectionState;
|
||||
isConnected(): boolean;
|
||||
on<K extends keyof WebSocketClientEvents>(event: K, callback: WebSocketClientEvents[K]): void;
|
||||
off<K extends keyof WebSocketClientEvents>(event: K): void;
|
||||
private handleMessage;
|
||||
private handleResponse;
|
||||
private handleNotification;
|
||||
private send;
|
||||
private startHeartbeat;
|
||||
private stopHeartbeat;
|
||||
private setState;
|
||||
private ensureConnected;
|
||||
private emit;
|
||||
private log;
|
||||
}
|
||||
|
||||
export declare interface WebSocketClientConfig {
|
||||
url: string;
|
||||
reconnect?: boolean;
|
||||
reconnectInterval?: number;
|
||||
maxReconnectAttempts?: number;
|
||||
heartbeatInterval?: number;
|
||||
debug?: boolean;
|
||||
}
|
||||
|
||||
export declare interface WebSocketClientEvents {
|
||||
connect: () => void;
|
||||
disconnect: (event: CloseEvent) => void;
|
||||
error: (error: Error) => void;
|
||||
message: (message: WSMessage) => void;
|
||||
stateChange: (state: ConnectionState) => void;
|
||||
}
|
||||
|
||||
export declare interface WSErrorInfo {
|
||||
code: string;
|
||||
message: string;
|
||||
details?: Record<string, any>;
|
||||
}
|
||||
|
||||
export declare interface WSMessage {
|
||||
id?: string;
|
||||
type: MessageType;
|
||||
operation?: WSOperation;
|
||||
schema?: string;
|
||||
entity?: string;
|
||||
record_id?: string;
|
||||
data?: any;
|
||||
options?: WSOptions;
|
||||
subscription_id?: string;
|
||||
success?: boolean;
|
||||
error?: WSErrorInfo;
|
||||
metadata?: Record<string, any>;
|
||||
timestamp?: string;
|
||||
}
|
||||
|
||||
export declare interface WSNotificationMessage {
|
||||
type: 'notification';
|
||||
operation: WSOperation;
|
||||
subscription_id: string;
|
||||
schema?: string;
|
||||
entity: string;
|
||||
data: any;
|
||||
timestamp: string;
|
||||
}
|
||||
|
||||
export declare type WSOperation = 'read' | 'create' | 'update' | 'delete' | 'subscribe' | 'unsubscribe' | 'meta';
|
||||
|
||||
export declare interface WSOptions {
|
||||
filters?: FilterOption[];
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
preload?: PreloadOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
parameters?: Parameter[];
|
||||
cursor_forward?: string;
|
||||
cursor_backward?: string;
|
||||
fetch_row_number?: string;
|
||||
}
|
||||
|
||||
export declare interface WSRequestMessage {
|
||||
id: string;
|
||||
type: 'request';
|
||||
operation: WSOperation;
|
||||
schema?: string;
|
||||
entity: string;
|
||||
record_id?: string;
|
||||
data?: any;
|
||||
options?: WSOptions;
|
||||
}
|
||||
|
||||
export declare interface WSResponseMessage {
|
||||
id: string;
|
||||
type: 'response';
|
||||
success: boolean;
|
||||
data?: any;
|
||||
error?: WSErrorInfo;
|
||||
metadata?: Record<string, any>;
|
||||
timestamp: string;
|
||||
}
|
||||
|
||||
export declare interface WSSubscriptionMessage {
|
||||
id: string;
|
||||
type: 'subscription';
|
||||
operation: 'subscribe' | 'unsubscribe';
|
||||
schema?: string;
|
||||
entity: string;
|
||||
options?: WSOptions;
|
||||
subscription_id?: string;
|
||||
}
|
||||
|
||||
export { }
|
||||
469
resolvespec-js/dist/index.js
vendored
Normal file
469
resolvespec-js/dist/index.js
vendored
Normal file
@@ -0,0 +1,469 @@
|
||||
import { v4 as l } from "uuid";
|
||||
const d = /* @__PURE__ */ new Map();
|
||||
function E(n) {
|
||||
const e = n.baseUrl;
|
||||
let t = d.get(e);
|
||||
return t || (t = new g(n), d.set(e, t)), t;
|
||||
}
|
||||
class g {
|
||||
constructor(e) {
|
||||
this.config = e;
|
||||
}
|
||||
buildUrl(e, t, s) {
|
||||
let r = `${this.config.baseUrl}/${e}/${t}`;
|
||||
return s && (r += `/${s}`), r;
|
||||
}
|
||||
baseHeaders() {
|
||||
const e = {
|
||||
"Content-Type": "application/json"
|
||||
};
|
||||
return this.config.token && (e.Authorization = `Bearer ${this.config.token}`), e;
|
||||
}
|
||||
async fetchWithError(e, t) {
|
||||
const s = await fetch(e, t), r = await s.json();
|
||||
if (!s.ok)
|
||||
throw new Error(r.error?.message || "An error occurred");
|
||||
return r;
|
||||
}
|
||||
async getMetadata(e, t) {
|
||||
const s = this.buildUrl(e, t);
|
||||
return this.fetchWithError(s, {
|
||||
method: "GET",
|
||||
headers: this.baseHeaders()
|
||||
});
|
||||
}
|
||||
async read(e, t, s, r) {
|
||||
const i = typeof s == "number" || typeof s == "string" ? String(s) : void 0, a = this.buildUrl(e, t, i), c = {
|
||||
operation: "read",
|
||||
id: Array.isArray(s) ? s : void 0,
|
||||
options: r
|
||||
};
|
||||
return this.fetchWithError(a, {
|
||||
method: "POST",
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(c)
|
||||
});
|
||||
}
|
||||
async create(e, t, s, r) {
|
||||
const i = this.buildUrl(e, t), a = {
|
||||
operation: "create",
|
||||
data: s,
|
||||
options: r
|
||||
};
|
||||
return this.fetchWithError(i, {
|
||||
method: "POST",
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(a)
|
||||
});
|
||||
}
|
||||
async update(e, t, s, r, i) {
|
||||
const a = typeof r == "number" || typeof r == "string" ? String(r) : void 0, c = this.buildUrl(e, t, a), o = {
|
||||
operation: "update",
|
||||
id: Array.isArray(r) ? r : void 0,
|
||||
data: s,
|
||||
options: i
|
||||
};
|
||||
return this.fetchWithError(c, {
|
||||
method: "POST",
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(o)
|
||||
});
|
||||
}
|
||||
async delete(e, t, s) {
|
||||
const r = this.buildUrl(e, t, String(s)), i = {
|
||||
operation: "delete"
|
||||
};
|
||||
return this.fetchWithError(r, {
|
||||
method: "POST",
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(i)
|
||||
});
|
||||
}
|
||||
}
|
||||
const f = /* @__PURE__ */ new Map();
|
||||
function _(n) {
|
||||
const e = n.url;
|
||||
let t = f.get(e);
|
||||
return t || (t = new p(n), f.set(e, t)), t;
|
||||
}
|
||||
class p {
|
||||
constructor(e) {
|
||||
this.ws = null, this.messageHandlers = /* @__PURE__ */ new Map(), this.subscriptions = /* @__PURE__ */ new Map(), this.eventListeners = {}, this.state = "disconnected", this.reconnectAttempts = 0, this.reconnectTimer = null, this.heartbeatTimer = null, this.isManualClose = !1, this.config = {
|
||||
url: e.url,
|
||||
reconnect: e.reconnect ?? !0,
|
||||
reconnectInterval: e.reconnectInterval ?? 3e3,
|
||||
maxReconnectAttempts: e.maxReconnectAttempts ?? 10,
|
||||
heartbeatInterval: e.heartbeatInterval ?? 3e4,
|
||||
debug: e.debug ?? !1
|
||||
};
|
||||
}
|
||||
async connect() {
|
||||
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||
this.log("Already connected");
|
||||
return;
|
||||
}
|
||||
return this.isManualClose = !1, this.setState("connecting"), new Promise((e, t) => {
|
||||
try {
|
||||
this.ws = new WebSocket(this.config.url), this.ws.onopen = () => {
|
||||
this.log("Connected to WebSocket server"), this.setState("connected"), this.reconnectAttempts = 0, this.startHeartbeat(), this.emit("connect"), e();
|
||||
}, this.ws.onmessage = (s) => {
|
||||
this.handleMessage(s.data);
|
||||
}, this.ws.onerror = (s) => {
|
||||
this.log("WebSocket error:", s);
|
||||
const r = new Error("WebSocket connection error");
|
||||
this.emit("error", r), t(r);
|
||||
}, this.ws.onclose = (s) => {
|
||||
this.log("WebSocket closed:", s.code, s.reason), this.stopHeartbeat(), this.setState("disconnected"), this.emit("disconnect", s), this.config.reconnect && !this.isManualClose && this.reconnectAttempts < this.config.maxReconnectAttempts && (this.reconnectAttempts++, this.log(`Reconnection attempt ${this.reconnectAttempts}/${this.config.maxReconnectAttempts}`), this.setState("reconnecting"), this.reconnectTimer = setTimeout(() => {
|
||||
this.connect().catch((r) => {
|
||||
this.log("Reconnection failed:", r);
|
||||
});
|
||||
}, this.config.reconnectInterval));
|
||||
};
|
||||
} catch (s) {
|
||||
t(s);
|
||||
}
|
||||
});
|
||||
}
|
||||
disconnect() {
|
||||
this.isManualClose = !0, this.reconnectTimer && (clearTimeout(this.reconnectTimer), this.reconnectTimer = null), this.stopHeartbeat(), this.ws && (this.setState("disconnecting"), this.ws.close(), this.ws = null), this.setState("disconnected"), this.messageHandlers.clear();
|
||||
}
|
||||
async request(e, t, s) {
|
||||
this.ensureConnected();
|
||||
const r = l(), i = {
|
||||
id: r,
|
||||
type: "request",
|
||||
operation: e,
|
||||
entity: t,
|
||||
schema: s?.schema,
|
||||
record_id: s?.record_id,
|
||||
data: s?.data,
|
||||
options: s?.options
|
||||
};
|
||||
return new Promise((a, c) => {
|
||||
this.messageHandlers.set(r, (o) => {
|
||||
o.success ? a(o.data) : c(new Error(o.error?.message || "Request failed"));
|
||||
}), this.send(i), setTimeout(() => {
|
||||
this.messageHandlers.has(r) && (this.messageHandlers.delete(r), c(new Error("Request timeout")));
|
||||
}, 3e4);
|
||||
});
|
||||
}
|
||||
async read(e, t) {
|
||||
return this.request("read", e, {
|
||||
schema: t?.schema,
|
||||
record_id: t?.record_id,
|
||||
options: {
|
||||
filters: t?.filters,
|
||||
columns: t?.columns,
|
||||
sort: t?.sort,
|
||||
preload: t?.preload,
|
||||
limit: t?.limit,
|
||||
offset: t?.offset
|
||||
}
|
||||
});
|
||||
}
|
||||
async create(e, t, s) {
|
||||
return this.request("create", e, {
|
||||
schema: s?.schema,
|
||||
data: t
|
||||
});
|
||||
}
|
||||
async update(e, t, s, r) {
|
||||
return this.request("update", e, {
|
||||
schema: r?.schema,
|
||||
record_id: t,
|
||||
data: s
|
||||
});
|
||||
}
|
||||
async delete(e, t, s) {
|
||||
await this.request("delete", e, {
|
||||
schema: s?.schema,
|
||||
record_id: t
|
||||
});
|
||||
}
|
||||
async meta(e, t) {
|
||||
return this.request("meta", e, {
|
||||
schema: t?.schema
|
||||
});
|
||||
}
|
||||
async subscribe(e, t, s) {
|
||||
this.ensureConnected();
|
||||
const r = l(), i = {
|
||||
id: r,
|
||||
type: "subscription",
|
||||
operation: "subscribe",
|
||||
entity: e,
|
||||
schema: s?.schema,
|
||||
options: {
|
||||
filters: s?.filters
|
||||
}
|
||||
};
|
||||
return new Promise((a, c) => {
|
||||
this.messageHandlers.set(r, (o) => {
|
||||
if (o.success && o.data?.subscription_id) {
|
||||
const h = o.data.subscription_id;
|
||||
this.subscriptions.set(h, {
|
||||
id: h,
|
||||
entity: e,
|
||||
schema: s?.schema,
|
||||
options: { filters: s?.filters },
|
||||
callback: t
|
||||
}), this.log(`Subscribed to ${e} with ID: ${h}`), a(h);
|
||||
} else
|
||||
c(new Error(o.error?.message || "Subscription failed"));
|
||||
}), this.send(i), setTimeout(() => {
|
||||
this.messageHandlers.has(r) && (this.messageHandlers.delete(r), c(new Error("Subscription timeout")));
|
||||
}, 1e4);
|
||||
});
|
||||
}
|
||||
async unsubscribe(e) {
|
||||
this.ensureConnected();
|
||||
const t = l(), s = {
|
||||
id: t,
|
||||
type: "subscription",
|
||||
operation: "unsubscribe",
|
||||
subscription_id: e
|
||||
};
|
||||
return new Promise((r, i) => {
|
||||
this.messageHandlers.set(t, (a) => {
|
||||
a.success ? (this.subscriptions.delete(e), this.log(`Unsubscribed from ${e}`), r()) : i(new Error(a.error?.message || "Unsubscribe failed"));
|
||||
}), this.send(s), setTimeout(() => {
|
||||
this.messageHandlers.has(t) && (this.messageHandlers.delete(t), i(new Error("Unsubscribe timeout")));
|
||||
}, 1e4);
|
||||
});
|
||||
}
|
||||
getSubscriptions() {
|
||||
return Array.from(this.subscriptions.values());
|
||||
}
|
||||
getState() {
|
||||
return this.state;
|
||||
}
|
||||
isConnected() {
|
||||
return this.ws?.readyState === WebSocket.OPEN;
|
||||
}
|
||||
on(e, t) {
|
||||
this.eventListeners[e] = t;
|
||||
}
|
||||
off(e) {
|
||||
delete this.eventListeners[e];
|
||||
}
|
||||
// Private methods
|
||||
handleMessage(e) {
|
||||
try {
|
||||
const t = JSON.parse(e);
|
||||
switch (this.log("Received message:", t), this.emit("message", t), t.type) {
|
||||
case "response":
|
||||
this.handleResponse(t);
|
||||
break;
|
||||
case "notification":
|
||||
this.handleNotification(t);
|
||||
break;
|
||||
case "pong":
|
||||
break;
|
||||
default:
|
||||
this.log("Unknown message type:", t.type);
|
||||
}
|
||||
} catch (t) {
|
||||
this.log("Error parsing message:", t);
|
||||
}
|
||||
}
|
||||
handleResponse(e) {
|
||||
const t = this.messageHandlers.get(e.id);
|
||||
t && (t(e), this.messageHandlers.delete(e.id));
|
||||
}
|
||||
handleNotification(e) {
|
||||
const t = this.subscriptions.get(e.subscription_id);
|
||||
t?.callback && t.callback(e);
|
||||
}
|
||||
send(e) {
|
||||
if (!this.ws || this.ws.readyState !== WebSocket.OPEN)
|
||||
throw new Error("WebSocket is not connected");
|
||||
const t = JSON.stringify(e);
|
||||
this.log("Sending message:", e), this.ws.send(t);
|
||||
}
|
||||
startHeartbeat() {
|
||||
this.heartbeatTimer || (this.heartbeatTimer = setInterval(() => {
|
||||
if (this.isConnected()) {
|
||||
const e = {
|
||||
id: l(),
|
||||
type: "ping"
|
||||
};
|
||||
this.send(e);
|
||||
}
|
||||
}, this.config.heartbeatInterval));
|
||||
}
|
||||
stopHeartbeat() {
|
||||
this.heartbeatTimer && (clearInterval(this.heartbeatTimer), this.heartbeatTimer = null);
|
||||
}
|
||||
setState(e) {
|
||||
this.state !== e && (this.state = e, this.emit("stateChange", e));
|
||||
}
|
||||
ensureConnected() {
|
||||
if (!this.isConnected())
|
||||
throw new Error("WebSocket is not connected. Call connect() first.");
|
||||
}
|
||||
emit(e, ...t) {
|
||||
const s = this.eventListeners[e];
|
||||
s && s(...t);
|
||||
}
|
||||
log(...e) {
|
||||
this.config.debug && console.log("[WebSocketClient]", ...e);
|
||||
}
|
||||
}
|
||||
function v(n) {
|
||||
return typeof btoa == "function" ? "ZIP_" + btoa(n) : "ZIP_" + Buffer.from(n, "utf-8").toString("base64");
|
||||
}
|
||||
function w(n) {
|
||||
let e = n;
|
||||
return e.startsWith("ZIP_") ? (e = e.slice(4).replace(/[\n\r ]/g, ""), e = m(e)) : e.startsWith("__") && (e = e.slice(2).replace(/[\n\r ]/g, ""), e = m(e)), (e.startsWith("ZIP_") || e.startsWith("__")) && (e = w(e)), e;
|
||||
}
|
||||
function m(n) {
|
||||
return typeof atob == "function" ? atob(n) : Buffer.from(n, "base64").toString("utf-8");
|
||||
}
|
||||
function u(n) {
|
||||
const e = {};
|
||||
if (n.columns?.length && (e["X-Select-Fields"] = n.columns.join(",")), n.omit_columns?.length && (e["X-Not-Select-Fields"] = n.omit_columns.join(",")), n.filters?.length)
|
||||
for (const t of n.filters) {
|
||||
const s = t.logic_operator ?? "AND", r = y(t.operator), i = S(t);
|
||||
t.operator === "eq" && s === "AND" ? e[`X-FieldFilter-${t.column}`] = i : s === "OR" ? e[`X-SearchOr-${r}-${t.column}`] = i : e[`X-SearchOp-${r}-${t.column}`] = i;
|
||||
}
|
||||
if (n.sort?.length) {
|
||||
const t = n.sort.map((s) => s.direction.toUpperCase() === "DESC" ? `-${s.column}` : `+${s.column}`);
|
||||
e["X-Sort"] = t.join(",");
|
||||
}
|
||||
if (n.limit !== void 0 && (e["X-Limit"] = String(n.limit)), n.offset !== void 0 && (e["X-Offset"] = String(n.offset)), n.cursor_forward && (e["X-Cursor-Forward"] = n.cursor_forward), n.cursor_backward && (e["X-Cursor-Backward"] = n.cursor_backward), n.preload?.length) {
|
||||
const t = n.preload.map((s) => s.columns?.length ? `${s.relation}:${s.columns.join(",")}` : s.relation);
|
||||
e["X-Preload"] = t.join("|");
|
||||
}
|
||||
if (n.fetch_row_number && (e["X-Fetch-RowNumber"] = n.fetch_row_number), n.computedColumns?.length)
|
||||
for (const t of n.computedColumns)
|
||||
e[`X-CQL-SEL-${t.name}`] = t.expression;
|
||||
if (n.customOperators?.length) {
|
||||
const t = n.customOperators.map(
|
||||
(s) => s.sql
|
||||
);
|
||||
e["X-Custom-SQL-W"] = t.join(" AND ");
|
||||
}
|
||||
return e;
|
||||
}
|
||||
function y(n) {
|
||||
switch (n) {
|
||||
case "eq":
|
||||
return "equals";
|
||||
case "neq":
|
||||
return "notequals";
|
||||
case "gt":
|
||||
return "greaterthan";
|
||||
case "gte":
|
||||
return "greaterthanorequal";
|
||||
case "lt":
|
||||
return "lessthan";
|
||||
case "lte":
|
||||
return "lessthanorequal";
|
||||
case "like":
|
||||
case "ilike":
|
||||
case "contains":
|
||||
return "contains";
|
||||
case "startswith":
|
||||
return "beginswith";
|
||||
case "endswith":
|
||||
return "endswith";
|
||||
case "in":
|
||||
return "in";
|
||||
case "between":
|
||||
return "between";
|
||||
case "between_inclusive":
|
||||
return "betweeninclusive";
|
||||
case "is_null":
|
||||
return "empty";
|
||||
case "is_not_null":
|
||||
return "notempty";
|
||||
default:
|
||||
return n;
|
||||
}
|
||||
}
|
||||
function S(n) {
|
||||
return n.value === null || n.value === void 0 ? "" : Array.isArray(n.value) ? n.value.join(",") : String(n.value);
|
||||
}
|
||||
const b = /* @__PURE__ */ new Map();
|
||||
function C(n) {
|
||||
const e = n.baseUrl;
|
||||
let t = b.get(e);
|
||||
return t || (t = new H(n), b.set(e, t)), t;
|
||||
}
|
||||
class H {
|
||||
constructor(e) {
|
||||
this.config = e;
|
||||
}
|
||||
buildUrl(e, t, s) {
|
||||
let r = `${this.config.baseUrl}/${e}/${t}`;
|
||||
return s && (r += `/${s}`), r;
|
||||
}
|
||||
baseHeaders() {
|
||||
const e = {
|
||||
"Content-Type": "application/json"
|
||||
};
|
||||
return this.config.token && (e.Authorization = `Bearer ${this.config.token}`), e;
|
||||
}
|
||||
async fetchWithError(e, t) {
|
||||
const s = await fetch(e, t), r = await s.json();
|
||||
if (!s.ok)
|
||||
throw new Error(
|
||||
r.error?.message || `${s.statusText} (${s.status})`
|
||||
);
|
||||
return {
|
||||
data: r,
|
||||
success: !0,
|
||||
error: r.error ? r.error : void 0,
|
||||
metadata: {
|
||||
count: s.headers.get("content-range") ? Number(s.headers.get("content-range")?.split("/")[1]) : 0,
|
||||
total: s.headers.get("content-range") ? Number(s.headers.get("content-range")?.split("/")[1]) : 0,
|
||||
filtered: s.headers.get("content-range") ? Number(s.headers.get("content-range")?.split("/")[1]) : 0,
|
||||
offset: s.headers.get("content-range") ? Number(
|
||||
s.headers.get("content-range")?.split("/")[0].split("-")[0]
|
||||
) : 0,
|
||||
limit: s.headers.get("x-limit") ? Number(s.headers.get("x-limit")) : 0
|
||||
}
|
||||
};
|
||||
}
|
||||
async read(e, t, s, r) {
|
||||
const i = this.buildUrl(e, t, s), a = r ? u(r) : {};
|
||||
return this.fetchWithError(i, {
|
||||
method: "GET",
|
||||
headers: { ...this.baseHeaders(), ...a }
|
||||
});
|
||||
}
|
||||
async create(e, t, s, r) {
|
||||
const i = this.buildUrl(e, t), a = r ? u(r) : {};
|
||||
return this.fetchWithError(i, {
|
||||
method: "POST",
|
||||
headers: { ...this.baseHeaders(), ...a },
|
||||
body: JSON.stringify(s)
|
||||
});
|
||||
}
|
||||
async update(e, t, s, r, i) {
|
||||
const a = this.buildUrl(e, t, s), c = i ? u(i) : {};
|
||||
return this.fetchWithError(a, {
|
||||
method: "PUT",
|
||||
headers: { ...this.baseHeaders(), ...c },
|
||||
body: JSON.stringify(r)
|
||||
});
|
||||
}
|
||||
async delete(e, t, s) {
|
||||
const r = this.buildUrl(e, t, s);
|
||||
return this.fetchWithError(r, {
|
||||
method: "DELETE",
|
||||
headers: this.baseHeaders()
|
||||
});
|
||||
}
|
||||
}
|
||||
export {
|
||||
H as HeaderSpecClient,
|
||||
g as ResolveSpecClient,
|
||||
p as WebSocketClient,
|
||||
u as buildHeaders,
|
||||
w as decodeHeaderValue,
|
||||
v as encodeHeaderValue,
|
||||
C as getHeaderSpecClient,
|
||||
E as getResolveSpecClient,
|
||||
_ as getWebSocketClient
|
||||
};
|
||||
@@ -1,20 +1,23 @@
|
||||
{
|
||||
"name": "@warkypublic/resolvespec-js",
|
||||
"version": "1.0.0",
|
||||
"description": "Client side library for the ResolveSpec API",
|
||||
"version": "1.0.1",
|
||||
"description": "TypeScript client library for ResolveSpec REST, HeaderSpec, and WebSocket APIs",
|
||||
"type": "module",
|
||||
"main": "./src/index.ts",
|
||||
"module": "./src/index.ts",
|
||||
"types": "./src/index.ts",
|
||||
"main": "./dist/index.cjs",
|
||||
"module": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js",
|
||||
"require": "./dist/index.cjs"
|
||||
}
|
||||
},
|
||||
"publishConfig": {
|
||||
"access": "public",
|
||||
"main": "./dist/index.js",
|
||||
"module": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts"
|
||||
"access": "public"
|
||||
},
|
||||
"files": [
|
||||
"dist",
|
||||
"bin",
|
||||
"README.md"
|
||||
],
|
||||
"scripts": {
|
||||
@@ -25,38 +28,33 @@
|
||||
"lint": "eslint src"
|
||||
},
|
||||
"keywords": [
|
||||
"string",
|
||||
"blob",
|
||||
"dependencies",
|
||||
"workspace",
|
||||
"package",
|
||||
"cli",
|
||||
"tools",
|
||||
"npm",
|
||||
"yarn",
|
||||
"pnpm"
|
||||
"resolvespec",
|
||||
"headerspec",
|
||||
"websocket",
|
||||
"rest-client",
|
||||
"typescript",
|
||||
"api-client"
|
||||
],
|
||||
"author": "Hein (Warkanum) Puth",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"semver": "^7.6.3",
|
||||
"uuid": "^11.0.3"
|
||||
"uuid": "^13.0.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@changesets/cli": "^2.27.10",
|
||||
"@eslint/js": "^9.16.0",
|
||||
"@types/jsdom": "^21.1.7",
|
||||
"eslint": "^9.16.0",
|
||||
"globals": "^15.13.0",
|
||||
"jsdom": "^25.0.1",
|
||||
"typescript": "^5.7.2",
|
||||
"typescript-eslint": "^8.17.0",
|
||||
"vite": "^6.0.2",
|
||||
"vite-plugin-dts": "^4.3.0",
|
||||
"vitest": "^2.1.8"
|
||||
"@changesets/cli": "^2.29.8",
|
||||
"@eslint/js": "^10.0.1",
|
||||
"@types/jsdom": "^27.0.0",
|
||||
"eslint": "^10.0.0",
|
||||
"globals": "^17.3.0",
|
||||
"jsdom": "^28.1.0",
|
||||
"typescript": "^5.9.3",
|
||||
"typescript-eslint": "^8.55.0",
|
||||
"vite": "^7.3.1",
|
||||
"vite-plugin-dts": "^4.5.4",
|
||||
"vitest": "^4.0.18"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=14.16"
|
||||
"node": ">=18"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
|
||||
3376
resolvespec-js/pnpm-lock.yaml
generated
Normal file
3376
resolvespec-js/pnpm-lock.yaml
generated
Normal file
File diff suppressed because it is too large
Load Diff
143
resolvespec-js/src/__tests__/common.test.ts
Normal file
143
resolvespec-js/src/__tests__/common.test.ts
Normal file
@@ -0,0 +1,143 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import type {
|
||||
Options,
|
||||
FilterOption,
|
||||
SortOption,
|
||||
PreloadOption,
|
||||
RequestBody,
|
||||
APIResponse,
|
||||
Metadata,
|
||||
APIError,
|
||||
Parameter,
|
||||
ComputedColumn,
|
||||
CustomOperator,
|
||||
} from '../common/types';
|
||||
|
||||
describe('Common Types', () => {
|
||||
it('should construct a valid FilterOption with logic_operator', () => {
|
||||
const filter: FilterOption = {
|
||||
column: 'name',
|
||||
operator: 'eq',
|
||||
value: 'test',
|
||||
logic_operator: 'OR',
|
||||
};
|
||||
expect(filter.logic_operator).toBe('OR');
|
||||
expect(filter.operator).toBe('eq');
|
||||
});
|
||||
|
||||
it('should construct Options with all new fields', () => {
|
||||
const opts: Options = {
|
||||
columns: ['id', 'name'],
|
||||
omit_columns: ['secret'],
|
||||
filters: [{ column: 'age', operator: 'gte', value: 18 }],
|
||||
sort: [{ column: 'name', direction: 'asc' }],
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
cursor_forward: 'abc123',
|
||||
cursor_backward: 'xyz789',
|
||||
fetch_row_number: '42',
|
||||
parameters: [{ name: 'param1', value: 'val1', sequence: 1 }],
|
||||
computedColumns: [{ name: 'full_name', expression: "first || ' ' || last" }],
|
||||
customOperators: [{ name: 'custom', sql: "status = 'active'" }],
|
||||
preload: [{
|
||||
relation: 'Items',
|
||||
columns: ['id', 'title'],
|
||||
omit_columns: ['internal'],
|
||||
sort: [{ column: 'id', direction: 'ASC' }],
|
||||
recursive: true,
|
||||
primary_key: 'id',
|
||||
related_key: 'parent_id',
|
||||
sql_joins: ['LEFT JOIN other ON other.id = items.other_id'],
|
||||
join_aliases: ['other'],
|
||||
}],
|
||||
};
|
||||
expect(opts.omit_columns).toEqual(['secret']);
|
||||
expect(opts.cursor_forward).toBe('abc123');
|
||||
expect(opts.fetch_row_number).toBe('42');
|
||||
expect(opts.parameters![0].sequence).toBe(1);
|
||||
expect(opts.preload![0].recursive).toBe(true);
|
||||
});
|
||||
|
||||
it('should construct a RequestBody with numeric id', () => {
|
||||
const body: RequestBody = {
|
||||
operation: 'read',
|
||||
id: 42,
|
||||
options: { limit: 10 },
|
||||
};
|
||||
expect(body.id).toBe(42);
|
||||
});
|
||||
|
||||
it('should construct a RequestBody with string array id', () => {
|
||||
const body: RequestBody = {
|
||||
operation: 'delete',
|
||||
id: ['1', '2', '3'],
|
||||
};
|
||||
expect(Array.isArray(body.id)).toBe(true);
|
||||
});
|
||||
|
||||
it('should construct Metadata with count and row_number', () => {
|
||||
const meta: Metadata = {
|
||||
total: 100,
|
||||
count: 10,
|
||||
filtered: 50,
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
row_number: 5,
|
||||
};
|
||||
expect(meta.count).toBe(10);
|
||||
expect(meta.row_number).toBe(5);
|
||||
});
|
||||
|
||||
it('should construct APIError with detail field', () => {
|
||||
const err: APIError = {
|
||||
code: 'not_found',
|
||||
message: 'Record not found',
|
||||
detail: 'The record with id 42 does not exist',
|
||||
};
|
||||
expect(err.detail).toBeDefined();
|
||||
});
|
||||
|
||||
it('should construct APIResponse with metadata', () => {
|
||||
const resp: APIResponse<string[]> = {
|
||||
success: true,
|
||||
data: ['a', 'b'],
|
||||
metadata: { total: 2, count: 2, filtered: 2, limit: 10, offset: 0 },
|
||||
};
|
||||
expect(resp.metadata?.count).toBe(2);
|
||||
});
|
||||
|
||||
it('should support all operator types', () => {
|
||||
const operators: FilterOption['operator'][] = [
|
||||
'eq', 'neq', 'gt', 'gte', 'lt', 'lte',
|
||||
'like', 'ilike', 'in',
|
||||
'contains', 'startswith', 'endswith',
|
||||
'between', 'between_inclusive',
|
||||
'is_null', 'is_not_null',
|
||||
];
|
||||
for (const op of operators) {
|
||||
const f: FilterOption = { column: 'x', operator: op, value: 'v' };
|
||||
expect(f.operator).toBe(op);
|
||||
}
|
||||
});
|
||||
|
||||
it('should support PreloadOption with computed_ql and where', () => {
|
||||
const preload: PreloadOption = {
|
||||
relation: 'Details',
|
||||
where: "status = 'active'",
|
||||
computed_ql: { cql1: 'SUM(amount)' },
|
||||
table_name: 'detail_table',
|
||||
updatable: true,
|
||||
foreign_key: 'detail_id',
|
||||
recursive_child_key: 'parent_detail_id',
|
||||
};
|
||||
expect(preload.computed_ql?.cql1).toBe('SUM(amount)');
|
||||
expect(preload.updatable).toBe(true);
|
||||
});
|
||||
|
||||
it('should support Parameter interface', () => {
|
||||
const p: Parameter = { name: 'key', value: 'val' };
|
||||
expect(p.name).toBe('key');
|
||||
const p2: Parameter = { name: 'key2', value: 'val2', sequence: 5 };
|
||||
expect(p2.sequence).toBe(5);
|
||||
});
|
||||
});
|
||||
239
resolvespec-js/src/__tests__/headerspec.test.ts
Normal file
239
resolvespec-js/src/__tests__/headerspec.test.ts
Normal file
@@ -0,0 +1,239 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { buildHeaders, encodeHeaderValue, decodeHeaderValue, HeaderSpecClient, getHeaderSpecClient } from '../headerspec/client';
|
||||
import type { Options, ClientConfig, APIResponse } from '../common/types';
|
||||
|
||||
describe('buildHeaders', () => {
|
||||
it('should set X-Select-Fields for columns', () => {
|
||||
const h = buildHeaders({ columns: ['id', 'name', 'email'] });
|
||||
expect(h['X-Select-Fields']).toBe('id,name,email');
|
||||
});
|
||||
|
||||
it('should set X-Not-Select-Fields for omit_columns', () => {
|
||||
const h = buildHeaders({ omit_columns: ['secret', 'internal'] });
|
||||
expect(h['X-Not-Select-Fields']).toBe('secret,internal');
|
||||
});
|
||||
|
||||
it('should set X-FieldFilter for eq AND filters', () => {
|
||||
const h = buildHeaders({
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }],
|
||||
});
|
||||
expect(h['X-FieldFilter-status']).toBe('active');
|
||||
});
|
||||
|
||||
it('should set X-SearchOp for non-eq AND filters', () => {
|
||||
const h = buildHeaders({
|
||||
filters: [{ column: 'age', operator: 'gte', value: 18 }],
|
||||
});
|
||||
expect(h['X-SearchOp-greaterthanorequal-age']).toBe('18');
|
||||
});
|
||||
|
||||
it('should set X-SearchOr for OR filters', () => {
|
||||
const h = buildHeaders({
|
||||
filters: [{ column: 'name', operator: 'contains', value: 'test', logic_operator: 'OR' }],
|
||||
});
|
||||
expect(h['X-SearchOr-contains-name']).toBe('test');
|
||||
});
|
||||
|
||||
it('should set X-Sort with direction prefixes', () => {
|
||||
const h = buildHeaders({
|
||||
sort: [
|
||||
{ column: 'name', direction: 'asc' },
|
||||
{ column: 'created_at', direction: 'DESC' },
|
||||
],
|
||||
});
|
||||
expect(h['X-Sort']).toBe('+name,-created_at');
|
||||
});
|
||||
|
||||
it('should set X-Limit and X-Offset', () => {
|
||||
const h = buildHeaders({ limit: 25, offset: 50 });
|
||||
expect(h['X-Limit']).toBe('25');
|
||||
expect(h['X-Offset']).toBe('50');
|
||||
});
|
||||
|
||||
it('should set cursor pagination headers', () => {
|
||||
const h = buildHeaders({ cursor_forward: 'abc', cursor_backward: 'xyz' });
|
||||
expect(h['X-Cursor-Forward']).toBe('abc');
|
||||
expect(h['X-Cursor-Backward']).toBe('xyz');
|
||||
});
|
||||
|
||||
it('should set X-Preload with pipe-separated relations', () => {
|
||||
const h = buildHeaders({
|
||||
preload: [
|
||||
{ relation: 'Items', columns: ['id', 'name'] },
|
||||
{ relation: 'Category' },
|
||||
],
|
||||
});
|
||||
expect(h['X-Preload']).toBe('Items:id,name|Category');
|
||||
});
|
||||
|
||||
it('should set X-Fetch-RowNumber', () => {
|
||||
const h = buildHeaders({ fetch_row_number: '42' });
|
||||
expect(h['X-Fetch-RowNumber']).toBe('42');
|
||||
});
|
||||
|
||||
it('should set X-CQL-SEL for computed columns', () => {
|
||||
const h = buildHeaders({
|
||||
computedColumns: [
|
||||
{ name: 'total', expression: 'price * qty' },
|
||||
],
|
||||
});
|
||||
expect(h['X-CQL-SEL-total']).toBe('price * qty');
|
||||
});
|
||||
|
||||
it('should set X-Custom-SQL-W for custom operators', () => {
|
||||
const h = buildHeaders({
|
||||
customOperators: [
|
||||
{ name: 'active', sql: "status = 'active'" },
|
||||
{ name: 'verified', sql: "verified = true" },
|
||||
],
|
||||
});
|
||||
expect(h['X-Custom-SQL-W']).toBe("status = 'active' AND verified = true");
|
||||
});
|
||||
|
||||
it('should return empty object for empty options', () => {
|
||||
const h = buildHeaders({});
|
||||
expect(Object.keys(h)).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should handle between filter with array value', () => {
|
||||
const h = buildHeaders({
|
||||
filters: [{ column: 'price', operator: 'between', value: [10, 100] }],
|
||||
});
|
||||
expect(h['X-SearchOp-between-price']).toBe('10,100');
|
||||
});
|
||||
|
||||
it('should handle is_null filter with null value', () => {
|
||||
const h = buildHeaders({
|
||||
filters: [{ column: 'deleted_at', operator: 'is_null', value: null }],
|
||||
});
|
||||
expect(h['X-SearchOp-empty-deleted_at']).toBe('');
|
||||
});
|
||||
|
||||
it('should handle in filter with array value', () => {
|
||||
const h = buildHeaders({
|
||||
filters: [{ column: 'id', operator: 'in', value: [1, 2, 3] }],
|
||||
});
|
||||
expect(h['X-SearchOp-in-id']).toBe('1,2,3');
|
||||
});
|
||||
});
|
||||
|
||||
describe('encodeHeaderValue / decodeHeaderValue', () => {
|
||||
it('should round-trip encode/decode', () => {
|
||||
const original = 'some complex value with spaces & symbols!';
|
||||
const encoded = encodeHeaderValue(original);
|
||||
expect(encoded.startsWith('ZIP_')).toBe(true);
|
||||
const decoded = decodeHeaderValue(encoded);
|
||||
expect(decoded).toBe(original);
|
||||
});
|
||||
|
||||
it('should decode __ prefixed values', () => {
|
||||
const encoded = '__' + btoa('hello');
|
||||
expect(decodeHeaderValue(encoded)).toBe('hello');
|
||||
});
|
||||
|
||||
it('should return plain values as-is', () => {
|
||||
expect(decodeHeaderValue('plain')).toBe('plain');
|
||||
});
|
||||
});
|
||||
|
||||
describe('HeaderSpecClient', () => {
|
||||
const config: ClientConfig = { baseUrl: 'http://localhost:3000', token: 'tok' };
|
||||
|
||||
function mockFetch<T>(data: APIResponse<T>, ok = true) {
|
||||
return vi.fn().mockResolvedValue({
|
||||
ok,
|
||||
json: () => Promise.resolve(data),
|
||||
});
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('read() sends GET with headers from options', async () => {
|
||||
globalThis.fetch = mockFetch({ success: true, data: [{ id: 1 }] });
|
||||
const client = new HeaderSpecClient(config);
|
||||
|
||||
await client.read('public', 'users', undefined, {
|
||||
columns: ['id', 'name'],
|
||||
limit: 10,
|
||||
});
|
||||
|
||||
const [url, opts] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users');
|
||||
expect(opts.method).toBe('GET');
|
||||
expect(opts.headers['X-Select-Fields']).toBe('id,name');
|
||||
expect(opts.headers['X-Limit']).toBe('10');
|
||||
expect(opts.headers['Authorization']).toBe('Bearer tok');
|
||||
});
|
||||
|
||||
it('read() with id appends to URL', async () => {
|
||||
globalThis.fetch = mockFetch({ success: true, data: {} });
|
||||
const client = new HeaderSpecClient(config);
|
||||
|
||||
await client.read('public', 'users', '42');
|
||||
|
||||
const [url] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users/42');
|
||||
});
|
||||
|
||||
it('create() sends POST with body and headers', async () => {
|
||||
globalThis.fetch = mockFetch({ success: true, data: { id: 1 } });
|
||||
const client = new HeaderSpecClient(config);
|
||||
|
||||
await client.create('public', 'users', { name: 'Test' });
|
||||
|
||||
const [url, opts] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(opts.method).toBe('POST');
|
||||
expect(JSON.parse(opts.body)).toEqual({ name: 'Test' });
|
||||
});
|
||||
|
||||
it('update() sends PUT with id in URL', async () => {
|
||||
globalThis.fetch = mockFetch({ success: true, data: {} });
|
||||
const client = new HeaderSpecClient(config);
|
||||
|
||||
await client.update('public', 'users', '1', { name: 'Updated' }, {
|
||||
filters: [{ column: 'active', operator: 'eq', value: true }],
|
||||
});
|
||||
|
||||
const [url, opts] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users/1');
|
||||
expect(opts.method).toBe('PUT');
|
||||
expect(opts.headers['X-FieldFilter-active']).toBe('true');
|
||||
});
|
||||
|
||||
it('delete() sends DELETE', async () => {
|
||||
globalThis.fetch = mockFetch({ success: true, data: undefined as any });
|
||||
const client = new HeaderSpecClient(config);
|
||||
|
||||
await client.delete('public', 'users', '1');
|
||||
|
||||
const [url, opts] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users/1');
|
||||
expect(opts.method).toBe('DELETE');
|
||||
});
|
||||
|
||||
it('throws on non-ok response', async () => {
|
||||
globalThis.fetch = mockFetch(
|
||||
{ success: false, data: null as any, error: { code: 'err', message: 'fail' } },
|
||||
false
|
||||
);
|
||||
const client = new HeaderSpecClient(config);
|
||||
|
||||
await expect(client.read('public', 'users')).rejects.toThrow('fail');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getHeaderSpecClient singleton', () => {
|
||||
it('returns same instance for same baseUrl', () => {
|
||||
const a = getHeaderSpecClient({ baseUrl: 'http://hs-singleton:3000' });
|
||||
const b = getHeaderSpecClient({ baseUrl: 'http://hs-singleton:3000' });
|
||||
expect(a).toBe(b);
|
||||
});
|
||||
|
||||
it('returns different instances for different baseUrls', () => {
|
||||
const a = getHeaderSpecClient({ baseUrl: 'http://hs-singleton-a:3000' });
|
||||
const b = getHeaderSpecClient({ baseUrl: 'http://hs-singleton-b:3000' });
|
||||
expect(a).not.toBe(b);
|
||||
});
|
||||
});
|
||||
178
resolvespec-js/src/__tests__/resolvespec.test.ts
Normal file
178
resolvespec-js/src/__tests__/resolvespec.test.ts
Normal file
@@ -0,0 +1,178 @@
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ResolveSpecClient, getResolveSpecClient } from '../resolvespec/client';
|
||||
import type { ClientConfig, APIResponse } from '../common/types';
|
||||
|
||||
const config: ClientConfig = { baseUrl: 'http://localhost:3000', token: 'test-token' };
|
||||
|
||||
function mockFetchResponse<T>(data: APIResponse<T>, ok = true, status = 200) {
|
||||
return vi.fn().mockResolvedValue({
|
||||
ok,
|
||||
status,
|
||||
json: () => Promise.resolve(data),
|
||||
});
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('ResolveSpecClient', () => {
|
||||
it('read() sends POST with operation read', async () => {
|
||||
const response: APIResponse = { success: true, data: [{ id: 1 }] };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
const result = await client.read('public', 'users', 1);
|
||||
expect(result.success).toBe(true);
|
||||
|
||||
const [url, opts] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users/1');
|
||||
expect(opts.method).toBe('POST');
|
||||
expect(opts.headers['Authorization']).toBe('Bearer test-token');
|
||||
|
||||
const body = JSON.parse(opts.body);
|
||||
expect(body.operation).toBe('read');
|
||||
});
|
||||
|
||||
it('read() with string array id puts id in body', async () => {
|
||||
const response: APIResponse = { success: true, data: [] };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await client.read('public', 'users', ['1', '2']);
|
||||
const body = JSON.parse((globalThis.fetch as any).mock.calls[0][1].body);
|
||||
expect(body.id).toEqual(['1', '2']);
|
||||
});
|
||||
|
||||
it('read() passes options through', async () => {
|
||||
const response: APIResponse = { success: true, data: [] };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await client.read('public', 'users', undefined, {
|
||||
columns: ['id', 'name'],
|
||||
omit_columns: ['secret'],
|
||||
filters: [{ column: 'active', operator: 'eq', value: true }],
|
||||
sort: [{ column: 'name', direction: 'asc' }],
|
||||
limit: 10,
|
||||
offset: 0,
|
||||
cursor_forward: 'cursor1',
|
||||
fetch_row_number: '5',
|
||||
});
|
||||
|
||||
const body = JSON.parse((globalThis.fetch as any).mock.calls[0][1].body);
|
||||
expect(body.options.columns).toEqual(['id', 'name']);
|
||||
expect(body.options.omit_columns).toEqual(['secret']);
|
||||
expect(body.options.cursor_forward).toBe('cursor1');
|
||||
expect(body.options.fetch_row_number).toBe('5');
|
||||
});
|
||||
|
||||
it('create() sends POST with operation create and data', async () => {
|
||||
const response: APIResponse = { success: true, data: { id: 1, name: 'Test' } };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
const result = await client.create('public', 'users', { name: 'Test' });
|
||||
expect(result.data.name).toBe('Test');
|
||||
|
||||
const body = JSON.parse((globalThis.fetch as any).mock.calls[0][1].body);
|
||||
expect(body.operation).toBe('create');
|
||||
expect(body.data.name).toBe('Test');
|
||||
});
|
||||
|
||||
it('update() with single id puts id in URL', async () => {
|
||||
const response: APIResponse = { success: true, data: { id: 1 } };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await client.update('public', 'users', { name: 'Updated' }, 1);
|
||||
const [url] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users/1');
|
||||
});
|
||||
|
||||
it('update() with string array id puts id in body', async () => {
|
||||
const response: APIResponse = { success: true, data: {} };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await client.update('public', 'users', { active: false }, ['1', '2']);
|
||||
const body = JSON.parse((globalThis.fetch as any).mock.calls[0][1].body);
|
||||
expect(body.id).toEqual(['1', '2']);
|
||||
});
|
||||
|
||||
it('delete() sends POST with operation delete', async () => {
|
||||
const response: APIResponse<void> = { success: true, data: undefined as any };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await client.delete('public', 'users', 1);
|
||||
const [url, opts] = (globalThis.fetch as any).mock.calls[0];
|
||||
expect(url).toBe('http://localhost:3000/public/users/1');
|
||||
|
||||
const body = JSON.parse(opts.body);
|
||||
expect(body.operation).toBe('delete');
|
||||
});
|
||||
|
||||
it('getMetadata() sends GET request', async () => {
|
||||
const response: APIResponse = {
|
||||
success: true,
|
||||
data: { schema: 'public', table: 'users', columns: [], relations: [] },
|
||||
};
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
const result = await client.getMetadata('public', 'users');
|
||||
expect(result.data.table).toBe('users');
|
||||
|
||||
const opts = (globalThis.fetch as any).mock.calls[0][1];
|
||||
expect(opts.method).toBe('GET');
|
||||
});
|
||||
|
||||
it('throws on non-ok response', async () => {
|
||||
const errorResp = {
|
||||
success: false,
|
||||
data: null,
|
||||
error: { code: 'not_found', message: 'Not found' },
|
||||
};
|
||||
globalThis.fetch = mockFetchResponse(errorResp as any, false, 404);
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await expect(client.read('public', 'users', 999)).rejects.toThrow('Not found');
|
||||
});
|
||||
|
||||
it('throws generic error when no error message', async () => {
|
||||
globalThis.fetch = vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
json: () => Promise.resolve({ success: false, data: null }),
|
||||
});
|
||||
|
||||
const client = new ResolveSpecClient(config);
|
||||
await expect(client.read('public', 'users')).rejects.toThrow('An error occurred');
|
||||
});
|
||||
|
||||
it('config without token omits Authorization header', async () => {
|
||||
const noAuthConfig: ClientConfig = { baseUrl: 'http://localhost:3000' };
|
||||
const response: APIResponse = { success: true, data: [] };
|
||||
globalThis.fetch = mockFetchResponse(response);
|
||||
|
||||
const client = new ResolveSpecClient(noAuthConfig);
|
||||
await client.read('public', 'users');
|
||||
const opts = (globalThis.fetch as any).mock.calls[0][1];
|
||||
expect(opts.headers['Authorization']).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getResolveSpecClient singleton', () => {
|
||||
it('returns same instance for same baseUrl', () => {
|
||||
const a = getResolveSpecClient({ baseUrl: 'http://singleton-test:3000' });
|
||||
const b = getResolveSpecClient({ baseUrl: 'http://singleton-test:3000' });
|
||||
expect(a).toBe(b);
|
||||
});
|
||||
|
||||
it('returns different instances for different baseUrls', () => {
|
||||
const a = getResolveSpecClient({ baseUrl: 'http://singleton-a:3000' });
|
||||
const b = getResolveSpecClient({ baseUrl: 'http://singleton-b:3000' });
|
||||
expect(a).not.toBe(b);
|
||||
});
|
||||
});
|
||||
336
resolvespec-js/src/__tests__/websocketspec.test.ts
Normal file
336
resolvespec-js/src/__tests__/websocketspec.test.ts
Normal file
@@ -0,0 +1,336 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { WebSocketClient, getWebSocketClient } from '../websocketspec/client';
|
||||
import type { WebSocketClientConfig } from '../websocketspec/types';
|
||||
|
||||
// Mock uuid
|
||||
vi.mock('uuid', () => ({
|
||||
v4: vi.fn(() => 'mock-uuid-1234'),
|
||||
}));
|
||||
|
||||
// Mock WebSocket
|
||||
class MockWebSocket {
|
||||
static OPEN = 1;
|
||||
static CLOSED = 3;
|
||||
|
||||
url: string;
|
||||
readyState = MockWebSocket.OPEN;
|
||||
onopen: ((ev: any) => void) | null = null;
|
||||
onclose: ((ev: any) => void) | null = null;
|
||||
onmessage: ((ev: any) => void) | null = null;
|
||||
onerror: ((ev: any) => void) | null = null;
|
||||
|
||||
private sentMessages: string[] = [];
|
||||
|
||||
constructor(url: string) {
|
||||
this.url = url;
|
||||
// Simulate async open
|
||||
setTimeout(() => {
|
||||
this.onopen?.({});
|
||||
}, 0);
|
||||
}
|
||||
|
||||
send(data: string) {
|
||||
this.sentMessages.push(data);
|
||||
}
|
||||
|
||||
close() {
|
||||
this.readyState = MockWebSocket.CLOSED;
|
||||
this.onclose?.({ code: 1000, reason: 'Normal closure' } as any);
|
||||
}
|
||||
|
||||
getSentMessages(): any[] {
|
||||
return this.sentMessages.map((m) => JSON.parse(m));
|
||||
}
|
||||
|
||||
simulateMessage(data: any) {
|
||||
this.onmessage?.({ data: JSON.stringify(data) });
|
||||
}
|
||||
}
|
||||
|
||||
let mockWsInstance: MockWebSocket | null = null;
|
||||
|
||||
beforeEach(() => {
|
||||
mockWsInstance = null;
|
||||
(globalThis as any).WebSocket = class extends MockWebSocket {
|
||||
constructor(url: string) {
|
||||
super(url);
|
||||
mockWsInstance = this;
|
||||
}
|
||||
};
|
||||
(globalThis as any).WebSocket.OPEN = MockWebSocket.OPEN;
|
||||
(globalThis as any).WebSocket.CLOSED = MockWebSocket.CLOSED;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('WebSocketClient', () => {
|
||||
const wsConfig: WebSocketClientConfig = {
|
||||
url: 'ws://localhost:8080',
|
||||
reconnect: false,
|
||||
heartbeatInterval: 60000,
|
||||
};
|
||||
|
||||
it('should connect and set state to connected', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
expect(client.getState()).toBe('connected');
|
||||
expect(client.isConnected()).toBe(true);
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should disconnect and set state to disconnected', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
client.disconnect();
|
||||
expect(client.getState()).toBe('disconnected');
|
||||
expect(client.isConnected()).toBe(false);
|
||||
});
|
||||
|
||||
it('should send read request', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
const readPromise = client.read('users', {
|
||||
schema: 'public',
|
||||
filters: [{ column: 'active', operator: 'eq', value: true }],
|
||||
limit: 10,
|
||||
});
|
||||
|
||||
// Simulate server response
|
||||
const sent = mockWsInstance!.getSentMessages();
|
||||
expect(sent.length).toBe(1);
|
||||
expect(sent[0].operation).toBe('read');
|
||||
expect(sent[0].entity).toBe('users');
|
||||
expect(sent[0].options.filters[0].column).toBe('active');
|
||||
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
data: [{ id: 1 }],
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const result = await readPromise;
|
||||
expect(result).toEqual([{ id: 1 }]);
|
||||
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should send create request', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
const createPromise = client.create('users', { name: 'Test' }, { schema: 'public' });
|
||||
|
||||
const sent = mockWsInstance!.getSentMessages();
|
||||
expect(sent[0].operation).toBe('create');
|
||||
expect(sent[0].data.name).toBe('Test');
|
||||
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
data: { id: 1, name: 'Test' },
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const result = await createPromise;
|
||||
expect(result.name).toBe('Test');
|
||||
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should send update request with record_id', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
const updatePromise = client.update('users', '1', { name: 'Updated' });
|
||||
|
||||
const sent = mockWsInstance!.getSentMessages();
|
||||
expect(sent[0].operation).toBe('update');
|
||||
expect(sent[0].record_id).toBe('1');
|
||||
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
data: { id: 1, name: 'Updated' },
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
await updatePromise;
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should send delete request', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
const deletePromise = client.delete('users', '1');
|
||||
|
||||
const sent = mockWsInstance!.getSentMessages();
|
||||
expect(sent[0].operation).toBe('delete');
|
||||
expect(sent[0].record_id).toBe('1');
|
||||
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
await deletePromise;
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should reject on failed request', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
const readPromise = client.read('users');
|
||||
|
||||
const sent = mockWsInstance!.getSentMessages();
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: false,
|
||||
error: { code: 'not_found', message: 'Not found' },
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
await expect(readPromise).rejects.toThrow('Not found');
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should handle subscriptions', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
const callback = vi.fn();
|
||||
const subPromise = client.subscribe('users', callback, {
|
||||
schema: 'public',
|
||||
});
|
||||
|
||||
const sent = mockWsInstance!.getSentMessages();
|
||||
expect(sent[0].type).toBe('subscription');
|
||||
expect(sent[0].operation).toBe('subscribe');
|
||||
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
data: { subscription_id: 'sub-1' },
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const subId = await subPromise;
|
||||
expect(subId).toBe('sub-1');
|
||||
expect(client.getSubscriptions()).toHaveLength(1);
|
||||
|
||||
// Simulate notification
|
||||
mockWsInstance!.simulateMessage({
|
||||
type: 'notification',
|
||||
operation: 'create',
|
||||
subscription_id: 'sub-1',
|
||||
entity: 'users',
|
||||
data: { id: 2, name: 'New' },
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
expect(callback).toHaveBeenCalledTimes(1);
|
||||
expect(callback.mock.calls[0][0].data.id).toBe(2);
|
||||
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should handle unsubscribe', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
// Subscribe first
|
||||
const subPromise = client.subscribe('users', vi.fn());
|
||||
let sent = mockWsInstance!.getSentMessages();
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[0].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
data: { subscription_id: 'sub-1' },
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
await subPromise;
|
||||
|
||||
// Unsubscribe
|
||||
const unsubPromise = client.unsubscribe('sub-1');
|
||||
sent = mockWsInstance!.getSentMessages();
|
||||
mockWsInstance!.simulateMessage({
|
||||
id: sent[sent.length - 1].id,
|
||||
type: 'response',
|
||||
success: true,
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
|
||||
await unsubPromise;
|
||||
expect(client.getSubscriptions()).toHaveLength(0);
|
||||
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should emit events', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
const connectCb = vi.fn();
|
||||
const stateChangeCb = vi.fn();
|
||||
|
||||
client.on('connect', connectCb);
|
||||
client.on('stateChange', stateChangeCb);
|
||||
|
||||
await client.connect();
|
||||
|
||||
expect(connectCb).toHaveBeenCalledTimes(1);
|
||||
expect(stateChangeCb).toHaveBeenCalled();
|
||||
|
||||
client.off('connect');
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should reject when sending without connection', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await expect(client.read('users')).rejects.toThrow('WebSocket is not connected');
|
||||
});
|
||||
|
||||
it('should handle pong messages without error', async () => {
|
||||
const client = new WebSocketClient(wsConfig);
|
||||
await client.connect();
|
||||
|
||||
// Should not throw
|
||||
mockWsInstance!.simulateMessage({ type: 'pong' });
|
||||
|
||||
client.disconnect();
|
||||
});
|
||||
|
||||
it('should handle malformed messages gracefully', async () => {
|
||||
const client = new WebSocketClient({ ...wsConfig, debug: false });
|
||||
await client.connect();
|
||||
|
||||
// Simulate non-JSON message
|
||||
mockWsInstance!.onmessage?.({ data: 'not-json' } as any);
|
||||
|
||||
client.disconnect();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getWebSocketClient singleton', () => {
|
||||
it('returns same instance for same url', () => {
|
||||
const a = getWebSocketClient({ url: 'ws://ws-singleton:8080' });
|
||||
const b = getWebSocketClient({ url: 'ws://ws-singleton:8080' });
|
||||
expect(a).toBe(b);
|
||||
});
|
||||
|
||||
it('returns different instances for different urls', () => {
|
||||
const a = getWebSocketClient({ url: 'ws://ws-singleton-a:8080' });
|
||||
const b = getWebSocketClient({ url: 'ws://ws-singleton-b:8080' });
|
||||
expect(a).not.toBe(b);
|
||||
});
|
||||
});
|
||||
@@ -1,132 +0,0 @@
|
||||
import { ClientConfig, APIResponse, TableMetadata, Options, RequestBody } from "./types";
|
||||
|
||||
// Helper functions
|
||||
const getHeaders = (options?: Record<string,any>): HeadersInit => {
|
||||
const headers: HeadersInit = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
if (options?.token) {
|
||||
headers['Authorization'] = `Bearer ${options.token}`;
|
||||
}
|
||||
|
||||
return headers;
|
||||
};
|
||||
|
||||
const buildUrl = (config: ClientConfig, schema: string, entity: string, id?: string): string => {
|
||||
let url = `${config.baseUrl}/${schema}/${entity}`;
|
||||
if (id) {
|
||||
url += `/${id}`;
|
||||
}
|
||||
return url;
|
||||
};
|
||||
|
||||
const fetchWithError = async <T>(url: string, options: RequestInit): Promise<APIResponse<T>> => {
|
||||
try {
|
||||
const response = await fetch(url, options);
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(data.error?.message || 'An error occurred');
|
||||
}
|
||||
|
||||
return data;
|
||||
} catch (error) {
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
// API Functions
|
||||
export const getMetadata = async (
|
||||
config: ClientConfig,
|
||||
schema: string,
|
||||
entity: string
|
||||
): Promise<APIResponse<TableMetadata>> => {
|
||||
const url = buildUrl(config, schema, entity);
|
||||
return fetchWithError<TableMetadata>(url, {
|
||||
method: 'GET',
|
||||
headers: getHeaders(config),
|
||||
});
|
||||
};
|
||||
|
||||
export const read = async <T = any>(
|
||||
config: ClientConfig,
|
||||
schema: string,
|
||||
entity: string,
|
||||
id?: string,
|
||||
options?: Options
|
||||
): Promise<APIResponse<T>> => {
|
||||
const url = buildUrl(config, schema, entity, id);
|
||||
const body: RequestBody = {
|
||||
operation: 'read',
|
||||
options,
|
||||
};
|
||||
|
||||
return fetchWithError<T>(url, {
|
||||
method: 'POST',
|
||||
headers: getHeaders(config),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
};
|
||||
|
||||
export const create = async <T = any>(
|
||||
config: ClientConfig,
|
||||
schema: string,
|
||||
entity: string,
|
||||
data: any | any[],
|
||||
options?: Options
|
||||
): Promise<APIResponse<T>> => {
|
||||
const url = buildUrl(config, schema, entity);
|
||||
const body: RequestBody = {
|
||||
operation: 'create',
|
||||
data,
|
||||
options,
|
||||
};
|
||||
|
||||
return fetchWithError<T>(url, {
|
||||
method: 'POST',
|
||||
headers: getHeaders(config),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
};
|
||||
|
||||
export const update = async <T = any>(
|
||||
config: ClientConfig,
|
||||
schema: string,
|
||||
entity: string,
|
||||
data: any | any[],
|
||||
id?: string | string[],
|
||||
options?: Options
|
||||
): Promise<APIResponse<T>> => {
|
||||
const url = buildUrl(config, schema, entity, typeof id === 'string' ? id : undefined);
|
||||
const body: RequestBody = {
|
||||
operation: 'update',
|
||||
id: typeof id === 'string' ? undefined : id,
|
||||
data,
|
||||
options,
|
||||
};
|
||||
|
||||
return fetchWithError<T>(url, {
|
||||
method: 'POST',
|
||||
headers: getHeaders(config),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
};
|
||||
|
||||
export const deleteEntity = async (
|
||||
config: ClientConfig,
|
||||
schema: string,
|
||||
entity: string,
|
||||
id: string
|
||||
): Promise<APIResponse<void>> => {
|
||||
const url = buildUrl(config, schema, entity, id);
|
||||
const body: RequestBody = {
|
||||
operation: 'delete',
|
||||
};
|
||||
|
||||
return fetchWithError<void>(url, {
|
||||
method: 'POST',
|
||||
headers: getHeaders(config),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
};
|
||||
1
resolvespec-js/src/common/index.ts
Normal file
1
resolvespec-js/src/common/index.ts
Normal file
@@ -0,0 +1 @@
|
||||
export * from './types';
|
||||
129
resolvespec-js/src/common/types.ts
Normal file
129
resolvespec-js/src/common/types.ts
Normal file
@@ -0,0 +1,129 @@
|
||||
// Types aligned with Go pkg/common/types.go
|
||||
|
||||
export type Operator =
|
||||
| 'eq' | 'neq' | 'gt' | 'gte' | 'lt' | 'lte'
|
||||
| 'like' | 'ilike' | 'in'
|
||||
| 'contains' | 'startswith' | 'endswith'
|
||||
| 'between' | 'between_inclusive'
|
||||
| 'is_null' | 'is_not_null';
|
||||
|
||||
export type Operation = 'read' | 'create' | 'update' | 'delete';
|
||||
export type SortDirection = 'asc' | 'desc' | 'ASC' | 'DESC';
|
||||
|
||||
export interface Parameter {
|
||||
name: string;
|
||||
value: string;
|
||||
sequence?: number;
|
||||
}
|
||||
|
||||
export interface PreloadOption {
|
||||
relation: string;
|
||||
table_name?: string;
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
sort?: SortOption[];
|
||||
filters?: FilterOption[];
|
||||
where?: string;
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
updatable?: boolean;
|
||||
computed_ql?: Record<string, string>;
|
||||
recursive?: boolean;
|
||||
// Relationship keys
|
||||
primary_key?: string;
|
||||
related_key?: string;
|
||||
foreign_key?: string;
|
||||
recursive_child_key?: string;
|
||||
// Custom SQL JOINs
|
||||
sql_joins?: string[];
|
||||
join_aliases?: string[];
|
||||
}
|
||||
|
||||
export interface FilterOption {
|
||||
column: string;
|
||||
operator: Operator | string;
|
||||
value: any;
|
||||
logic_operator?: 'AND' | 'OR';
|
||||
}
|
||||
|
||||
export interface SortOption {
|
||||
column: string;
|
||||
direction: SortDirection;
|
||||
}
|
||||
|
||||
export interface CustomOperator {
|
||||
name: string;
|
||||
sql: string;
|
||||
}
|
||||
|
||||
export interface ComputedColumn {
|
||||
name: string;
|
||||
expression: string;
|
||||
}
|
||||
|
||||
export interface Options {
|
||||
preload?: PreloadOption[];
|
||||
columns?: string[];
|
||||
omit_columns?: string[];
|
||||
filters?: FilterOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
customOperators?: CustomOperator[];
|
||||
computedColumns?: ComputedColumn[];
|
||||
parameters?: Parameter[];
|
||||
cursor_forward?: string;
|
||||
cursor_backward?: string;
|
||||
fetch_row_number?: string;
|
||||
}
|
||||
|
||||
export interface RequestBody {
|
||||
operation: Operation;
|
||||
id?: number | string | string[];
|
||||
data?: any | any[];
|
||||
options?: Options;
|
||||
}
|
||||
|
||||
export interface Metadata {
|
||||
total: number;
|
||||
count: number;
|
||||
filtered: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
row_number?: number;
|
||||
}
|
||||
|
||||
export interface APIError {
|
||||
code: string;
|
||||
message: string;
|
||||
details?: any;
|
||||
detail?: string;
|
||||
}
|
||||
|
||||
export interface APIResponse<T = any> {
|
||||
success: boolean;
|
||||
data: T;
|
||||
metadata?: Metadata;
|
||||
error?: APIError;
|
||||
}
|
||||
|
||||
export interface Column {
|
||||
name: string;
|
||||
type: string;
|
||||
is_nullable: boolean;
|
||||
is_primary: boolean;
|
||||
is_unique: boolean;
|
||||
has_index: boolean;
|
||||
}
|
||||
|
||||
export interface TableMetadata {
|
||||
schema: string;
|
||||
table: string;
|
||||
columns: Column[];
|
||||
relations: string[];
|
||||
}
|
||||
|
||||
export interface ClientConfig {
|
||||
baseUrl: string;
|
||||
token?: string;
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
import { getMetadata, read, create, update, deleteEntity } from "./api";
|
||||
import { ClientConfig } from "./types";
|
||||
|
||||
// Usage Examples
|
||||
const config: ClientConfig = {
|
||||
baseUrl: 'http://api.example.com/v1',
|
||||
token: 'your-token-here'
|
||||
};
|
||||
|
||||
// Example usage
|
||||
const examples = async () => {
|
||||
// Get metadata
|
||||
const metadata = await getMetadata(config, 'test', 'employees');
|
||||
|
||||
|
||||
// Read with relations
|
||||
const employees = await read(config, 'test', 'employees', undefined, {
|
||||
preload: [
|
||||
{
|
||||
relation: 'department',
|
||||
columns: ['id', 'name']
|
||||
}
|
||||
],
|
||||
filters: [
|
||||
{
|
||||
column: 'status',
|
||||
operator: 'eq',
|
||||
value: 'active'
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Create single record
|
||||
const newEmployee = await create(config, 'test', 'employees', {
|
||||
first_name: 'John',
|
||||
last_name: 'Doe',
|
||||
email: 'john@example.com'
|
||||
});
|
||||
|
||||
// Bulk create
|
||||
const newEmployees = await create(config, 'test', 'employees', [
|
||||
{
|
||||
first_name: 'Jane',
|
||||
last_name: 'Smith',
|
||||
email: 'jane@example.com'
|
||||
},
|
||||
{
|
||||
first_name: 'Bob',
|
||||
last_name: 'Johnson',
|
||||
email: 'bob@example.com'
|
||||
}
|
||||
]);
|
||||
|
||||
// Update single record
|
||||
const updatedEmployee = await update(config, 'test', 'employees',
|
||||
{ status: 'inactive' },
|
||||
'emp123'
|
||||
);
|
||||
|
||||
// Bulk update
|
||||
const updatedEmployees = await update(config, 'test', 'employees',
|
||||
{ department_id: 'dept2' },
|
||||
['emp1', 'emp2', 'emp3']
|
||||
);
|
||||
|
||||
// Delete
|
||||
await deleteEntity(config, 'test', 'employees', 'emp123');
|
||||
};
|
||||
345
resolvespec-js/src/headerspec/client.ts
Normal file
345
resolvespec-js/src/headerspec/client.ts
Normal file
@@ -0,0 +1,345 @@
|
||||
import type {
|
||||
APIResponse,
|
||||
ClientConfig,
|
||||
CustomOperator,
|
||||
FilterOption,
|
||||
Options,
|
||||
PreloadOption,
|
||||
SortOption,
|
||||
} from "../common/types";
|
||||
|
||||
/**
|
||||
* Encode a value with base64 and ZIP_ prefix for complex header values.
|
||||
*/
|
||||
export function encodeHeaderValue(value: string): string {
|
||||
if (typeof btoa === "function") {
|
||||
return "ZIP_" + btoa(value);
|
||||
}
|
||||
return "ZIP_" + Buffer.from(value, "utf-8").toString("base64");
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode a header value that may be base64 encoded with ZIP_ or __ prefix.
|
||||
*/
|
||||
export function decodeHeaderValue(value: string): string {
|
||||
let code = value;
|
||||
|
||||
if (code.startsWith("ZIP_")) {
|
||||
code = code.slice(4).replace(/[\n\r ]/g, "");
|
||||
code = decodeBase64(code);
|
||||
} else if (code.startsWith("__")) {
|
||||
code = code.slice(2).replace(/[\n\r ]/g, "");
|
||||
code = decodeBase64(code);
|
||||
}
|
||||
|
||||
// Handle nested encoding
|
||||
if (code.startsWith("ZIP_") || code.startsWith("__")) {
|
||||
code = decodeHeaderValue(code);
|
||||
}
|
||||
|
||||
return code;
|
||||
}
|
||||
|
||||
function decodeBase64(str: string): string {
|
||||
if (typeof atob === "function") {
|
||||
return atob(str);
|
||||
}
|
||||
return Buffer.from(str, "base64").toString("utf-8");
|
||||
}
|
||||
|
||||
/**
|
||||
* Build HTTP headers from Options, matching Go's restheadspec handler conventions.
|
||||
*
|
||||
* Header mapping:
|
||||
* - X-Select-Fields: comma-separated columns
|
||||
* - X-Not-Select-Fields: comma-separated omit_columns
|
||||
* - X-FieldFilter-{col}: exact match (eq)
|
||||
* - X-SearchOp-{operator}-{col}: AND filter
|
||||
* - X-SearchOr-{operator}-{col}: OR filter
|
||||
* - X-Sort: +col (asc), -col (desc)
|
||||
* - X-Limit, X-Offset: pagination
|
||||
* - X-Cursor-Forward, X-Cursor-Backward: cursor pagination
|
||||
* - X-Preload: RelationName:field1,field2 pipe-separated
|
||||
* - X-Fetch-RowNumber: row number fetch
|
||||
* - X-CQL-SEL-{col}: computed columns
|
||||
* - X-Custom-SQL-W: custom operators (AND)
|
||||
*/
|
||||
export function buildHeaders(options: Options): Record<string, string> {
|
||||
const headers: Record<string, string> = {};
|
||||
|
||||
// Column selection
|
||||
if (options.columns?.length) {
|
||||
headers["X-Select-Fields"] = options.columns.join(",");
|
||||
}
|
||||
|
||||
if (options.omit_columns?.length) {
|
||||
headers["X-Not-Select-Fields"] = options.omit_columns.join(",");
|
||||
}
|
||||
|
||||
// Filters
|
||||
if (options.filters?.length) {
|
||||
for (const filter of options.filters) {
|
||||
const logicOp = filter.logic_operator ?? "AND";
|
||||
const op = mapOperatorToHeaderOp(filter.operator);
|
||||
const valueStr = formatFilterValue(filter);
|
||||
|
||||
if (filter.operator === "eq" && logicOp === "AND") {
|
||||
// Simple field filter shorthand
|
||||
headers[`X-FieldFilter-${filter.column}`] = valueStr;
|
||||
} else if (logicOp === "OR") {
|
||||
headers[`X-SearchOr-${op}-${filter.column}`] = valueStr;
|
||||
} else {
|
||||
headers[`X-SearchOp-${op}-${filter.column}`] = valueStr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort
|
||||
if (options.sort?.length) {
|
||||
const sortParts = options.sort.map((s: SortOption) => {
|
||||
const dir = s.direction.toUpperCase();
|
||||
return dir === "DESC" ? `-${s.column}` : `+${s.column}`;
|
||||
});
|
||||
headers["X-Sort"] = sortParts.join(",");
|
||||
}
|
||||
|
||||
// Pagination
|
||||
if (options.limit !== undefined) {
|
||||
headers["X-Limit"] = String(options.limit);
|
||||
}
|
||||
if (options.offset !== undefined) {
|
||||
headers["X-Offset"] = String(options.offset);
|
||||
}
|
||||
|
||||
// Cursor pagination
|
||||
if (options.cursor_forward) {
|
||||
headers["X-Cursor-Forward"] = options.cursor_forward;
|
||||
}
|
||||
if (options.cursor_backward) {
|
||||
headers["X-Cursor-Backward"] = options.cursor_backward;
|
||||
}
|
||||
|
||||
// Preload
|
||||
if (options.preload?.length) {
|
||||
const parts = options.preload.map((p: PreloadOption) => {
|
||||
if (p.columns?.length) {
|
||||
return `${p.relation}:${p.columns.join(",")}`;
|
||||
}
|
||||
return p.relation;
|
||||
});
|
||||
headers["X-Preload"] = parts.join("|");
|
||||
}
|
||||
|
||||
// Fetch row number
|
||||
if (options.fetch_row_number) {
|
||||
headers["X-Fetch-RowNumber"] = options.fetch_row_number;
|
||||
}
|
||||
|
||||
// Computed columns
|
||||
if (options.computedColumns?.length) {
|
||||
for (const cc of options.computedColumns) {
|
||||
headers[`X-CQL-SEL-${cc.name}`] = cc.expression;
|
||||
}
|
||||
}
|
||||
|
||||
// Custom operators -> X-Custom-SQL-W
|
||||
if (options.customOperators?.length) {
|
||||
const sqlParts = options.customOperators.map(
|
||||
(co: CustomOperator) => co.sql,
|
||||
);
|
||||
headers["X-Custom-SQL-W"] = sqlParts.join(" AND ");
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
function mapOperatorToHeaderOp(operator: string): string {
|
||||
switch (operator) {
|
||||
case "eq":
|
||||
return "equals";
|
||||
case "neq":
|
||||
return "notequals";
|
||||
case "gt":
|
||||
return "greaterthan";
|
||||
case "gte":
|
||||
return "greaterthanorequal";
|
||||
case "lt":
|
||||
return "lessthan";
|
||||
case "lte":
|
||||
return "lessthanorequal";
|
||||
case "like":
|
||||
case "ilike":
|
||||
case "contains":
|
||||
return "contains";
|
||||
case "startswith":
|
||||
return "beginswith";
|
||||
case "endswith":
|
||||
return "endswith";
|
||||
case "in":
|
||||
return "in";
|
||||
case "between":
|
||||
return "between";
|
||||
case "between_inclusive":
|
||||
return "betweeninclusive";
|
||||
case "is_null":
|
||||
return "empty";
|
||||
case "is_not_null":
|
||||
return "notempty";
|
||||
default:
|
||||
return operator;
|
||||
}
|
||||
}
|
||||
|
||||
function formatFilterValue(filter: FilterOption): string {
|
||||
if (filter.value === null || filter.value === undefined) {
|
||||
return "";
|
||||
}
|
||||
if (Array.isArray(filter.value)) {
|
||||
return filter.value.join(",");
|
||||
}
|
||||
return String(filter.value);
|
||||
}
|
||||
|
||||
const instances = new Map<string, HeaderSpecClient>();
|
||||
|
||||
export function getHeaderSpecClient(config: ClientConfig): HeaderSpecClient {
|
||||
const key = config.baseUrl;
|
||||
let instance = instances.get(key);
|
||||
if (!instance) {
|
||||
instance = new HeaderSpecClient(config);
|
||||
instances.set(key, instance);
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
/**
|
||||
* HeaderSpec REST client.
|
||||
* Sends query options via HTTP headers instead of request body, matching the Go restheadspec handler.
|
||||
*
|
||||
* HTTP methods: GET=read, POST=create, PUT=update, DELETE=delete
|
||||
*/
|
||||
export class HeaderSpecClient {
|
||||
private config: ClientConfig;
|
||||
|
||||
constructor(config: ClientConfig) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
private buildUrl(schema: string, entity: string, id?: string): string {
|
||||
let url = `${this.config.baseUrl}/${schema}/${entity}`;
|
||||
if (id) {
|
||||
url += `/${id}`;
|
||||
}
|
||||
return url;
|
||||
}
|
||||
|
||||
private baseHeaders(): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
if (this.config.token) {
|
||||
headers["Authorization"] = `Bearer ${this.config.token}`;
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
private async fetchWithError<T>(
|
||||
url: string,
|
||||
init: RequestInit,
|
||||
): Promise<APIResponse<T>> {
|
||||
const response = await fetch(url, init);
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
data.error?.message ||
|
||||
`${response.statusText} ` + `(${response.status})`,
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
data: data,
|
||||
success: true,
|
||||
error: data.error ? data.error : undefined,
|
||||
metadata: {
|
||||
count: response.headers.get("content-range")
|
||||
? Number(response.headers.get("content-range")?.split("/")[1])
|
||||
: 0,
|
||||
total: response.headers.get("content-range")
|
||||
? Number(response.headers.get("content-range")?.split("/")[1])
|
||||
: 0,
|
||||
filtered: response.headers.get("content-range")
|
||||
? Number(response.headers.get("content-range")?.split("/")[1])
|
||||
: 0,
|
||||
offset: response.headers.get("content-range")
|
||||
? Number(
|
||||
response.headers
|
||||
.get("content-range")
|
||||
?.split("/")[0]
|
||||
.split("-")[0],
|
||||
)
|
||||
: 0,
|
||||
limit: response.headers.get("x-limit")
|
||||
? Number(response.headers.get("x-limit"))
|
||||
: 0,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async read<T = any>(
|
||||
schema: string,
|
||||
entity: string,
|
||||
id?: string,
|
||||
options?: Options,
|
||||
): Promise<APIResponse<T>> {
|
||||
const url = this.buildUrl(schema, entity, id);
|
||||
const optHeaders = options ? buildHeaders(options) : {};
|
||||
return this.fetchWithError<T>(url, {
|
||||
method: "GET",
|
||||
headers: { ...this.baseHeaders(), ...optHeaders },
|
||||
});
|
||||
}
|
||||
|
||||
async create<T = any>(
|
||||
schema: string,
|
||||
entity: string,
|
||||
data: any,
|
||||
options?: Options,
|
||||
): Promise<APIResponse<T>> {
|
||||
const url = this.buildUrl(schema, entity);
|
||||
const optHeaders = options ? buildHeaders(options) : {};
|
||||
return this.fetchWithError<T>(url, {
|
||||
method: "POST",
|
||||
headers: { ...this.baseHeaders(), ...optHeaders },
|
||||
body: JSON.stringify(data),
|
||||
});
|
||||
}
|
||||
|
||||
async update<T = any>(
|
||||
schema: string,
|
||||
entity: string,
|
||||
id: string,
|
||||
data: any,
|
||||
options?: Options,
|
||||
): Promise<APIResponse<T>> {
|
||||
const url = this.buildUrl(schema, entity, id);
|
||||
const optHeaders = options ? buildHeaders(options) : {};
|
||||
return this.fetchWithError<T>(url, {
|
||||
method: "PUT",
|
||||
headers: { ...this.baseHeaders(), ...optHeaders },
|
||||
body: JSON.stringify(data),
|
||||
});
|
||||
}
|
||||
|
||||
async delete(
|
||||
schema: string,
|
||||
entity: string,
|
||||
id: string,
|
||||
): Promise<APIResponse<void>> {
|
||||
const url = this.buildUrl(schema, entity, id);
|
||||
return this.fetchWithError<void>(url, {
|
||||
method: "DELETE",
|
||||
headers: this.baseHeaders(),
|
||||
});
|
||||
}
|
||||
}
|
||||
7
resolvespec-js/src/headerspec/index.ts
Normal file
7
resolvespec-js/src/headerspec/index.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export {
|
||||
HeaderSpecClient,
|
||||
getHeaderSpecClient,
|
||||
buildHeaders,
|
||||
encodeHeaderValue,
|
||||
decodeHeaderValue,
|
||||
} from './client';
|
||||
@@ -1,7 +1,11 @@
|
||||
// Types
|
||||
export * from './types';
|
||||
export * from './websocket-types';
|
||||
// Common types
|
||||
export * from './common';
|
||||
|
||||
// WebSocket Client
|
||||
export { WebSocketClient } from './websocket-client';
|
||||
export type { WebSocketClient as default } from './websocket-client';
|
||||
// REST client (ResolveSpec)
|
||||
export * from './resolvespec';
|
||||
|
||||
// WebSocket client
|
||||
export * from './websocketspec';
|
||||
|
||||
// HeaderSpec client
|
||||
export * from './headerspec';
|
||||
|
||||
141
resolvespec-js/src/resolvespec/client.ts
Normal file
141
resolvespec-js/src/resolvespec/client.ts
Normal file
@@ -0,0 +1,141 @@
|
||||
import type { ClientConfig, APIResponse, TableMetadata, Options, RequestBody } from '../common/types';
|
||||
|
||||
const instances = new Map<string, ResolveSpecClient>();
|
||||
|
||||
export function getResolveSpecClient(config: ClientConfig): ResolveSpecClient {
|
||||
const key = config.baseUrl;
|
||||
let instance = instances.get(key);
|
||||
if (!instance) {
|
||||
instance = new ResolveSpecClient(config);
|
||||
instances.set(key, instance);
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
export class ResolveSpecClient {
|
||||
private config: ClientConfig;
|
||||
|
||||
constructor(config: ClientConfig) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
private buildUrl(schema: string, entity: string, id?: string): string {
|
||||
let url = `${this.config.baseUrl}/${schema}/${entity}`;
|
||||
if (id) {
|
||||
url += `/${id}`;
|
||||
}
|
||||
return url;
|
||||
}
|
||||
|
||||
private baseHeaders(): HeadersInit {
|
||||
const headers: Record<string, string> = {
|
||||
'Content-Type': 'application/json',
|
||||
};
|
||||
|
||||
if (this.config.token) {
|
||||
headers['Authorization'] = `Bearer ${this.config.token}`;
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
|
||||
private async fetchWithError<T>(url: string, options: RequestInit): Promise<APIResponse<T>> {
|
||||
const response = await fetch(url, options);
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(data.error?.message || 'An error occurred');
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
async getMetadata(schema: string, entity: string): Promise<APIResponse<TableMetadata>> {
|
||||
const url = this.buildUrl(schema, entity);
|
||||
return this.fetchWithError<TableMetadata>(url, {
|
||||
method: 'GET',
|
||||
headers: this.baseHeaders(),
|
||||
});
|
||||
}
|
||||
|
||||
async read<T = any>(
|
||||
schema: string,
|
||||
entity: string,
|
||||
id?: number | string | string[],
|
||||
options?: Options
|
||||
): Promise<APIResponse<T>> {
|
||||
const urlId = typeof id === 'number' || typeof id === 'string' ? String(id) : undefined;
|
||||
const url = this.buildUrl(schema, entity, urlId);
|
||||
const body: RequestBody = {
|
||||
operation: 'read',
|
||||
id: Array.isArray(id) ? id : undefined,
|
||||
options,
|
||||
};
|
||||
|
||||
return this.fetchWithError<T>(url, {
|
||||
method: 'POST',
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
}
|
||||
|
||||
async create<T = any>(
|
||||
schema: string,
|
||||
entity: string,
|
||||
data: any | any[],
|
||||
options?: Options
|
||||
): Promise<APIResponse<T>> {
|
||||
const url = this.buildUrl(schema, entity);
|
||||
const body: RequestBody = {
|
||||
operation: 'create',
|
||||
data,
|
||||
options,
|
||||
};
|
||||
|
||||
return this.fetchWithError<T>(url, {
|
||||
method: 'POST',
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
}
|
||||
|
||||
async update<T = any>(
|
||||
schema: string,
|
||||
entity: string,
|
||||
data: any | any[],
|
||||
id?: number | string | string[],
|
||||
options?: Options
|
||||
): Promise<APIResponse<T>> {
|
||||
const urlId = typeof id === 'number' || typeof id === 'string' ? String(id) : undefined;
|
||||
const url = this.buildUrl(schema, entity, urlId);
|
||||
const body: RequestBody = {
|
||||
operation: 'update',
|
||||
id: Array.isArray(id) ? id : undefined,
|
||||
data,
|
||||
options,
|
||||
};
|
||||
|
||||
return this.fetchWithError<T>(url, {
|
||||
method: 'POST',
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
}
|
||||
|
||||
async delete(
|
||||
schema: string,
|
||||
entity: string,
|
||||
id: number | string
|
||||
): Promise<APIResponse<void>> {
|
||||
const url = this.buildUrl(schema, entity, String(id));
|
||||
const body: RequestBody = {
|
||||
operation: 'delete',
|
||||
};
|
||||
|
||||
return this.fetchWithError<void>(url, {
|
||||
method: 'POST',
|
||||
headers: this.baseHeaders(),
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
}
|
||||
}
|
||||
1
resolvespec-js/src/resolvespec/index.ts
Normal file
1
resolvespec-js/src/resolvespec/index.ts
Normal file
@@ -0,0 +1 @@
|
||||
export { ResolveSpecClient, getResolveSpecClient } from './client';
|
||||
@@ -1,86 +0,0 @@
|
||||
// Types
|
||||
export type Operator = 'eq' | 'neq' | 'gt' | 'gte' | 'lt' | 'lte' | 'like' | 'ilike' | 'in';
|
||||
export type Operation = 'read' | 'create' | 'update' | 'delete';
|
||||
export type SortDirection = 'asc' | 'desc';
|
||||
|
||||
export interface PreloadOption {
|
||||
relation: string;
|
||||
columns?: string[];
|
||||
filters?: FilterOption[];
|
||||
}
|
||||
|
||||
export interface FilterOption {
|
||||
column: string;
|
||||
operator: Operator;
|
||||
value: any;
|
||||
}
|
||||
|
||||
export interface SortOption {
|
||||
column: string;
|
||||
direction: SortDirection;
|
||||
}
|
||||
|
||||
export interface CustomOperator {
|
||||
name: string;
|
||||
sql: string;
|
||||
}
|
||||
|
||||
export interface ComputedColumn {
|
||||
name: string;
|
||||
expression: string;
|
||||
}
|
||||
|
||||
export interface Options {
|
||||
preload?: PreloadOption[];
|
||||
columns?: string[];
|
||||
filters?: FilterOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
customOperators?: CustomOperator[];
|
||||
computedColumns?: ComputedColumn[];
|
||||
}
|
||||
|
||||
export interface RequestBody {
|
||||
operation: Operation;
|
||||
id?: string | string[];
|
||||
data?: any | any[];
|
||||
options?: Options;
|
||||
}
|
||||
|
||||
export interface APIResponse<T = any> {
|
||||
success: boolean;
|
||||
data: T;
|
||||
metadata?: {
|
||||
total: number;
|
||||
filtered: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
};
|
||||
error?: {
|
||||
code: string;
|
||||
message: string;
|
||||
details?: any;
|
||||
};
|
||||
}
|
||||
|
||||
export interface Column {
|
||||
name: string;
|
||||
type: string;
|
||||
is_nullable: boolean;
|
||||
is_primary: boolean;
|
||||
is_unique: boolean;
|
||||
has_index: boolean;
|
||||
}
|
||||
|
||||
export interface TableMetadata {
|
||||
schema: string;
|
||||
table: string;
|
||||
columns: Column[];
|
||||
relations: string[];
|
||||
}
|
||||
|
||||
export interface ClientConfig {
|
||||
baseUrl: string;
|
||||
token?: string;
|
||||
}
|
||||
@@ -1,427 +0,0 @@
|
||||
import { WebSocketClient } from './websocket-client';
|
||||
import type { WSNotificationMessage } from './websocket-types';
|
||||
|
||||
/**
|
||||
* Example 1: Basic Usage
|
||||
*/
|
||||
export async function basicUsageExample() {
|
||||
// Create client
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
debug: true
|
||||
});
|
||||
|
||||
// Connect
|
||||
await client.connect();
|
||||
|
||||
// Read users
|
||||
const users = await client.read('users', {
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
],
|
||||
limit: 10,
|
||||
sort: [
|
||||
{ column: 'name', direction: 'asc' }
|
||||
]
|
||||
});
|
||||
|
||||
console.log('Users:', users);
|
||||
|
||||
// Create a user
|
||||
const newUser = await client.create('users', {
|
||||
name: 'John Doe',
|
||||
email: 'john@example.com',
|
||||
status: 'active'
|
||||
}, { schema: 'public' });
|
||||
|
||||
console.log('Created user:', newUser);
|
||||
|
||||
// Update user
|
||||
const updatedUser = await client.update('users', '123', {
|
||||
name: 'John Updated'
|
||||
}, { schema: 'public' });
|
||||
|
||||
console.log('Updated user:', updatedUser);
|
||||
|
||||
// Delete user
|
||||
await client.delete('users', '123', { schema: 'public' });
|
||||
|
||||
// Disconnect
|
||||
client.disconnect();
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 2: Real-time Subscriptions
|
||||
*/
|
||||
export async function subscriptionExample() {
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
debug: true
|
||||
});
|
||||
|
||||
await client.connect();
|
||||
|
||||
// Subscribe to user changes
|
||||
const subscriptionId = await client.subscribe(
|
||||
'users',
|
||||
(notification: WSNotificationMessage) => {
|
||||
console.log('User changed:', notification.operation, notification.data);
|
||||
|
||||
switch (notification.operation) {
|
||||
case 'create':
|
||||
console.log('New user created:', notification.data);
|
||||
break;
|
||||
case 'update':
|
||||
console.log('User updated:', notification.data);
|
||||
break;
|
||||
case 'delete':
|
||||
console.log('User deleted:', notification.data);
|
||||
break;
|
||||
}
|
||||
},
|
||||
{
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'active' }
|
||||
]
|
||||
}
|
||||
);
|
||||
|
||||
console.log('Subscribed with ID:', subscriptionId);
|
||||
|
||||
// Later: unsubscribe
|
||||
setTimeout(async () => {
|
||||
await client.unsubscribe(subscriptionId);
|
||||
console.log('Unsubscribed');
|
||||
client.disconnect();
|
||||
}, 60000);
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 3: Event Handling
|
||||
*/
|
||||
export async function eventHandlingExample() {
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws'
|
||||
});
|
||||
|
||||
// Listen to connection events
|
||||
client.on('connect', () => {
|
||||
console.log('Connected!');
|
||||
});
|
||||
|
||||
client.on('disconnect', (event) => {
|
||||
console.log('Disconnected:', event.code, event.reason);
|
||||
});
|
||||
|
||||
client.on('error', (error) => {
|
||||
console.error('WebSocket error:', error);
|
||||
});
|
||||
|
||||
client.on('stateChange', (state) => {
|
||||
console.log('State changed to:', state);
|
||||
});
|
||||
|
||||
client.on('message', (message) => {
|
||||
console.log('Received message:', message);
|
||||
});
|
||||
|
||||
await client.connect();
|
||||
|
||||
// Your operations here...
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 4: Multiple Subscriptions
|
||||
*/
|
||||
export async function multipleSubscriptionsExample() {
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
debug: true
|
||||
});
|
||||
|
||||
await client.connect();
|
||||
|
||||
// Subscribe to users
|
||||
const userSubId = await client.subscribe(
|
||||
'users',
|
||||
(notification) => {
|
||||
console.log('[Users]', notification.operation, notification.data);
|
||||
},
|
||||
{ schema: 'public' }
|
||||
);
|
||||
|
||||
// Subscribe to posts
|
||||
const postSubId = await client.subscribe(
|
||||
'posts',
|
||||
(notification) => {
|
||||
console.log('[Posts]', notification.operation, notification.data);
|
||||
},
|
||||
{
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'published' }
|
||||
]
|
||||
}
|
||||
);
|
||||
|
||||
// Subscribe to comments
|
||||
const commentSubId = await client.subscribe(
|
||||
'comments',
|
||||
(notification) => {
|
||||
console.log('[Comments]', notification.operation, notification.data);
|
||||
},
|
||||
{ schema: 'public' }
|
||||
);
|
||||
|
||||
console.log('Active subscriptions:', client.getSubscriptions());
|
||||
|
||||
// Clean up after 60 seconds
|
||||
setTimeout(async () => {
|
||||
await client.unsubscribe(userSubId);
|
||||
await client.unsubscribe(postSubId);
|
||||
await client.unsubscribe(commentSubId);
|
||||
client.disconnect();
|
||||
}, 60000);
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 5: Advanced Queries
|
||||
*/
|
||||
export async function advancedQueriesExample() {
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws'
|
||||
});
|
||||
|
||||
await client.connect();
|
||||
|
||||
// Complex query with filters, sorting, pagination, and preloading
|
||||
const posts = await client.read('posts', {
|
||||
schema: 'public',
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'published' },
|
||||
{ column: 'views', operator: 'gte', value: 100 }
|
||||
],
|
||||
columns: ['id', 'title', 'content', 'user_id', 'created_at'],
|
||||
sort: [
|
||||
{ column: 'created_at', direction: 'desc' },
|
||||
{ column: 'views', direction: 'desc' }
|
||||
],
|
||||
preload: [
|
||||
{
|
||||
relation: 'user',
|
||||
columns: ['id', 'name', 'email']
|
||||
},
|
||||
{
|
||||
relation: 'comments',
|
||||
columns: ['id', 'content', 'user_id'],
|
||||
filters: [
|
||||
{ column: 'status', operator: 'eq', value: 'approved' }
|
||||
]
|
||||
}
|
||||
],
|
||||
limit: 20,
|
||||
offset: 0
|
||||
});
|
||||
|
||||
console.log('Posts:', posts);
|
||||
|
||||
// Get single record by ID
|
||||
const post = await client.read('posts', {
|
||||
schema: 'public',
|
||||
record_id: '123'
|
||||
});
|
||||
|
||||
console.log('Single post:', post);
|
||||
|
||||
client.disconnect();
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 6: Error Handling
|
||||
*/
|
||||
export async function errorHandlingExample() {
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws',
|
||||
reconnect: true,
|
||||
maxReconnectAttempts: 5
|
||||
});
|
||||
|
||||
client.on('error', (error) => {
|
||||
console.error('Connection error:', error);
|
||||
});
|
||||
|
||||
client.on('stateChange', (state) => {
|
||||
console.log('Connection state:', state);
|
||||
});
|
||||
|
||||
try {
|
||||
await client.connect();
|
||||
|
||||
try {
|
||||
// Try to read non-existent entity
|
||||
await client.read('nonexistent', { schema: 'public' });
|
||||
} catch (error) {
|
||||
console.error('Read error:', error);
|
||||
}
|
||||
|
||||
try {
|
||||
// Try to create invalid record
|
||||
await client.create('users', {
|
||||
// Missing required fields
|
||||
}, { schema: 'public' });
|
||||
} catch (error) {
|
||||
console.error('Create error:', error);
|
||||
}
|
||||
|
||||
} catch (error) {
|
||||
console.error('Connection failed:', error);
|
||||
} finally {
|
||||
client.disconnect();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 7: React Integration
|
||||
*/
|
||||
export function reactIntegrationExample() {
|
||||
const exampleCode = `
|
||||
import { useEffect, useState } from 'react';
|
||||
import { WebSocketClient } from '@warkypublic/resolvespec-js';
|
||||
|
||||
export function useWebSocket(url: string) {
|
||||
const [client] = useState(() => new WebSocketClient({ url }));
|
||||
const [isConnected, setIsConnected] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
client.on('connect', () => setIsConnected(true));
|
||||
client.on('disconnect', () => setIsConnected(false));
|
||||
|
||||
client.connect();
|
||||
|
||||
return () => {
|
||||
client.disconnect();
|
||||
};
|
||||
}, [client]);
|
||||
|
||||
return { client, isConnected };
|
||||
}
|
||||
|
||||
export function UsersComponent() {
|
||||
const { client, isConnected } = useWebSocket('ws://localhost:8080/ws');
|
||||
const [users, setUsers] = useState([]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isConnected) return;
|
||||
|
||||
// Subscribe to user changes
|
||||
const subscribeToUsers = async () => {
|
||||
const subId = await client.subscribe('users', (notification) => {
|
||||
if (notification.operation === 'create') {
|
||||
setUsers(prev => [...prev, notification.data]);
|
||||
} else if (notification.operation === 'update') {
|
||||
setUsers(prev => prev.map(u =>
|
||||
u.id === notification.data.id ? notification.data : u
|
||||
));
|
||||
} else if (notification.operation === 'delete') {
|
||||
setUsers(prev => prev.filter(u => u.id !== notification.data.id));
|
||||
}
|
||||
}, { schema: 'public' });
|
||||
|
||||
// Load initial users
|
||||
const initialUsers = await client.read('users', {
|
||||
schema: 'public',
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
|
||||
});
|
||||
setUsers(initialUsers);
|
||||
|
||||
return () => client.unsubscribe(subId);
|
||||
};
|
||||
|
||||
subscribeToUsers();
|
||||
}, [client, isConnected]);
|
||||
|
||||
const createUser = async (name: string, email: string) => {
|
||||
await client.create('users', { name, email, status: 'active' }, {
|
||||
schema: 'public'
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h2>Users ({users.length})</h2>
|
||||
{isConnected ? '🟢 Connected' : '🔴 Disconnected'}
|
||||
{/* Render users... */}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
`;
|
||||
|
||||
console.log(exampleCode);
|
||||
}
|
||||
|
||||
/**
|
||||
* Example 8: TypeScript with Typed Models
|
||||
*/
|
||||
export async function typedModelsExample() {
|
||||
// Define your models
|
||||
interface User {
|
||||
id: number;
|
||||
name: string;
|
||||
email: string;
|
||||
status: 'active' | 'inactive';
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
interface Post {
|
||||
id: number;
|
||||
title: string;
|
||||
content: string;
|
||||
user_id: number;
|
||||
status: 'draft' | 'published';
|
||||
views: number;
|
||||
user?: User;
|
||||
}
|
||||
|
||||
const client = new WebSocketClient({
|
||||
url: 'ws://localhost:8080/ws'
|
||||
});
|
||||
|
||||
await client.connect();
|
||||
|
||||
// Type-safe operations
|
||||
const users = await client.read<User[]>('users', {
|
||||
schema: 'public',
|
||||
filters: [{ column: 'status', operator: 'eq', value: 'active' }]
|
||||
});
|
||||
|
||||
const newUser = await client.create<User>('users', {
|
||||
name: 'Alice',
|
||||
email: 'alice@example.com',
|
||||
status: 'active'
|
||||
}, { schema: 'public' });
|
||||
|
||||
const posts = await client.read<Post[]>('posts', {
|
||||
schema: 'public',
|
||||
preload: [
|
||||
{
|
||||
relation: 'user',
|
||||
columns: ['id', 'name', 'email']
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
// Type-safe subscriptions
|
||||
await client.subscribe(
|
||||
'users',
|
||||
(notification) => {
|
||||
const user = notification.data as User;
|
||||
console.log('User changed:', user.name, user.email);
|
||||
},
|
||||
{ schema: 'public' }
|
||||
);
|
||||
|
||||
client.disconnect();
|
||||
}
|
||||
@@ -8,10 +8,22 @@ import type {
|
||||
WSOperation,
|
||||
WSOptions,
|
||||
Subscription,
|
||||
SubscriptionOptions,
|
||||
ConnectionState,
|
||||
WebSocketClientEvents
|
||||
} from './websocket-types';
|
||||
} from './types';
|
||||
import type { FilterOption, SortOption, PreloadOption } from '../common/types';
|
||||
|
||||
const instances = new Map<string, WebSocketClient>();
|
||||
|
||||
export function getWebSocketClient(config: WebSocketClientConfig): WebSocketClient {
|
||||
const key = config.url;
|
||||
let instance = instances.get(key);
|
||||
if (!instance) {
|
||||
instance = new WebSocketClient(config);
|
||||
instances.set(key, instance);
|
||||
}
|
||||
return instance;
|
||||
}
|
||||
|
||||
export class WebSocketClient {
|
||||
private ws: WebSocket | null = null;
|
||||
@@ -36,9 +48,6 @@ export class WebSocketClient {
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to WebSocket server
|
||||
*/
|
||||
async connect(): Promise<void> {
|
||||
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||
this.log('Already connected');
|
||||
@@ -78,7 +87,6 @@ export class WebSocketClient {
|
||||
this.setState('disconnected');
|
||||
this.emit('disconnect', event);
|
||||
|
||||
// Attempt reconnection if enabled and not manually closed
|
||||
if (this.config.reconnect && !this.isManualClose && this.reconnectAttempts < this.config.maxReconnectAttempts) {
|
||||
this.reconnectAttempts++;
|
||||
this.log(`Reconnection attempt ${this.reconnectAttempts}/${this.config.maxReconnectAttempts}`);
|
||||
@@ -97,9 +105,6 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Disconnect from WebSocket server
|
||||
*/
|
||||
disconnect(): void {
|
||||
this.isManualClose = true;
|
||||
|
||||
@@ -120,9 +125,6 @@ export class WebSocketClient {
|
||||
this.messageHandlers.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a CRUD request and wait for response
|
||||
*/
|
||||
async request<T = any>(
|
||||
operation: WSOperation,
|
||||
entity: string,
|
||||
@@ -148,7 +150,6 @@ export class WebSocketClient {
|
||||
};
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
// Set up response handler
|
||||
this.messageHandlers.set(id, (response: WSResponseMessage) => {
|
||||
if (response.success) {
|
||||
resolve(response.data);
|
||||
@@ -157,10 +158,8 @@ export class WebSocketClient {
|
||||
}
|
||||
});
|
||||
|
||||
// Send message
|
||||
this.send(message);
|
||||
|
||||
// Timeout after 30 seconds
|
||||
setTimeout(() => {
|
||||
if (this.messageHandlers.has(id)) {
|
||||
this.messageHandlers.delete(id);
|
||||
@@ -170,16 +169,13 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Read records
|
||||
*/
|
||||
async read<T = any>(entity: string, options?: {
|
||||
schema?: string;
|
||||
record_id?: string;
|
||||
filters?: import('./types').FilterOption[];
|
||||
filters?: FilterOption[];
|
||||
columns?: string[];
|
||||
sort?: import('./types').SortOption[];
|
||||
preload?: import('./types').PreloadOption[];
|
||||
sort?: SortOption[];
|
||||
preload?: PreloadOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
}): Promise<T> {
|
||||
@@ -197,9 +193,6 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a record
|
||||
*/
|
||||
async create<T = any>(entity: string, data: any, options?: {
|
||||
schema?: string;
|
||||
}): Promise<T> {
|
||||
@@ -209,9 +202,6 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a record
|
||||
*/
|
||||
async update<T = any>(entity: string, id: string, data: any, options?: {
|
||||
schema?: string;
|
||||
}): Promise<T> {
|
||||
@@ -222,9 +212,6 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a record
|
||||
*/
|
||||
async delete(entity: string, id: string, options?: {
|
||||
schema?: string;
|
||||
}): Promise<void> {
|
||||
@@ -234,9 +221,6 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Get metadata for an entity
|
||||
*/
|
||||
async meta<T = any>(entity: string, options?: {
|
||||
schema?: string;
|
||||
}): Promise<T> {
|
||||
@@ -245,15 +229,12 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Subscribe to entity changes
|
||||
*/
|
||||
async subscribe(
|
||||
entity: string,
|
||||
callback: (notification: WSNotificationMessage) => void,
|
||||
options?: {
|
||||
schema?: string;
|
||||
filters?: import('./types').FilterOption[];
|
||||
filters?: FilterOption[];
|
||||
}
|
||||
): Promise<string> {
|
||||
this.ensureConnected();
|
||||
@@ -275,7 +256,6 @@ export class WebSocketClient {
|
||||
if (response.success && response.data?.subscription_id) {
|
||||
const subscriptionId = response.data.subscription_id;
|
||||
|
||||
// Store subscription
|
||||
this.subscriptions.set(subscriptionId, {
|
||||
id: subscriptionId,
|
||||
entity,
|
||||
@@ -293,7 +273,6 @@ export class WebSocketClient {
|
||||
|
||||
this.send(message);
|
||||
|
||||
// Timeout
|
||||
setTimeout(() => {
|
||||
if (this.messageHandlers.has(id)) {
|
||||
this.messageHandlers.delete(id);
|
||||
@@ -303,9 +282,6 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Unsubscribe from entity changes
|
||||
*/
|
||||
async unsubscribe(subscriptionId: string): Promise<void> {
|
||||
this.ensureConnected();
|
||||
|
||||
@@ -330,7 +306,6 @@ export class WebSocketClient {
|
||||
|
||||
this.send(message);
|
||||
|
||||
// Timeout
|
||||
setTimeout(() => {
|
||||
if (this.messageHandlers.has(id)) {
|
||||
this.messageHandlers.delete(id);
|
||||
@@ -340,37 +315,22 @@ export class WebSocketClient {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of active subscriptions
|
||||
*/
|
||||
getSubscriptions(): Subscription[] {
|
||||
return Array.from(this.subscriptions.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* Get connection state
|
||||
*/
|
||||
getState(): ConnectionState {
|
||||
return this.state;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if connected
|
||||
*/
|
||||
isConnected(): boolean {
|
||||
return this.ws?.readyState === WebSocket.OPEN;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add event listener
|
||||
*/
|
||||
on<K extends keyof WebSocketClientEvents>(event: K, callback: WebSocketClientEvents[K]): void {
|
||||
this.eventListeners[event] = callback as any;
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove event listener
|
||||
*/
|
||||
off<K extends keyof WebSocketClientEvents>(event: K): void {
|
||||
delete this.eventListeners[event];
|
||||
}
|
||||
@@ -384,7 +344,6 @@ export class WebSocketClient {
|
||||
|
||||
this.emit('message', message);
|
||||
|
||||
// Handle different message types
|
||||
switch (message.type) {
|
||||
case 'response':
|
||||
this.handleResponse(message as WSResponseMessage);
|
||||
@@ -395,7 +354,6 @@ export class WebSocketClient {
|
||||
break;
|
||||
|
||||
case 'pong':
|
||||
// Heartbeat response
|
||||
break;
|
||||
|
||||
default:
|
||||
2
resolvespec-js/src/websocketspec/index.ts
Normal file
2
resolvespec-js/src/websocketspec/index.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
export * from './types';
|
||||
export { WebSocketClient, getWebSocketClient } from './client';
|
||||
@@ -1,17 +1,24 @@
|
||||
import type { FilterOption, SortOption, PreloadOption, Parameter } from '../common/types';
|
||||
|
||||
// Re-export common types
|
||||
export type { FilterOption, SortOption, PreloadOption, Operator, SortDirection } from '../common/types';
|
||||
|
||||
// WebSocket Message Types
|
||||
export type MessageType = 'request' | 'response' | 'notification' | 'subscription' | 'error' | 'ping' | 'pong';
|
||||
export type WSOperation = 'read' | 'create' | 'update' | 'delete' | 'subscribe' | 'unsubscribe' | 'meta';
|
||||
|
||||
// Re-export common types
|
||||
export type { FilterOption, SortOption, PreloadOption, Operator, SortDirection } from './types';
|
||||
|
||||
export interface WSOptions {
|
||||
filters?: import('./types').FilterOption[];
|
||||
filters?: FilterOption[];
|
||||
columns?: string[];
|
||||
preload?: import('./types').PreloadOption[];
|
||||
sort?: import('./types').SortOption[];
|
||||
omit_columns?: string[];
|
||||
preload?: PreloadOption[];
|
||||
sort?: SortOption[];
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
parameters?: Parameter[];
|
||||
cursor_forward?: string;
|
||||
cursor_backward?: string;
|
||||
fetch_row_number?: string;
|
||||
}
|
||||
|
||||
export interface WSMessage {
|
||||
@@ -78,7 +85,7 @@ export interface WSSubscriptionMessage {
|
||||
}
|
||||
|
||||
export interface SubscriptionOptions {
|
||||
filters?: import('./types').FilterOption[];
|
||||
filters?: FilterOption[];
|
||||
onNotification?: (notification: WSNotificationMessage) => void;
|
||||
}
|
||||
|
||||
21
resolvespec-js/tsconfig.json
Normal file
21
resolvespec-js/tsconfig.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"strict": true,
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"sourceMap": true,
|
||||
"outDir": "dist",
|
||||
"rootDir": "src",
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"resolveJsonModule": true,
|
||||
"isolatedModules": true,
|
||||
"lib": ["ES2020", "DOM"]
|
||||
},
|
||||
"include": ["src"],
|
||||
"exclude": ["node_modules", "dist", "src/__tests__"]
|
||||
}
|
||||
20
resolvespec-js/vite.config.ts
Normal file
20
resolvespec-js/vite.config.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
import { defineConfig } from 'vite';
|
||||
import dts from 'vite-plugin-dts';
|
||||
import { resolve } from 'path';
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [
|
||||
dts({ rollupTypes: true }),
|
||||
],
|
||||
build: {
|
||||
lib: {
|
||||
entry: resolve(__dirname, 'src/index.ts'),
|
||||
name: 'ResolveSpec',
|
||||
formats: ['es', 'cjs'],
|
||||
fileName: (format) => `index.${format === 'es' ? 'js' : 'cjs'}`,
|
||||
},
|
||||
rollupOptions: {
|
||||
external: ['uuid', 'semver'],
|
||||
},
|
||||
},
|
||||
});
|
||||
Reference in New Issue
Block a user