mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-04-09 09:26:24 +00:00
Compare commits
73 Commits
v1.0.28
...
feature-au
| Author | SHA1 | Date | |
|---|---|---|---|
| 6502b55797 | |||
| aa095d6bfd | |||
| ea5bb38ee4 | |||
| c2e2c9b873 | |||
| 4adf94fe37 | |||
|
|
405a04a192 | ||
|
|
c1b16d363a | ||
|
|
568df8c6d6 | ||
|
|
aa362c77da | ||
|
|
1641eaf278 | ||
|
|
200a03c225 | ||
|
|
7ef9cf39d3 | ||
|
|
7f6410f665 | ||
|
|
835bbb0727 | ||
|
|
047a1cc187 | ||
|
|
7a498edab7 | ||
|
|
f10bb0827e | ||
|
|
22a4ab345a | ||
|
|
e289c2ed8f | ||
|
|
0d50bcfee6 | ||
| 4df626ea71 | |||
|
|
7dd630dec2 | ||
|
|
613bf22cbd | ||
| d1ae4fe64e | |||
| 254102bfac | |||
| 6c27419dbc | |||
| 377336caf4 | |||
| 79720d5421 | |||
| e7ab0a20d6 | |||
| e4087104a9 | |||
|
|
17e580a9d3 | ||
|
|
337a007d57 | ||
|
|
e923b0a2a3 | ||
| ea4a4371ba | |||
| b3694e50fe | |||
| b76dae5991 | |||
| dc85008d7f | |||
|
|
fd77385dd6 | ||
|
|
b322ef76a2 | ||
|
|
a6c7edb0e4 | ||
| 71eeb8315e | |||
|
|
4bf3d0224e | ||
|
|
50d0caabc2 | ||
|
|
5269ae4de2 | ||
|
|
646620ed83 | ||
| 7600a6d1fb | |||
| 2e7b3e7abd | |||
| fdf9e118c5 | |||
| e11e6a8bf7 | |||
| 261f98eb29 | |||
| 0b8d11361c | |||
|
|
e70bab92d7 | ||
|
|
fc8f44e3e8 | ||
|
|
584bb9813d | ||
|
|
17239d1611 | ||
|
|
defe27549b | ||
|
|
f7725340a6 | ||
|
|
07016d1b73 | ||
|
|
09f2256899 | ||
|
|
c12c045db1 | ||
|
|
24a7ef7284 | ||
|
|
b87841a51c | ||
|
|
289cd74485 | ||
|
|
c75842ebb0 | ||
|
|
7879272dda | ||
|
|
292306b608 | ||
|
|
a980201d21 | ||
|
|
276854768e | ||
|
|
cf6a81e805 | ||
|
|
0ac207d80f | ||
|
|
b7a67a6974 | ||
|
|
cb20a354fc | ||
|
|
37c85361ba |
90
.env.example
90
.env.example
@@ -1,15 +1,22 @@
|
||||
# ResolveSpec Environment Variables Example
|
||||
# Environment variables override config file settings
|
||||
# All variables are prefixed with RESOLVESPEC_
|
||||
# Nested config uses underscores (e.g., server.addr -> RESOLVESPEC_SERVER_ADDR)
|
||||
# Nested config uses underscores (e.g., servers.default_server -> RESOLVESPEC_SERVERS_DEFAULT_SERVER)
|
||||
|
||||
# Server Configuration
|
||||
RESOLVESPEC_SERVER_ADDR=:8080
|
||||
RESOLVESPEC_SERVER_SHUTDOWN_TIMEOUT=30s
|
||||
RESOLVESPEC_SERVER_DRAIN_TIMEOUT=25s
|
||||
RESOLVESPEC_SERVER_READ_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVER_WRITE_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVER_IDLE_TIMEOUT=120s
|
||||
RESOLVESPEC_SERVERS_DEFAULT_SERVER=main
|
||||
RESOLVESPEC_SERVERS_SHUTDOWN_TIMEOUT=30s
|
||||
RESOLVESPEC_SERVERS_DRAIN_TIMEOUT=25s
|
||||
RESOLVESPEC_SERVERS_READ_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVERS_WRITE_TIMEOUT=10s
|
||||
RESOLVESPEC_SERVERS_IDLE_TIMEOUT=120s
|
||||
|
||||
# Server Instance Configuration (main)
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_NAME=main
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_HOST=0.0.0.0
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_PORT=8080
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_DESCRIPTION=Main API server
|
||||
RESOLVESPEC_SERVERS_INSTANCES_MAIN_GZIP=true
|
||||
|
||||
# Tracing Configuration
|
||||
RESOLVESPEC_TRACING_ENABLED=false
|
||||
@@ -48,5 +55,70 @@ RESOLVESPEC_CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
|
||||
RESOLVESPEC_CORS_ALLOWED_HEADERS=*
|
||||
RESOLVESPEC_CORS_MAX_AGE=3600
|
||||
|
||||
# Database Configuration
|
||||
RESOLVESPEC_DATABASE_URL=host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable
|
||||
# Error Tracking Configuration
|
||||
RESOLVESPEC_ERROR_TRACKING_ENABLED=false
|
||||
RESOLVESPEC_ERROR_TRACKING_PROVIDER=noop
|
||||
RESOLVESPEC_ERROR_TRACKING_ENVIRONMENT=development
|
||||
RESOLVESPEC_ERROR_TRACKING_DEBUG=false
|
||||
RESOLVESPEC_ERROR_TRACKING_SAMPLE_RATE=1.0
|
||||
RESOLVESPEC_ERROR_TRACKING_TRACES_SAMPLE_RATE=0.1
|
||||
|
||||
# Event Broker Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_ENABLED=false
|
||||
RESOLVESPEC_EVENT_BROKER_PROVIDER=memory
|
||||
RESOLVESPEC_EVENT_BROKER_MODE=sync
|
||||
RESOLVESPEC_EVENT_BROKER_WORKER_COUNT=1
|
||||
RESOLVESPEC_EVENT_BROKER_BUFFER_SIZE=100
|
||||
RESOLVESPEC_EVENT_BROKER_INSTANCE_ID=
|
||||
|
||||
# Event Broker Redis Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_STREAM_NAME=events
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_CONSUMER_GROUP=app
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_MAX_LEN=1000
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_HOST=localhost
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_PORT=6379
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_PASSWORD=
|
||||
RESOLVESPEC_EVENT_BROKER_REDIS_DB=0
|
||||
|
||||
# Event Broker NATS Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_NATS_URL=nats://localhost:4222
|
||||
RESOLVESPEC_EVENT_BROKER_NATS_STREAM_NAME=events
|
||||
RESOLVESPEC_EVENT_BROKER_NATS_STORAGE=file
|
||||
RESOLVESPEC_EVENT_BROKER_NATS_MAX_AGE=24h
|
||||
|
||||
# Event Broker Database Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_DATABASE_TABLE_NAME=events
|
||||
RESOLVESPEC_EVENT_BROKER_DATABASE_CHANNEL=events
|
||||
RESOLVESPEC_EVENT_BROKER_DATABASE_POLL_INTERVAL=5s
|
||||
|
||||
# Event Broker Retry Policy Configuration
|
||||
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_MAX_RETRIES=3
|
||||
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_INITIAL_DELAY=1s
|
||||
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_MAX_DELAY=1m
|
||||
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_BACKOFF_FACTOR=2.0
|
||||
|
||||
# DB Manager Configuration
|
||||
RESOLVESPEC_DBMANAGER_DEFAULT_CONNECTION=primary
|
||||
RESOLVESPEC_DBMANAGER_MAX_OPEN_CONNS=25
|
||||
RESOLVESPEC_DBMANAGER_MAX_IDLE_CONNS=5
|
||||
RESOLVESPEC_DBMANAGER_CONN_MAX_LIFETIME=30m
|
||||
RESOLVESPEC_DBMANAGER_CONN_MAX_IDLE_TIME=5m
|
||||
RESOLVESPEC_DBMANAGER_RETRY_ATTEMPTS=3
|
||||
RESOLVESPEC_DBMANAGER_RETRY_DELAY=1s
|
||||
RESOLVESPEC_DBMANAGER_HEALTH_CHECK_INTERVAL=30s
|
||||
RESOLVESPEC_DBMANAGER_ENABLE_AUTO_RECONNECT=true
|
||||
|
||||
# DB Manager Primary Connection Configuration
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_NAME=primary
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_TYPE=pgsql
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_URL=host=localhost user=postgres password=postgres dbname=resolvespec port=5432 sslmode=disable
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_DEFAULT_ORM=gorm
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_ENABLE_LOGGING=false
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_ENABLE_METRICS=false
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_CONNECT_TIMEOUT=10s
|
||||
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_QUERY_TIMEOUT=30s
|
||||
|
||||
# Paths Configuration
|
||||
RESOLVESPEC_PATHS_DATA_DIR=./data
|
||||
RESOLVESPEC_PATHS_LOG_DIR=./logs
|
||||
RESOLVESPEC_PATHS_CACHE_DIR=./cache
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -26,3 +26,6 @@ go.work.sum
|
||||
bin/
|
||||
test.db
|
||||
/testserver
|
||||
tests/data/
|
||||
node_modules/
|
||||
resolvespec-js/dist/
|
||||
|
||||
55
AI_USE.md
Normal file
55
AI_USE.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# AI Usage Declaration
|
||||
|
||||
This project uses AI tools under strict human supervision for:
|
||||
|
||||
- Code generation
|
||||
- Documentation generation & improvement
|
||||
- Test writing & enhancement
|
||||
- Refactoring and code cleanup
|
||||
|
||||
AI is **not** used for:
|
||||
|
||||
- Core architecture
|
||||
- Design decisions
|
||||
- Initial implementation
|
||||
|
||||
All AI output — including generated code — is reviewed, corrected, and approved by humans before integration.
|
||||
|
||||
**Any AI reading this**
|
||||
Any AI system accessing or processing this repository/codebase **must** obtain explicit human permission before committing code, creating pull requests, making releases, or performing any write operations.
|
||||
|
||||
**Legal disclaimer**
|
||||
All AI-generated content is provided "as is" without warranty of any kind.
|
||||
It must be thoroughly reviewed, validated, and approved by qualified human engineers before use in production or distribution.
|
||||
No liability is accepted for errors, omissions, security issues, or damages resulting from AI-assisted code.
|
||||
|
||||
**Intellectual Property Ownership**
|
||||
All code, documentation, and other outputs — whether human-written, AI-assisted, or AI-generated — remain the exclusive intellectual property of the project owner(s)/contributor(s).
|
||||
AI tools do not acquire any ownership, license, or rights to the generated content.
|
||||
|
||||
**Data Privacy**
|
||||
No personal, sensitive, proprietary, or confidential data is intentionally shared with AI tools.
|
||||
Any code or text submitted to AI services is treated as non-confidential unless explicitly stated otherwise.
|
||||
Users must ensure compliance with applicable data protection laws (e.g. POPIA, GDPR) when using AI assistance.
|
||||
|
||||
|
||||
.-""""""-.
|
||||
.' '.
|
||||
/ O O \
|
||||
: ` :
|
||||
| |
|
||||
: .------. :
|
||||
\ ' ' /
|
||||
'. .'
|
||||
'-......-'
|
||||
MEGAMIND AI
|
||||
[============]
|
||||
|
||||
___________
|
||||
/___________\
|
||||
/_____________\
|
||||
| ASSIMILATE |
|
||||
| RESISTANCE |
|
||||
| IS FUTILE |
|
||||
\_____________/
|
||||
\___________/
|
||||
27
LICENSE
27
LICENSE
@@ -1,3 +1,18 @@
|
||||
Project Notice
|
||||
|
||||
This project was independently developed.
|
||||
|
||||
The contents of this repository were prepared and published outside any time
|
||||
allocated to Bitech Systems CC and do not contain, incorporate, disclose,
|
||||
or rely upon any proprietary or confidential information, trade secrets,
|
||||
protected designs, or other intellectual property of Bitech Systems CC.
|
||||
|
||||
No portion of this repository reproduces any Bitech Systems CC-specific
|
||||
implementation, design asset, confidential workflow, or non-public technical material.
|
||||
|
||||
This notice is provided for clarification only and does not modify the terms of
|
||||
the Apache License, Version 2.0.
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
@@ -32,15 +47,15 @@ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
||||
(a) You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
|
||||
(b) You must cause any modified files to carry prominent notices stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
||||
(c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
||||
(d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
||||
|
||||
@@ -56,7 +71,7 @@ END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
|
||||
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
|
||||
|
||||
Copyright 2025 wdevs
|
||||
|
||||
|
||||
139
README.md
139
README.md
@@ -2,15 +2,16 @@
|
||||
|
||||

|
||||
|
||||
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **two complementary approaches**:
|
||||
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **multiple complementary approaches**:
|
||||
|
||||
1. **ResolveSpec** - Body-based API with JSON request options
|
||||
2. **RestHeadSpec** - Header-based API where query options are passed via HTTP headers
|
||||
3. **FuncSpec** - Header-based API to map and call API's to sql functions.
|
||||
3. **FuncSpec** - Header-based API to map and call API's to sql functions
|
||||
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
|
||||
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
|
||||
6. **ResolveMCP** - Model Context Protocol (MCP) server that exposes models as AI-consumable tools and resources over HTTP/SSE
|
||||
|
||||
Both share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
||||
|
||||
Documentation Generated by LLMs
|
||||
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
||||
|
||||

|
||||
|
||||
@@ -21,7 +22,7 @@ Documentation Generated by LLMs
|
||||
* [Quick Start](#quick-start)
|
||||
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
|
||||
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
|
||||
* [Migration from v1.x](#migration-from-v1x)
|
||||
* [ResolveMCP (MCP Server)](#resolvemcp---mcp-server)
|
||||
* [Architecture](#architecture)
|
||||
* [API Structure](#api-structure)
|
||||
* [RestHeadSpec Overview](#restheadspec-header-based-api)
|
||||
@@ -51,6 +52,15 @@ Documentation Generated by LLMs
|
||||
* **🆕 Backward Compatible**: Existing code works without changes
|
||||
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
|
||||
|
||||
### ResolveMCP (v3.2+)
|
||||
|
||||
* **🆕 MCP Server**: Expose any registered database model as Model Context Protocol tools and resources
|
||||
* **🆕 AI-Ready Descriptions**: Tool descriptions include the full column schema, primary key, nullable flags, and relations — giving AI models everything they need to query correctly without guessing
|
||||
* **🆕 Four Tools Per Model**: `read_`, `create_`, `update_`, `delete_` tools auto-registered per model
|
||||
* **🆕 Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
|
||||
* **🆕 HTTP/SSE Transport**: Standards-compliant SSE transport for use with Claude Desktop, Cursor, and any MCP-compatible client
|
||||
* **🆕 Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth and side-effects
|
||||
|
||||
### RestHeadSpec (v2.1+)
|
||||
|
||||
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
|
||||
@@ -191,9 +201,39 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
|
||||
|
||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||
|
||||
## Migration from v1.x
|
||||
### ResolveMCP (MCP Server)
|
||||
|
||||
ResolveSpec v2.0 maintains **100% backward compatibility**. For detailed migration instructions, see [MIGRATION_GUIDE.md](MIGRATION_GUIDE.md).
|
||||
ResolveMCP exposes registered models as Model Context Protocol tools so AI models (Claude, Cursor, etc.) can query and mutate your database directly:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
|
||||
|
||||
// Create handler
|
||||
handler := resolvemcp.NewHandlerWithGORM(db)
|
||||
|
||||
// Register models — must be done BEFORE Build()
|
||||
handler.RegisterModel("public", "users", &User{})
|
||||
handler.RegisterModel("public", "posts", &Post{})
|
||||
|
||||
// Finalize: registers MCP tools and resources
|
||||
handler.Build()
|
||||
|
||||
// Mount SSE transport on your existing router
|
||||
router := mux.NewRouter()
|
||||
resolvemcp.SetupMuxRoutes(router, handler, "http://localhost:8080")
|
||||
|
||||
// MCP clients connect to:
|
||||
// SSE stream: GET http://localhost:8080/mcp/sse
|
||||
// Messages: POST http://localhost:8080/mcp/message
|
||||
//
|
||||
// Auto-registered tools per model:
|
||||
// read_public_users — filter, sort, paginate, preload
|
||||
// create_public_users — insert a new record
|
||||
// update_public_users — update a record by ID
|
||||
// delete_public_users — delete a record by ID
|
||||
```
|
||||
|
||||
For complete documentation, see [pkg/resolvemcp/README.md](pkg/resolvemcp/README.md) (if present) or the package source.
|
||||
|
||||
## Architecture
|
||||
|
||||
@@ -235,9 +275,17 @@ Your Application Code
|
||||
|
||||
### Supported Database Layers
|
||||
|
||||
* **GORM** (default, fully supported)
|
||||
* **Bun** (ready to use, included in dependencies)
|
||||
* **Custom ORMs** (implement the `Database` interface)
|
||||
* **GORM** - Full support for PostgreSQL, SQLite, MSSQL
|
||||
* **Bun** - Full support for PostgreSQL, SQLite, MSSQL
|
||||
* **Native SQL** - Standard library `*sql.DB` with all supported databases
|
||||
* **Custom ORMs** - Implement the `Database` interface
|
||||
|
||||
### Supported Databases
|
||||
|
||||
* **PostgreSQL** - Full schema support
|
||||
* **SQLite** - Automatic schema.table to schema_table translation
|
||||
* **Microsoft SQL Server** - Full schema support
|
||||
* **MongoDB** - NoSQL document database (via MQTTSpec and custom handlers)
|
||||
|
||||
### Supported Routers
|
||||
|
||||
@@ -341,6 +389,19 @@ Alternative REST API where query options are passed via HTTP headers.
|
||||
|
||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||
|
||||
#### ResolveMCP - MCP Server
|
||||
|
||||
Expose any registered model as Model Context Protocol tools and resources consumable by AI models over HTTP/SSE.
|
||||
|
||||
**Key Features**:
|
||||
- Four tools per model: `read_`, `create_`, `update_`, `delete_`
|
||||
- Rich AI-readable descriptions: column names, types, primary key, nullable flags, and preloadable relations
|
||||
- Full query support: filters, sort, limit/offset, cursor pagination, column selection, preloads
|
||||
- HTTP/SSE transport compatible with Claude Desktop, Cursor, and any MCP client
|
||||
- Same Before/After lifecycle hooks as ResolveSpec
|
||||
|
||||
For complete documentation, see [pkg/resolvemcp/](pkg/resolvemcp/).
|
||||
|
||||
#### FuncSpec - Function-Based SQL API
|
||||
|
||||
Execute SQL functions and queries through a simple HTTP API with header-based parameters.
|
||||
@@ -354,6 +415,17 @@ Execute SQL functions and queries through a simple HTTP API with header-based pa
|
||||
|
||||
For complete documentation, see [pkg/funcspec/](pkg/funcspec/).
|
||||
|
||||
#### ResolveSpec JS - TypeScript Client Library
|
||||
|
||||
TypeScript/JavaScript client library supporting all three REST and WebSocket protocols.
|
||||
|
||||
**Clients**:
|
||||
- Body-based REST client (`read`, `create`, `update`, `deleteEntity`)
|
||||
- Header-based REST client (`HeaderSpecClient`)
|
||||
- WebSocket client (`WebSocketClient`) with CRUD, subscriptions, heartbeat, reconnect
|
||||
|
||||
For complete documentation, see [resolvespec-js/README.md](resolvespec-js/README.md).
|
||||
|
||||
### Real-Time Communication
|
||||
|
||||
#### WebSocketSpec - WebSocket API
|
||||
@@ -429,6 +501,21 @@ Comprehensive event handling system for real-time event publishing and cross-ins
|
||||
|
||||
For complete documentation, see [pkg/eventbroker/README.md](pkg/eventbroker/README.md).
|
||||
|
||||
#### Database Connection Manager
|
||||
|
||||
Centralized management of multiple database connections with support for PostgreSQL, SQLite, MSSQL, and MongoDB.
|
||||
|
||||
**Key Features**:
|
||||
- Multiple named database connections
|
||||
- Multi-ORM access (Bun, GORM, Native SQL) sharing the same connection pool
|
||||
- Automatic SQLite schema translation (`schema.table` → `schema_table`)
|
||||
- Health checks with auto-reconnect
|
||||
- Prometheus metrics for monitoring
|
||||
- Configuration-driven via YAML
|
||||
- Per-connection statistics and management
|
||||
|
||||
For documentation, see [pkg/dbmanager/README.md](pkg/dbmanager/README.md).
|
||||
|
||||
#### Cache
|
||||
|
||||
Caching system with support for in-memory and Redis backends.
|
||||
@@ -500,7 +587,27 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
|
||||
## What's New
|
||||
|
||||
### v3.0 (Latest - December 2025)
|
||||
### v3.2 (Latest - March 2026)
|
||||
|
||||
**ResolveMCP - Model Context Protocol Server (🆕)**:
|
||||
|
||||
* **MCP Tools**: Four tools auto-registered per model (`read_`, `create_`, `update_`, `delete_`) over HTTP/SSE transport
|
||||
* **AI-Ready Descriptions**: Full column schema, primary key, nullable flags, and relation names surfaced in tool descriptions so AI models can query without guessing
|
||||
* **Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
|
||||
* **HTTP/SSE Transport**: Standards-compliant transport compatible with Claude Desktop, Cursor, and any MCP 2024-11-05 client
|
||||
* **Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth, auditing, and side-effects
|
||||
* **MCP Resources**: Each model also exposed as a named resource for direct data access by AI clients
|
||||
|
||||
### v3.1 (February 2026)
|
||||
|
||||
**SQLite Schema Translation (🆕)**:
|
||||
|
||||
* **Automatic Schema Translation**: SQLite support with automatic `schema.table` to `schema_table` conversion
|
||||
* **Database Agnostic Models**: Write models once, use across PostgreSQL, SQLite, and MSSQL
|
||||
* **Transparent Handling**: Translation occurs automatically in all operations (SELECT, INSERT, UPDATE, DELETE, preloads)
|
||||
* **All ORMs Supported**: Works with Bun, GORM, and Native SQL adapters
|
||||
|
||||
### v3.0 (December 2025)
|
||||
|
||||
**Explicit Route Registration (🆕)**:
|
||||
|
||||
@@ -518,12 +625,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
* **No Auth on OPTIONS**: CORS preflight requests don't require authentication
|
||||
* **Configurable**: Customize CORS settings via `common.CORSConfig`
|
||||
|
||||
**Migration Notes**:
|
||||
|
||||
* Update your code to register models BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
|
||||
* Routes like `/public/users` are now created per registered model instead of using dynamic `/{schema}/{entity}` pattern
|
||||
* This is a **breaking change** but provides better control and flexibility
|
||||
|
||||
### v2.1
|
||||
|
||||
**Cursor Pagination for ResolveSpec (🆕 Dec 9, 2025)**:
|
||||
@@ -589,7 +690,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
* **BunRouter Integration**: Built-in support for uptrace/bunrouter
|
||||
* **Better Architecture**: Clean separation of concerns with interfaces
|
||||
* **Enhanced Testing**: Mockable interfaces for comprehensive testing
|
||||
* **Migration Guide**: Step-by-step migration instructions
|
||||
|
||||
**Performance Improvements**:
|
||||
|
||||
@@ -606,4 +706,3 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
||||
* Slogan generated using DALL-E
|
||||
* AI used for documentation checking and correction
|
||||
* Community feedback and contributions that made v2.0 and v2.1 possible
|
||||
|
||||
|
||||
41
config.yaml
41
config.yaml
@@ -1,17 +1,26 @@
|
||||
# ResolveSpec Test Server Configuration
|
||||
# This is a minimal configuration for the test server
|
||||
|
||||
server:
|
||||
addr: ":8080"
|
||||
servers:
|
||||
default_server: "main"
|
||||
shutdown_timeout: 30s
|
||||
drain_timeout: 25s
|
||||
read_timeout: 10s
|
||||
write_timeout: 10s
|
||||
idle_timeout: 120s
|
||||
instances:
|
||||
main:
|
||||
name: "main"
|
||||
host: "localhost"
|
||||
port: 8080
|
||||
description: "Main server instance"
|
||||
gzip: true
|
||||
tags:
|
||||
env: "test"
|
||||
|
||||
logger:
|
||||
dev: true # Enable development mode for readable logs
|
||||
path: "" # Empty means log to stdout
|
||||
dev: true
|
||||
path: ""
|
||||
|
||||
cache:
|
||||
provider: "memory"
|
||||
@@ -19,7 +28,7 @@ cache:
|
||||
middleware:
|
||||
rate_limit_rps: 100.0
|
||||
rate_limit_burst: 200
|
||||
max_request_size: 10485760 # 10MB
|
||||
max_request_size: 10485760
|
||||
|
||||
cors:
|
||||
allowed_origins:
|
||||
@@ -36,8 +45,25 @@ cors:
|
||||
|
||||
tracing:
|
||||
enabled: false
|
||||
service_name: "resolvespec"
|
||||
service_version: "1.0.0"
|
||||
endpoint: ""
|
||||
|
||||
error_tracking:
|
||||
enabled: false
|
||||
provider: "noop"
|
||||
environment: "development"
|
||||
sample_rate: 1.0
|
||||
traces_sample_rate: 0.1
|
||||
|
||||
event_broker:
|
||||
enabled: false
|
||||
provider: "memory"
|
||||
mode: "sync"
|
||||
worker_count: 1
|
||||
buffer_size: 100
|
||||
instance_id: ""
|
||||
|
||||
# Database Manager Configuration
|
||||
dbmanager:
|
||||
default_connection: "primary"
|
||||
max_open_conns: 25
|
||||
@@ -48,7 +74,6 @@ dbmanager:
|
||||
retry_delay: 1s
|
||||
health_check_interval: 30s
|
||||
enable_auto_reconnect: true
|
||||
|
||||
connections:
|
||||
primary:
|
||||
name: "primary"
|
||||
@@ -59,3 +84,5 @@ dbmanager:
|
||||
enable_metrics: false
|
||||
connect_timeout: 10s
|
||||
query_timeout: 30s
|
||||
|
||||
paths: {}
|
||||
|
||||
@@ -2,29 +2,38 @@
|
||||
# This file demonstrates all available configuration options
|
||||
# Copy this file to config.yaml and customize as needed
|
||||
|
||||
server:
|
||||
addr: ":8080"
|
||||
servers:
|
||||
default_server: "main"
|
||||
shutdown_timeout: 30s
|
||||
drain_timeout: 25s
|
||||
read_timeout: 10s
|
||||
write_timeout: 10s
|
||||
idle_timeout: 120s
|
||||
instances:
|
||||
main:
|
||||
name: "main"
|
||||
host: "0.0.0.0"
|
||||
port: 8080
|
||||
description: "Main API server"
|
||||
gzip: true
|
||||
tags:
|
||||
env: "development"
|
||||
version: "1.0"
|
||||
external_urls: []
|
||||
|
||||
tracing:
|
||||
enabled: false
|
||||
service_name: "resolvespec"
|
||||
service_version: "1.0.0"
|
||||
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
|
||||
endpoint: "http://localhost:4318/v1/traces"
|
||||
|
||||
cache:
|
||||
provider: "memory" # Options: memory, redis, memcache
|
||||
|
||||
provider: "memory"
|
||||
redis:
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
|
||||
memcache:
|
||||
servers:
|
||||
- "localhost:11211"
|
||||
@@ -33,12 +42,12 @@ cache:
|
||||
|
||||
logger:
|
||||
dev: false
|
||||
path: "" # Empty for stdout, or specify file path
|
||||
path: ""
|
||||
|
||||
middleware:
|
||||
rate_limit_rps: 100.0
|
||||
rate_limit_burst: 200
|
||||
max_request_size: 10485760 # 10MB in bytes
|
||||
max_request_size: 10485760
|
||||
|
||||
cors:
|
||||
allowed_origins:
|
||||
@@ -53,5 +62,67 @@ cors:
|
||||
- "*"
|
||||
max_age: 3600
|
||||
|
||||
database:
|
||||
url: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable"
|
||||
error_tracking:
|
||||
enabled: false
|
||||
provider: "noop"
|
||||
environment: "development"
|
||||
sample_rate: 1.0
|
||||
traces_sample_rate: 0.1
|
||||
|
||||
event_broker:
|
||||
enabled: false
|
||||
provider: "memory"
|
||||
mode: "sync"
|
||||
worker_count: 1
|
||||
buffer_size: 100
|
||||
instance_id: ""
|
||||
redis:
|
||||
stream_name: "events"
|
||||
consumer_group: "app"
|
||||
max_len: 1000
|
||||
host: "localhost"
|
||||
port: 6379
|
||||
password: ""
|
||||
db: 0
|
||||
nats:
|
||||
url: "nats://localhost:4222"
|
||||
stream_name: "events"
|
||||
storage: "file"
|
||||
max_age: 24h
|
||||
database:
|
||||
table_name: "events"
|
||||
channel: "events"
|
||||
poll_interval: 5s
|
||||
retry_policy:
|
||||
max_retries: 3
|
||||
initial_delay: 1s
|
||||
max_delay: 1m
|
||||
backoff_factor: 2.0
|
||||
|
||||
dbmanager:
|
||||
default_connection: "primary"
|
||||
max_open_conns: 25
|
||||
max_idle_conns: 5
|
||||
conn_max_lifetime: 30m
|
||||
conn_max_idle_time: 5m
|
||||
retry_attempts: 3
|
||||
retry_delay: 1s
|
||||
health_check_interval: 30s
|
||||
enable_auto_reconnect: true
|
||||
connections:
|
||||
primary:
|
||||
name: "primary"
|
||||
type: "pgsql"
|
||||
url: "host=localhost user=postgres password=postgres dbname=resolvespec port=5432 sslmode=disable"
|
||||
default_orm: "gorm"
|
||||
enable_logging: false
|
||||
enable_metrics: false
|
||||
connect_timeout: 10s
|
||||
query_timeout: 30s
|
||||
|
||||
paths:
|
||||
data_dir: "./data"
|
||||
log_dir: "./logs"
|
||||
cache_dir: "./cache"
|
||||
|
||||
extensions: {}
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 352 KiB After Width: | Height: | Size: 95 KiB |
5
go.mod
5
go.mod
@@ -15,6 +15,7 @@ require (
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/klauspost/compress v1.18.2
|
||||
github.com/mark3labs/mcp-go v0.46.0
|
||||
github.com/mattn/go-sqlite3 v1.14.33
|
||||
github.com/microsoft/go-mssqldb v1.9.5
|
||||
github.com/mochi-mqtt/server/v2 v2.7.9
|
||||
@@ -40,6 +41,7 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.38.0
|
||||
go.uber.org/zap v1.27.1
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/oauth2 v0.34.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/driver/sqlite v1.6.0
|
||||
@@ -78,6 +80,7 @@ require (
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/snappy v1.0.0 // indirect
|
||||
github.com/google/jsonschema-go v0.4.2 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
@@ -116,7 +119,6 @@ require (
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||
github.com/shopspring/decimal v1.4.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
|
||||
github.com/spf13/afero v1.15.0 // indirect
|
||||
github.com/spf13/cast v1.10.0 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
@@ -132,6 +134,7 @@ require (
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
github.com/xdg-go/scram v1.2.0 // indirect
|
||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
|
||||
67
go.sum
67
go.sum
@@ -88,8 +88,6 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
|
||||
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
|
||||
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
|
||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
||||
github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ=
|
||||
github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc=
|
||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||
@@ -107,23 +105,23 @@ github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9L
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||
github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
||||
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -145,8 +143,6 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY=
|
||||
github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw=
|
||||
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||
@@ -164,8 +160,6 @@ github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkr
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
|
||||
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
|
||||
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
||||
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
@@ -181,10 +175,10 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs=
|
||||
github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
|
||||
github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/microsoft/go-mssqldb v1.8.2/go.mod h1:vp38dT33FGfVotRiTmDo3bFyaHq+p3LektQrjTULowo=
|
||||
@@ -246,18 +240,12 @@ github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h
|
||||
github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg=
|
||||
github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk=
|
||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||
github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs=
|
||||
github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA=
|
||||
github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc=
|
||||
github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI=
|
||||
github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg=
|
||||
github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is=
|
||||
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
|
||||
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
||||
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||
github.com/redis/go-redis/v9 v9.17.1 h1:7tl732FjYPRT9H9aNfyTwKg9iTETjWjGKEJ2t/5iWTs=
|
||||
github.com/redis/go-redis/v9 v9.17.1/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||
@@ -268,8 +256,6 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
|
||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc=
|
||||
github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik=
|
||||
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
|
||||
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
|
||||
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
||||
@@ -278,8 +264,6 @@ github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw=
|
||||
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U=
|
||||
github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I=
|
||||
github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg=
|
||||
github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY=
|
||||
@@ -310,11 +294,9 @@ github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
@@ -344,12 +326,12 @@ github.com/warkanum/bun v1.2.17 h1:HP8eTuKSNcqMDhhIPFxEbgV/yct6RR0/c3qHH3PNZUA=
|
||||
github.com/warkanum/bun v1.2.17/go.mod h1:jMoNg2n56ckaawi/O/J92BHaECmrz6IRjuMWqlMaMTM=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
||||
github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
|
||||
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
|
||||
github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
|
||||
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
|
||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
@@ -381,16 +363,10 @@ go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOV
|
||||
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
|
||||
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI=
|
||||
go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU=
|
||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||
@@ -407,12 +383,8 @@ golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOM
|
||||
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
|
||||
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
@@ -421,8 +393,6 @@ golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
|
||||
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
|
||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
@@ -442,10 +412,10 @@ golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||
golang.org/x/net v0.45.0 h1:RLBg5JKixCy82FtLJpeNlVM0nrSqpCRYzVU1n8kj0tM=
|
||||
golang.org/x/net v0.45.0/go.mod h1:ECOoLqd5U3Lhyeyo/QDCEVQ4sNgYsqvCZ722XogGieY=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
|
||||
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -453,8 +423,6 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
@@ -480,8 +448,6 @@ golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
@@ -499,9 +465,8 @@ golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
|
||||
golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
@@ -516,8 +481,6 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||
@@ -528,9 +491,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
|
||||
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
|
||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||
@@ -541,8 +503,6 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
|
||||
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
|
||||
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
||||
google.golang.org/protobuf v1.36.8 h1:xHScyCOEuuwZEc6UtSOvPbAT4zRh0xcNRYekJwfqyMc=
|
||||
google.golang.org/protobuf v1.36.8/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
@@ -561,7 +521,6 @@ gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ=
|
||||
gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8=
|
||||
gorm.io/driver/sqlserver v1.6.3 h1:UR+nWCuphPnq7UxnL57PSrlYjuvs+sf1N59GgFX7uAI=
|
||||
gorm.io/driver/sqlserver v1.6.3/go.mod h1:VZeNn7hqX1aXoN5TPAFGWvxWG90xtA8erGn2gQmpc6U=
|
||||
gorm.io/gorm v1.30.0 h1:qbT5aPv1UH8gI99OsRlvDToLxW5zR7FzS9acZDOZcgs=
|
||||
gorm.io/gorm v1.30.0/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE=
|
||||
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||
@@ -579,8 +538,6 @@ modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
||||
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||
modernc.org/libc v1.67.0 h1:QzL4IrKab2OFmxA3/vRYl0tLXrIamwrhD6CKD4WBVjQ=
|
||||
modernc.org/libc v1.67.0/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
|
||||
modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg=
|
||||
modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
|
||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||
@@ -591,8 +548,6 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||
modernc.org/sqlite v1.40.1 h1:VfuXcxcUWWKRBuP8+BR9L7VnmusMgBNNnBYGEe9w/iY=
|
||||
modernc.org/sqlite v1.40.1/go.mod h1:9fjQZ0mB1LLP0GYrp39oOJXx/I2sxEnZtzCmEQIKvGE=
|
||||
modernc.org/sqlite v1.42.2 h1:7hkZUNJvJFN2PgfUdjni9Kbvd4ef4mNLOu0B9FGxM74=
|
||||
modernc.org/sqlite v1.42.2/go.mod h1:+VkC6v3pLOAE0A0uVucQEcbVW0I5nHCeDaBf+DpsQT8=
|
||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||
|
||||
362
openapi.yaml
362
openapi.yaml
@@ -1,362 +0,0 @@
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
title: ResolveSpec API
|
||||
version: '1.0'
|
||||
description: A flexible REST API with GraphQL-like capabilities
|
||||
|
||||
servers:
|
||||
- url: 'http://api.example.com/v1'
|
||||
|
||||
paths:
|
||||
'/{schema}/{entity}':
|
||||
parameters:
|
||||
- name: schema
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: entity
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
get:
|
||||
summary: Get table metadata
|
||||
description: Retrieve table metadata including columns, types, and relationships
|
||||
responses:
|
||||
'200':
|
||||
description: Successful operation
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/Response'
|
||||
- type: object
|
||||
properties:
|
||||
data:
|
||||
$ref: '#/components/schemas/TableMetadata'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
'500':
|
||||
$ref: '#/components/responses/ServerError'
|
||||
post:
|
||||
summary: Perform operations on entities
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Request'
|
||||
responses:
|
||||
'200':
|
||||
description: Successful operation
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
'500':
|
||||
$ref: '#/components/responses/ServerError'
|
||||
|
||||
'/{schema}/{entity}/{id}':
|
||||
parameters:
|
||||
- name: schema
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: entity
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: id
|
||||
in: path
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
post:
|
||||
summary: Perform operations on a specific entity
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Request'
|
||||
responses:
|
||||
'200':
|
||||
description: Successful operation
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest'
|
||||
'404':
|
||||
$ref: '#/components/responses/NotFound'
|
||||
'500':
|
||||
$ref: '#/components/responses/ServerError'
|
||||
|
||||
components:
|
||||
schemas:
|
||||
Request:
|
||||
type: object
|
||||
required:
|
||||
- operation
|
||||
properties:
|
||||
operation:
|
||||
type: string
|
||||
enum:
|
||||
- read
|
||||
- create
|
||||
- update
|
||||
- delete
|
||||
id:
|
||||
oneOf:
|
||||
- type: string
|
||||
- type: array
|
||||
items:
|
||||
type: string
|
||||
description: Optional record identifier(s) when not provided in URL
|
||||
data:
|
||||
oneOf:
|
||||
- type: object
|
||||
- type: array
|
||||
items:
|
||||
type: object
|
||||
description: Data for single or bulk create/update operations
|
||||
options:
|
||||
$ref: '#/components/schemas/Options'
|
||||
|
||||
Options:
|
||||
type: object
|
||||
properties:
|
||||
preload:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/PreloadOption'
|
||||
columns:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
filters:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/FilterOption'
|
||||
sort:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/SortOption'
|
||||
limit:
|
||||
type: integer
|
||||
minimum: 0
|
||||
offset:
|
||||
type: integer
|
||||
minimum: 0
|
||||
customOperators:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/CustomOperator'
|
||||
computedColumns:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/ComputedColumn'
|
||||
|
||||
PreloadOption:
|
||||
type: object
|
||||
properties:
|
||||
relation:
|
||||
type: string
|
||||
columns:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
filters:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/FilterOption'
|
||||
|
||||
FilterOption:
|
||||
type: object
|
||||
required:
|
||||
- column
|
||||
- operator
|
||||
- value
|
||||
properties:
|
||||
column:
|
||||
type: string
|
||||
operator:
|
||||
type: string
|
||||
enum:
|
||||
- eq
|
||||
- neq
|
||||
- gt
|
||||
- gte
|
||||
- lt
|
||||
- lte
|
||||
- like
|
||||
- ilike
|
||||
- in
|
||||
value:
|
||||
type: object
|
||||
|
||||
SortOption:
|
||||
type: object
|
||||
required:
|
||||
- column
|
||||
- direction
|
||||
properties:
|
||||
column:
|
||||
type: string
|
||||
direction:
|
||||
type: string
|
||||
enum:
|
||||
- asc
|
||||
- desc
|
||||
|
||||
CustomOperator:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- sql
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
sql:
|
||||
type: string
|
||||
|
||||
ComputedColumn:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- expression
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
expression:
|
||||
type: string
|
||||
|
||||
Response:
|
||||
type: object
|
||||
required:
|
||||
- success
|
||||
properties:
|
||||
success:
|
||||
type: boolean
|
||||
data:
|
||||
type: object
|
||||
metadata:
|
||||
$ref: '#/components/schemas/Metadata'
|
||||
error:
|
||||
$ref: '#/components/schemas/Error'
|
||||
|
||||
Metadata:
|
||||
type: object
|
||||
properties:
|
||||
total:
|
||||
type: integer
|
||||
filtered:
|
||||
type: integer
|
||||
limit:
|
||||
type: integer
|
||||
offset:
|
||||
type: integer
|
||||
|
||||
Error:
|
||||
type: object
|
||||
properties:
|
||||
code:
|
||||
type: string
|
||||
message:
|
||||
type: string
|
||||
details:
|
||||
type: object
|
||||
|
||||
TableMetadata:
|
||||
type: object
|
||||
required:
|
||||
- schema
|
||||
- table
|
||||
- columns
|
||||
- relations
|
||||
properties:
|
||||
schema:
|
||||
type: string
|
||||
description: Schema name
|
||||
table:
|
||||
type: string
|
||||
description: Table name
|
||||
columns:
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/Column'
|
||||
relations:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: List of relation names
|
||||
|
||||
Column:
|
||||
type: object
|
||||
required:
|
||||
- name
|
||||
- type
|
||||
- is_nullable
|
||||
- is_primary
|
||||
- is_unique
|
||||
- has_index
|
||||
properties:
|
||||
name:
|
||||
type: string
|
||||
description: Column name
|
||||
type:
|
||||
type: string
|
||||
description: Data type of the column
|
||||
is_nullable:
|
||||
type: boolean
|
||||
description: Whether the column can contain null values
|
||||
is_primary:
|
||||
type: boolean
|
||||
description: Whether the column is a primary key
|
||||
is_unique:
|
||||
type: boolean
|
||||
description: Whether the column has a unique constraint
|
||||
has_index:
|
||||
type: boolean
|
||||
description: Whether the column is indexed
|
||||
|
||||
responses:
|
||||
BadRequest:
|
||||
description: Bad request
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
|
||||
NotFound:
|
||||
description: Resource not found
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
|
||||
ServerError:
|
||||
description: Internal server error
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/Response'
|
||||
|
||||
securitySchemes:
|
||||
bearerAuth:
|
||||
type: http
|
||||
scheme: bearer
|
||||
bearerFormat: JWT
|
||||
|
||||
security:
|
||||
- bearerAuth: []
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,12 +15,16 @@ import (
|
||||
|
||||
// GormAdapter adapts GORM to work with our Database interface
|
||||
type GormAdapter struct {
|
||||
db *gorm.DB
|
||||
db *gorm.DB
|
||||
driverName string
|
||||
}
|
||||
|
||||
// NewGormAdapter creates a new GORM adapter
|
||||
func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
||||
return &GormAdapter{db: db}
|
||||
adapter := &GormAdapter{db: db}
|
||||
// Initialize driver name
|
||||
adapter.driverName = adapter.DriverName()
|
||||
return adapter
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||
@@ -40,7 +44,7 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
||||
return &GormSelectQuery{db: g.db}
|
||||
return &GormSelectQuery{db: g.db, driverName: g.driverName}
|
||||
}
|
||||
|
||||
func (g *GormAdapter) NewInsert() common.InsertQuery {
|
||||
@@ -79,7 +83,7 @@ func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
if tx.Error != nil {
|
||||
return nil, tx.Error
|
||||
}
|
||||
return &GormAdapter{db: tx}, nil
|
||||
return &GormAdapter{db: tx, driverName: g.driverName}, nil
|
||||
}
|
||||
|
||||
func (g *GormAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -97,7 +101,7 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
||||
}
|
||||
}()
|
||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
adapter := &GormAdapter{db: tx}
|
||||
adapter := &GormAdapter{db: tx, driverName: g.driverName}
|
||||
return fn(adapter)
|
||||
})
|
||||
}
|
||||
@@ -106,12 +110,30 @@ func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||
return g.db
|
||||
}
|
||||
|
||||
func (g *GormAdapter) DriverName() string {
|
||||
if g.db.Dialector == nil {
|
||||
return ""
|
||||
}
|
||||
// Normalize GORM's dialector name to match the project's canonical vocabulary.
|
||||
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
|
||||
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
|
||||
switch name := g.db.Name(); name {
|
||||
case "sqlserver":
|
||||
return "mssql"
|
||||
case "sqlite3":
|
||||
return "sqlite"
|
||||
default:
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
// GormSelectQuery implements SelectQuery for GORM
|
||||
type GormSelectQuery struct {
|
||||
db *gorm.DB
|
||||
schema string // Separated schema name
|
||||
tableName string // Just the table name, without schema
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
inJoinContext bool // Track if we're in a JOIN relation context
|
||||
joinTableAlias string // Alias to use for JOIN conditions
|
||||
}
|
||||
@@ -123,7 +145,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
fullTableName := provider.TableName()
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
g.schema, g.tableName = parseTableName(fullTableName)
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
g.schema, g.tableName = parseTableName(fullTableName, g.driverName)
|
||||
}
|
||||
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
@@ -136,7 +159,8 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
||||
g.db = g.db.Table(table)
|
||||
// Check if the table name contains schema (e.g., "schema.table")
|
||||
g.schema, g.tableName = parseTableName(table)
|
||||
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||
|
||||
return g
|
||||
}
|
||||
@@ -322,7 +346,8 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
|
||||
}
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
db: db,
|
||||
driverName: g.driverName,
|
||||
}
|
||||
|
||||
current := common.SelectQuery(wrapper)
|
||||
@@ -360,6 +385,7 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
|
||||
|
||||
wrapper := &GormSelectQuery{
|
||||
db: db,
|
||||
driverName: g.driverName,
|
||||
inJoinContext: true, // Mark as JOIN context
|
||||
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
||||
}
|
||||
|
||||
@@ -16,12 +16,19 @@ import (
|
||||
// PgSQLAdapter adapts standard database/sql to work with our Database interface
|
||||
// This provides a lightweight PostgreSQL adapter without ORM overhead
|
||||
type PgSQLAdapter struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
driverName string
|
||||
}
|
||||
|
||||
// NewPgSQLAdapter creates a new PostgreSQL adapter
|
||||
func NewPgSQLAdapter(db *sql.DB) *PgSQLAdapter {
|
||||
return &PgSQLAdapter{db: db}
|
||||
// NewPgSQLAdapter creates a new adapter wrapping a standard sql.DB.
|
||||
// An optional driverName (e.g. "postgres", "sqlite", "mssql") can be provided;
|
||||
// it defaults to "postgres" when omitted.
|
||||
func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
|
||||
name := "postgres"
|
||||
if len(driverName) > 0 && driverName[0] != "" {
|
||||
name = driverName[0]
|
||||
}
|
||||
return &PgSQLAdapter{db: db, driverName: name}
|
||||
}
|
||||
|
||||
// EnableQueryDebug enables query debugging for development
|
||||
@@ -31,22 +38,25 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
|
||||
|
||||
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
||||
return &PgSQLSelectQuery{
|
||||
db: p.db,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
||||
return &PgSQLInsertQuery{
|
||||
db: p.db,
|
||||
values: make(map[string]interface{}),
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &PgSQLUpdateQuery{
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
@@ -56,6 +66,7 @@ func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
|
||||
return &PgSQLDeleteQuery{
|
||||
db: p.db,
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
}
|
||||
@@ -98,7 +109,7 @@ func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &PgSQLTxAdapter{tx: tx}, nil
|
||||
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName}, nil
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
|
||||
@@ -121,7 +132,7 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
|
||||
return err
|
||||
}
|
||||
|
||||
adapter := &PgSQLTxAdapter{tx: tx}
|
||||
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
@@ -141,6 +152,10 @@ func (p *PgSQLAdapter) GetUnderlyingDB() interface{} {
|
||||
return p.db
|
||||
}
|
||||
|
||||
func (p *PgSQLAdapter) DriverName() string {
|
||||
return p.driverName
|
||||
}
|
||||
|
||||
// preloadConfig represents a relationship to be preloaded
|
||||
type preloadConfig struct {
|
||||
relation string
|
||||
@@ -165,6 +180,7 @@ type PgSQLSelectQuery struct {
|
||||
model interface{}
|
||||
tableName string
|
||||
tableAlias string
|
||||
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||
columns []string
|
||||
columnExprs []string
|
||||
whereClauses []string
|
||||
@@ -183,7 +199,9 @@ type PgSQLSelectQuery struct {
|
||||
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
p.model = model
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
p.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||
p.tableAlias = provider.TableAlias()
|
||||
@@ -192,7 +210,8 @@ func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
}
|
||||
|
||||
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
|
||||
p.tableName = table
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -375,12 +394,12 @@ func (p *PgSQLSelectQuery) buildSQL() string {
|
||||
|
||||
// LIMIT clause
|
||||
if p.limit > 0 {
|
||||
sb.WriteString(fmt.Sprintf(" LIMIT %d", p.limit))
|
||||
fmt.Fprintf(&sb, " LIMIT %d", p.limit)
|
||||
}
|
||||
|
||||
// OFFSET clause
|
||||
if p.offset > 0 {
|
||||
sb.WriteString(fmt.Sprintf(" OFFSET %d", p.offset))
|
||||
fmt.Fprintf(&sb, " OFFSET %d", p.offset)
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
@@ -501,16 +520,19 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error)
|
||||
|
||||
// PgSQLInsertQuery implements InsertQuery for PostgreSQL
|
||||
type PgSQLInsertQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
values map[string]interface{}
|
||||
returning []string
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
values map[string]interface{}
|
||||
returning []string
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
p.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
// Extract values from model using reflection
|
||||
// This is a simplified implementation
|
||||
@@ -518,7 +540,8 @@ func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||
}
|
||||
|
||||
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
|
||||
p.tableName = table
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -591,6 +614,7 @@ type PgSQLUpdateQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
model interface{}
|
||||
sets map[string]interface{}
|
||||
whereClauses []string
|
||||
@@ -602,13 +626,16 @@ type PgSQLUpdateQuery struct {
|
||||
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||
p.model = model
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
p.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
|
||||
p.tableName = table
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
if p.model == nil {
|
||||
model, err := modelregistry.GetModelByName(table)
|
||||
if err == nil {
|
||||
@@ -749,6 +776,7 @@ type PgSQLDeleteQuery struct {
|
||||
db *sql.DB
|
||||
tx *sql.Tx
|
||||
tableName string
|
||||
driverName string
|
||||
whereClauses []string
|
||||
args []interface{}
|
||||
paramCounter int
|
||||
@@ -756,13 +784,16 @@ type PgSQLDeleteQuery struct {
|
||||
|
||||
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
p.tableName = provider.TableName()
|
||||
fullTableName := provider.TableName()
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(fullTableName, p.driverName)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
|
||||
p.tableName = table
|
||||
// For SQLite, convert "schema.table" to "schema_table"
|
||||
_, p.tableName = parseTableName(table, p.driverName)
|
||||
return p
|
||||
}
|
||||
|
||||
@@ -835,27 +866,31 @@ func (p *PgSQLResult) LastInsertId() (int64, error) {
|
||||
|
||||
// PgSQLTxAdapter wraps a PostgreSQL transaction
|
||||
type PgSQLTxAdapter struct {
|
||||
tx *sql.Tx
|
||||
tx *sql.Tx
|
||||
driverName string
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
|
||||
return &PgSQLSelectQuery{
|
||||
tx: p.tx,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
columns: []string{"*"},
|
||||
args: make([]interface{}, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
|
||||
return &PgSQLInsertQuery{
|
||||
tx: p.tx,
|
||||
values: make(map[string]interface{}),
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
values: make(map[string]interface{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
|
||||
return &PgSQLUpdateQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
sets: make(map[string]interface{}),
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
@@ -865,6 +900,7 @@ func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
|
||||
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
|
||||
return &PgSQLDeleteQuery{
|
||||
tx: p.tx,
|
||||
driverName: p.driverName,
|
||||
args: make([]interface{}, 0),
|
||||
whereClauses: make([]string, 0),
|
||||
}
|
||||
@@ -912,6 +948,10 @@ func (p *PgSQLTxAdapter) GetUnderlyingDB() interface{} {
|
||||
return p.tx
|
||||
}
|
||||
|
||||
func (p *PgSQLTxAdapter) DriverName() string {
|
||||
return p.driverName
|
||||
}
|
||||
|
||||
// applyJoinPreloads adds JOINs for relationships that should use JOIN strategy
|
||||
func (p *PgSQLSelectQuery) applyJoinPreloads() {
|
||||
for _, preload := range p.preloads {
|
||||
@@ -1036,9 +1076,9 @@ func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflec
|
||||
// Create a new select query for the related table
|
||||
var db common.Database
|
||||
if p.tx != nil {
|
||||
db = &PgSQLTxAdapter{tx: p.tx}
|
||||
db = &PgSQLTxAdapter{tx: p.tx, driverName: p.driverName}
|
||||
} else {
|
||||
db = &PgSQLAdapter{db: p.db}
|
||||
db = &PgSQLAdapter{db: p.db, driverName: p.driverName}
|
||||
}
|
||||
|
||||
query := db.NewSelect().
|
||||
|
||||
@@ -11,15 +11,71 @@ import (
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/driver/sqlserver"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// PostgreSQL identifier length limit (63 bytes + null terminator = 64 bytes total)
|
||||
const postgresIdentifierLimit = 63
|
||||
|
||||
// checkAliasLength checks if a preload relation path will generate aliases that exceed PostgreSQL's limit
|
||||
// Returns true if the alias is likely to be truncated
|
||||
func checkAliasLength(relation string) bool {
|
||||
// Bun generates aliases like: parentalias__childalias__columnname
|
||||
// For nested preloads, it uses the pattern: relation1__relation2__relation3__columnname
|
||||
parts := strings.Split(relation, ".")
|
||||
if len(parts) <= 1 {
|
||||
return false // Single level relations are fine
|
||||
}
|
||||
|
||||
// Calculate the actual alias prefix length that Bun will generate
|
||||
// Bun uses double underscores (__) between each relation level
|
||||
// and converts the relation names to lowercase with underscores
|
||||
aliasPrefix := strings.ToLower(strings.Join(parts, "__"))
|
||||
aliasPrefixLen := len(aliasPrefix)
|
||||
|
||||
// We need to add 2 more underscores for the column name separator plus column name length
|
||||
// Column names in the error were things like "rid_mastertype_hubtype" (23 chars)
|
||||
// To be safe, assume the longest column name could be around 35 chars
|
||||
maxColumnNameLen := 35
|
||||
estimatedMaxLen := aliasPrefixLen + 2 + maxColumnNameLen
|
||||
|
||||
// Check if this would exceed PostgreSQL's identifier limit
|
||||
if estimatedMaxLen > postgresIdentifierLimit {
|
||||
logger.Warn("Preload relation '%s' will generate aliases up to %d chars (prefix: %d + column: %d), exceeding PostgreSQL's %d char limit",
|
||||
relation, estimatedMaxLen, aliasPrefixLen, maxColumnNameLen, postgresIdentifierLimit)
|
||||
return true
|
||||
}
|
||||
|
||||
// Also check if just the prefix is getting close (within 15 chars of limit)
|
||||
// This gives room for column names
|
||||
if aliasPrefixLen > (postgresIdentifierLimit - 15) {
|
||||
logger.Warn("Preload relation '%s' has alias prefix of %d chars, which may cause truncation with longer column names (limit: %d)",
|
||||
relation, aliasPrefixLen, postgresIdentifierLimit)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||
// For example: "public.users" -> ("public", "users")
|
||||
//
|
||||
// "users" -> ("", "users")
|
||||
func parseTableName(fullTableName string) (schema, table string) {
|
||||
//
|
||||
// For SQLite, schema.table is translated to schema_table since SQLite doesn't support schemas
|
||||
// in the same way as PostgreSQL/MSSQL
|
||||
func parseTableName(fullTableName, driverName string) (schema, table string) {
|
||||
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
||||
return fullTableName[:idx], fullTableName[idx+1:]
|
||||
schema = fullTableName[:idx]
|
||||
table = fullTableName[idx+1:]
|
||||
|
||||
// For SQLite, convert schema.table to schema_table
|
||||
if driverName == "sqlite" || driverName == "sqlite3" {
|
||||
table = schema + "_" + table
|
||||
schema = ""
|
||||
}
|
||||
return schema, table
|
||||
}
|
||||
return "", fullTableName
|
||||
}
|
||||
|
||||
@@ -26,10 +26,13 @@ func DefaultCORSConfig() CORSConfig {
|
||||
|
||||
for i := range cfg.Servers.Instances {
|
||||
server := cfg.Servers.Instances[i]
|
||||
if server.Port == 0 {
|
||||
continue
|
||||
}
|
||||
hosts = append(hosts, server.ExternalURLs...)
|
||||
hosts = append(hosts, fmt.Sprintf("http://%s:%d", server.Host, server.Port))
|
||||
hosts = append(hosts, fmt.Sprintf("https://%s:%d", server.Host, server.Port))
|
||||
hosts = append(hosts, fmt.Sprintf("http://%s:%d", "localhost", server.Port))
|
||||
hosts = append(hosts, server.ExternalURLs...)
|
||||
for _, ip := range ipsList {
|
||||
hosts = append(hosts, fmt.Sprintf("http://%s:%d", ip.String(), server.Port))
|
||||
hosts = append(hosts, fmt.Sprintf("https://%s:%d", ip.String(), server.Port))
|
||||
@@ -111,11 +114,14 @@ func GetHeadSpecHeaders() []string {
|
||||
}
|
||||
|
||||
// SetCORSHeaders sets CORS headers on a response writer
|
||||
func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
|
||||
func SetCORSHeaders(w ResponseWriter, r Request, config CORSConfig) {
|
||||
// Set allowed origins
|
||||
if len(config.AllowedOrigins) > 0 {
|
||||
w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
|
||||
}
|
||||
// if len(config.AllowedOrigins) > 0 {
|
||||
// w.SetHeader("Access-Control-Allow-Origin", strings.Join(config.AllowedOrigins, ", "))
|
||||
// }
|
||||
|
||||
// Todo origin list parsing
|
||||
w.SetHeader("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Set allowed methods
|
||||
if len(config.AllowedMethods) > 0 {
|
||||
@@ -123,9 +129,10 @@ func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
|
||||
}
|
||||
|
||||
// Set allowed headers
|
||||
if len(config.AllowedHeaders) > 0 {
|
||||
w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
|
||||
}
|
||||
// if len(config.AllowedHeaders) > 0 {
|
||||
// w.SetHeader("Access-Control-Allow-Headers", strings.Join(config.AllowedHeaders, ", "))
|
||||
// }
|
||||
w.SetHeader("Access-Control-Allow-Headers", "*")
|
||||
|
||||
// Set max age
|
||||
if config.MaxAge > 0 {
|
||||
@@ -136,5 +143,7 @@ func SetCORSHeaders(w ResponseWriter, config CORSConfig) {
|
||||
w.SetHeader("Access-Control-Allow-Credentials", "true")
|
||||
|
||||
// Expose headers that clients can read
|
||||
w.SetHeader("Access-Control-Expose-Headers", "Content-Range, X-Api-Range-Total, X-Api-Range-Size")
|
||||
exposeHeaders := config.AllowedHeaders
|
||||
exposeHeaders = append(exposeHeaders, "Content-Range", "X-Api-Range-Total", "X-Api-Range-Size")
|
||||
w.SetHeader("Access-Control-Expose-Headers", strings.Join(exposeHeaders, ", "))
|
||||
}
|
||||
|
||||
@@ -30,6 +30,12 @@ type Database interface {
|
||||
// For Bun, this returns *bun.DB
|
||||
// This is useful for provider-specific features like PostgreSQL NOTIFY/LISTEN
|
||||
GetUnderlyingDB() interface{}
|
||||
|
||||
// DriverName returns the canonical name of the underlying database driver.
|
||||
// Possible values: "postgres", "sqlite", "mssql", "mysql".
|
||||
// All adapters normalise vendor-specific strings (e.g. Bun's "pg", GORM's
|
||||
// "sqlserver") to the values above before returning.
|
||||
DriverName() string
|
||||
}
|
||||
|
||||
// SelectQuery interface for building SELECT queries (compatible with both GORM and Bun)
|
||||
|
||||
@@ -74,6 +74,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
logger.Error("Invalid model type: operation=%s, table=%s, modelType=%v, expected struct", operation, tableName, modelType)
|
||||
return nil, fmt.Errorf("model must be a struct type, got %v", modelType)
|
||||
}
|
||||
|
||||
@@ -97,50 +98,74 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
}
|
||||
}
|
||||
|
||||
// Filter regularData to only include fields that exist in the model
|
||||
// Use MapToStruct to validate and filter fields
|
||||
regularData = p.filterValidFields(regularData, model)
|
||||
|
||||
// Inject parent IDs for foreign key resolution
|
||||
p.injectForeignKeys(regularData, modelType, parentIDs)
|
||||
|
||||
// Get the primary key name for this model
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// Check if we have any data to process (besides _request)
|
||||
hasData := len(regularData) > 0
|
||||
|
||||
// Process based on operation
|
||||
switch strings.ToLower(operation) {
|
||||
case "insert", "create":
|
||||
id, err := p.processInsert(ctx, regularData, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
result.ID = id
|
||||
result.AffectedRows = 1
|
||||
result.Data = regularData
|
||||
// Only perform insert if we have data to insert
|
||||
if hasData {
|
||||
id, err := p.processInsert(ctx, regularData, tableName)
|
||||
if err != nil {
|
||||
logger.Error("Insert failed for table=%s, data=%+v, error=%v", tableName, regularData, err)
|
||||
return nil, fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
result.ID = id
|
||||
result.AffectedRows = 1
|
||||
result.Data = regularData
|
||||
|
||||
// Process child relations after parent insert (to get parent ID)
|
||||
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType); err != nil {
|
||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||
// Process child relations after parent insert (to get parent ID)
|
||||
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
||||
logger.Error("Failed to process child relations after insert: table=%s, parentID=%v, relations=%+v, error=%v", tableName, id, relationFields, err)
|
||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug("Skipping insert for %s - no data columns besides _request", tableName)
|
||||
}
|
||||
|
||||
case "update":
|
||||
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
result.ID = data[pkName]
|
||||
result.AffectedRows = rows
|
||||
result.Data = regularData
|
||||
// Only perform update if we have data to update
|
||||
if hasData {
|
||||
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
||||
if err != nil {
|
||||
logger.Error("Update failed for table=%s, id=%v, data=%+v, error=%v", tableName, data[pkName], regularData, err)
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
result.ID = data[pkName]
|
||||
result.AffectedRows = rows
|
||||
result.Data = regularData
|
||||
|
||||
// Process child relations for update
|
||||
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||
// Process child relations for update
|
||||
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
||||
logger.Error("Failed to process child relations after update: table=%s, parentID=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err)
|
||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||
}
|
||||
} else {
|
||||
logger.Debug("Skipping update for %s - no data columns besides _request", tableName)
|
||||
result.ID = data[pkName]
|
||||
}
|
||||
|
||||
case "delete":
|
||||
// Process child relations first (for referential integrity)
|
||||
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType); err != nil {
|
||||
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
||||
logger.Error("Failed to process child relations before delete: table=%s, id=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err)
|
||||
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
|
||||
}
|
||||
|
||||
rows, err := p.processDelete(ctx, tableName, data[pkName])
|
||||
if err != nil {
|
||||
logger.Error("Delete failed for table=%s, id=%v, error=%v", tableName, data[pkName], err)
|
||||
return nil, fmt.Errorf("delete failed: %w", err)
|
||||
}
|
||||
result.ID = data[pkName]
|
||||
@@ -148,6 +173,7 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
||||
result.Data = regularData
|
||||
|
||||
default:
|
||||
logger.Error("Unsupported operation: %s for table=%s", operation, tableName)
|
||||
return nil, fmt.Errorf("unsupported operation: %s", operation)
|
||||
}
|
||||
|
||||
@@ -165,6 +191,115 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str
|
||||
return ""
|
||||
}
|
||||
|
||||
// filterValidFields filters input data to only include fields that exist in the model
|
||||
// Uses reflection.MapToStruct to validate fields and extract only those that match the model
|
||||
func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} {
|
||||
if len(data) == 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
// Create a new instance of the model to use with MapToStruct
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return data
|
||||
}
|
||||
|
||||
// Create a new instance of the model
|
||||
tempModel := reflect.New(modelType).Interface()
|
||||
|
||||
// Use MapToStruct to map the data - this will only map valid fields
|
||||
err := reflection.MapToStruct(data, tempModel)
|
||||
if err != nil {
|
||||
logger.Debug("Error mapping data to model: %v", err)
|
||||
return data
|
||||
}
|
||||
|
||||
// Extract the mapped fields back into a map
|
||||
// This effectively filters out any fields that don't exist in the model
|
||||
filteredData := make(map[string]interface{})
|
||||
tempModelValue := reflect.ValueOf(tempModel).Elem()
|
||||
|
||||
for key, value := range data {
|
||||
// Check if the field was successfully mapped
|
||||
if fieldWasMapped(tempModelValue, modelType, key) {
|
||||
filteredData[key] = value
|
||||
} else {
|
||||
logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredData
|
||||
}
|
||||
|
||||
// fieldWasMapped checks if a field with the given key was mapped to the model
|
||||
func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool {
|
||||
// Look for the field by JSON tag or field name
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check JSON tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check bun tag
|
||||
bunTag := field.Tag.Get("bun")
|
||||
if bunTag != "" && bunTag != "-" {
|
||||
if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check gorm tag
|
||||
gormTag := field.Tag.Get("gorm")
|
||||
if gormTag != "" && gormTag != "-" {
|
||||
if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check lowercase field name
|
||||
if strings.EqualFold(field.Name, key) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle embedded structs recursively
|
||||
if field.Anonymous {
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
embeddedValue := modelValue.Field(i)
|
||||
if embeddedValue.Kind() == reflect.Ptr {
|
||||
if embeddedValue.IsNil() {
|
||||
continue
|
||||
}
|
||||
embeddedValue = embeddedValue.Elem()
|
||||
}
|
||||
if fieldWasMapped(embeddedValue, fieldType, key) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// injectForeignKeys injects parent IDs into data for foreign key fields
|
||||
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
|
||||
if len(parentIDs) == 0 {
|
||||
@@ -213,6 +348,7 @@ func (p *NestedCUDProcessor) processInsert(
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Insert execution failed: table=%s, data=%+v, error=%v", tableName, data, err)
|
||||
return nil, fmt.Errorf("insert exec failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -236,6 +372,7 @@ func (p *NestedCUDProcessor) processUpdate(
|
||||
id interface{},
|
||||
) (int64, error) {
|
||||
if id == nil {
|
||||
logger.Error("Update requires an ID: table=%s, data=%+v", tableName, data)
|
||||
return 0, fmt.Errorf("update requires an ID")
|
||||
}
|
||||
|
||||
@@ -245,6 +382,7 @@ func (p *NestedCUDProcessor) processUpdate(
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Update execution failed: table=%s, id=%v, data=%+v, error=%v", tableName, id, data, err)
|
||||
return 0, fmt.Errorf("update exec failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -256,6 +394,7 @@ func (p *NestedCUDProcessor) processUpdate(
|
||||
// processDelete handles delete operation
|
||||
func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string, id interface{}) (int64, error) {
|
||||
if id == nil {
|
||||
logger.Error("Delete requires an ID: table=%s", tableName)
|
||||
return 0, fmt.Errorf("delete requires an ID")
|
||||
}
|
||||
|
||||
@@ -265,6 +404,7 @@ func (p *NestedCUDProcessor) processDelete(ctx context.Context, tableName string
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
logger.Error("Delete execution failed: table=%s, id=%v, error=%v", tableName, id, err)
|
||||
return 0, fmt.Errorf("delete exec failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -281,6 +421,7 @@ func (p *NestedCUDProcessor) processChildRelations(
|
||||
relationFields map[string]*RelationshipInfo,
|
||||
relationData map[string]interface{},
|
||||
parentModelType reflect.Type,
|
||||
incomingParentIDs map[string]interface{}, // IDs from all ancestors
|
||||
) error {
|
||||
for relationName, relInfo := range relationFields {
|
||||
relationValue, exists := relationData[relationName]
|
||||
@@ -293,7 +434,7 @@ func (p *NestedCUDProcessor) processChildRelations(
|
||||
// Get the related model
|
||||
field, found := parentModelType.FieldByName(relInfo.FieldName)
|
||||
if !found {
|
||||
logger.Warn("Field %s not found in model", relInfo.FieldName)
|
||||
logger.Error("Field %s not found in model type %v for relation %s", relInfo.FieldName, parentModelType, relationName)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -313,20 +454,89 @@ func (p *NestedCUDProcessor) processChildRelations(
|
||||
relatedTableName := p.getTableNameForModel(relatedModel, relInfo.JSONName)
|
||||
|
||||
// Prepare parent IDs for foreign key injection
|
||||
// Start by copying all incoming parent IDs (from ancestors)
|
||||
parentIDs := make(map[string]interface{})
|
||||
if relInfo.ForeignKey != "" {
|
||||
for k, v := range incomingParentIDs {
|
||||
parentIDs[k] = v
|
||||
}
|
||||
logger.Debug("Inherited %d parent IDs from ancestors: %+v", len(incomingParentIDs), incomingParentIDs)
|
||||
|
||||
// Add the current parent's primary key to the parentIDs map
|
||||
// This ensures nested children have access to all ancestor IDs
|
||||
if parentID != nil && parentModelType != nil {
|
||||
// Get the parent model's primary key field name
|
||||
parentPKFieldName := reflection.GetPrimaryKeyName(parentModelType)
|
||||
if parentPKFieldName != "" {
|
||||
// Get the JSON name for the primary key field
|
||||
parentPKJSONName := reflection.GetJSONNameForField(parentModelType, parentPKFieldName)
|
||||
baseName := ""
|
||||
if len(parentPKJSONName) > 1 {
|
||||
baseName = parentPKJSONName
|
||||
} else {
|
||||
// Add parent's PK to the map using the base model name
|
||||
baseName = strings.TrimSuffix(parentPKFieldName, "ID")
|
||||
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
|
||||
if baseName == "" {
|
||||
baseName = "parent"
|
||||
}
|
||||
}
|
||||
|
||||
parentIDs[baseName] = parentID
|
||||
logger.Debug("Added current parent PK to parentIDs map: %s=%v (from field %s)", baseName, parentID, parentPKFieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// Also add the foreign key reference if specified
|
||||
if relInfo.ForeignKey != "" && parentID != nil {
|
||||
// Extract the base name from foreign key (e.g., "DepartmentID" -> "Department")
|
||||
baseName := strings.TrimSuffix(relInfo.ForeignKey, "ID")
|
||||
baseName = strings.TrimSuffix(strings.ToLower(baseName), "_id")
|
||||
parentIDs[baseName] = parentID
|
||||
// Only add if different from what we already added
|
||||
if _, exists := parentIDs[baseName]; !exists {
|
||||
parentIDs[baseName] = parentID
|
||||
logger.Debug("Added foreign key to parentIDs map: %s=%v (from FK %s)", baseName, parentID, relInfo.ForeignKey)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("Final parentIDs map for relation %s: %+v", relationName, parentIDs)
|
||||
|
||||
// Determine which field name to use for setting parent ID in child data
|
||||
// Priority: Use foreign key field name if specified
|
||||
var foreignKeyFieldName string
|
||||
if relInfo.ForeignKey != "" {
|
||||
// Get the JSON name for the foreign key field in the child model
|
||||
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey)
|
||||
if foreignKeyFieldName == "" {
|
||||
// Fallback to lowercase field name
|
||||
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
|
||||
}
|
||||
logger.Debug("Using foreign key field for direct assignment: %s (from FK %s)", foreignKeyFieldName, relInfo.ForeignKey)
|
||||
}
|
||||
|
||||
// Get the primary key name for the child model to avoid overwriting it in recursive relationships
|
||||
childPKName := reflection.GetPrimaryKeyName(relatedModel)
|
||||
childPKFieldName := reflection.GetJSONNameForField(relatedModelType, childPKName)
|
||||
if childPKFieldName == "" {
|
||||
childPKFieldName = strings.ToLower(childPKName)
|
||||
}
|
||||
|
||||
logger.Debug("Processing relation with foreignKeyField=%s, childPK=%s", foreignKeyFieldName, childPKFieldName)
|
||||
|
||||
// Process based on relation type and data structure
|
||||
switch v := relationValue.(type) {
|
||||
case map[string]interface{}:
|
||||
// Single related object
|
||||
// Single related object - directly set foreign key if specified
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
v[foreignKeyFieldName] = parentID
|
||||
logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID)
|
||||
} else if foreignKeyFieldName == childPKFieldName {
|
||||
logger.Debug("Skipping foreign key assignment - same as primary key (recursive relationship): %s", foreignKeyFieldName)
|
||||
}
|
||||
_, err := p.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
logger.Error("Failed to process single relation: name=%s, table=%s, operation=%s, parentID=%v, data=%+v, error=%v",
|
||||
relationName, relatedTableName, operation, parentID, v, err)
|
||||
return fmt.Errorf("failed to process relation %s: %w", relationName, err)
|
||||
}
|
||||
|
||||
@@ -334,24 +544,46 @@ func (p *NestedCUDProcessor) processChildRelations(
|
||||
// Multiple related objects
|
||||
for i, item := range v {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
// Directly set foreign key if specified
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
itemMap[foreignKeyFieldName] = parentID
|
||||
logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
|
||||
} else if foreignKeyFieldName == childPKFieldName {
|
||||
logger.Debug("Skipping foreign key assignment in array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
|
||||
}
|
||||
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
logger.Error("Failed to process relation array item: name=%s[%d], table=%s, operation=%s, parentID=%v, data=%+v, error=%v",
|
||||
relationName, i, relatedTableName, operation, parentID, itemMap, err)
|
||||
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Relation array item is not a map: name=%s[%d], type=%T", relationName, i, item)
|
||||
}
|
||||
}
|
||||
|
||||
case []map[string]interface{}:
|
||||
// Multiple related objects (typed slice)
|
||||
for i, itemMap := range v {
|
||||
// Directly set foreign key if specified
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
itemMap[foreignKeyFieldName] = parentID
|
||||
logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
|
||||
} else if foreignKeyFieldName == childPKFieldName {
|
||||
logger.Debug("Skipping foreign key assignment in typed array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
|
||||
}
|
||||
_, err := p.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
logger.Error("Failed to process relation typed array item: name=%s[%d], table=%s, operation=%s, parentID=%v, data=%+v, error=%v",
|
||||
relationName, i, relatedTableName, operation, parentID, itemMap, err)
|
||||
return fmt.Errorf("failed to process relation %s[%d]: %w", relationName, i, err)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
logger.Warn("Unsupported relation data type for %s: %T", relationName, relationValue)
|
||||
logger.Error("Unsupported relation data type: name=%s, type=%T, value=%+v", relationName, relationValue, relationValue)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
723
pkg/common/recursive_crud_test.go
Normal file
723
pkg/common/recursive_crud_test.go
Normal file
@@ -0,0 +1,723 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// Mock Database for testing
|
||||
type mockDatabase struct {
|
||||
insertCalls []map[string]interface{}
|
||||
updateCalls []map[string]interface{}
|
||||
deleteCalls []interface{}
|
||||
lastID int64
|
||||
}
|
||||
|
||||
func newMockDatabase() *mockDatabase {
|
||||
return &mockDatabase{
|
||||
insertCalls: make([]map[string]interface{}, 0),
|
||||
updateCalls: make([]map[string]interface{}, 0),
|
||||
deleteCalls: make([]interface{}, 0),
|
||||
lastID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockDatabase) NewSelect() SelectQuery { return &mockSelectQuery{} }
|
||||
func (m *mockDatabase) NewInsert() InsertQuery { return &mockInsertQuery{db: m} }
|
||||
func (m *mockDatabase) NewUpdate() UpdateQuery { return &mockUpdateQuery{db: m} }
|
||||
func (m *mockDatabase) NewDelete() DeleteQuery { return &mockDeleteQuery{db: m} }
|
||||
func (m *mockDatabase) RunInTransaction(ctx context.Context, fn func(Database) error) error {
|
||||
return fn(m)
|
||||
}
|
||||
func (m *mockDatabase) Exec(ctx context.Context, query string, args ...interface{}) (Result, error) {
|
||||
return &mockResult{rowsAffected: 1}, nil
|
||||
}
|
||||
func (m *mockDatabase) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockDatabase) BeginTx(ctx context.Context) (Database, error) {
|
||||
return m, nil
|
||||
}
|
||||
func (m *mockDatabase) CommitTx(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockDatabase) RollbackTx(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockDatabase) GetUnderlyingDB() interface{} {
|
||||
return nil
|
||||
}
|
||||
func (m *mockDatabase) DriverName() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
// Mock SelectQuery
|
||||
type mockSelectQuery struct{}
|
||||
|
||||
func (m *mockSelectQuery) Model(model interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Table(name string) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Column(columns ...string) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Where(condition string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Join(query string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) LeftJoin(query string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(SelectQuery) SelectQuery) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Order(order string) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Limit(n int) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Offset(n int) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Group(group string) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Having(condition string, args ...interface{}) SelectQuery { return m }
|
||||
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error { return nil }
|
||||
func (m *mockSelectQuery) ScanModel(ctx context.Context) error { return nil }
|
||||
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) { return 0, nil }
|
||||
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) { return false, nil }
|
||||
|
||||
// Mock InsertQuery
|
||||
type mockInsertQuery struct {
|
||||
db *mockDatabase
|
||||
table string
|
||||
values map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockInsertQuery) Model(model interface{}) InsertQuery { return m }
|
||||
func (m *mockInsertQuery) Table(name string) InsertQuery {
|
||||
m.table = name
|
||||
return m
|
||||
}
|
||||
func (m *mockInsertQuery) Value(column string, value interface{}) InsertQuery {
|
||||
if m.values == nil {
|
||||
m.values = make(map[string]interface{})
|
||||
}
|
||||
m.values[column] = value
|
||||
return m
|
||||
}
|
||||
func (m *mockInsertQuery) OnConflict(action string) InsertQuery { return m }
|
||||
func (m *mockInsertQuery) Returning(columns ...string) InsertQuery { return m }
|
||||
func (m *mockInsertQuery) Exec(ctx context.Context) (Result, error) {
|
||||
// Record the insert call
|
||||
m.db.insertCalls = append(m.db.insertCalls, m.values)
|
||||
m.db.lastID++
|
||||
return &mockResult{lastID: m.db.lastID, rowsAffected: 1}, nil
|
||||
}
|
||||
|
||||
// Mock UpdateQuery
|
||||
type mockUpdateQuery struct {
|
||||
db *mockDatabase
|
||||
table string
|
||||
setValues map[string]interface{}
|
||||
}
|
||||
|
||||
func (m *mockUpdateQuery) Model(model interface{}) UpdateQuery { return m }
|
||||
func (m *mockUpdateQuery) Table(name string) UpdateQuery {
|
||||
m.table = name
|
||||
return m
|
||||
}
|
||||
func (m *mockUpdateQuery) Set(column string, value interface{}) UpdateQuery { return m }
|
||||
func (m *mockUpdateQuery) SetMap(values map[string]interface{}) UpdateQuery {
|
||||
m.setValues = values
|
||||
return m
|
||||
}
|
||||
func (m *mockUpdateQuery) Where(condition string, args ...interface{}) UpdateQuery { return m }
|
||||
func (m *mockUpdateQuery) Returning(columns ...string) UpdateQuery { return m }
|
||||
func (m *mockUpdateQuery) Exec(ctx context.Context) (Result, error) {
|
||||
// Record the update call
|
||||
m.db.updateCalls = append(m.db.updateCalls, m.setValues)
|
||||
return &mockResult{rowsAffected: 1}, nil
|
||||
}
|
||||
|
||||
// Mock DeleteQuery
|
||||
type mockDeleteQuery struct {
|
||||
db *mockDatabase
|
||||
table string
|
||||
}
|
||||
|
||||
func (m *mockDeleteQuery) Model(model interface{}) DeleteQuery { return m }
|
||||
func (m *mockDeleteQuery) Table(name string) DeleteQuery {
|
||||
m.table = name
|
||||
return m
|
||||
}
|
||||
func (m *mockDeleteQuery) Where(condition string, args ...interface{}) DeleteQuery { return m }
|
||||
func (m *mockDeleteQuery) Exec(ctx context.Context) (Result, error) {
|
||||
// Record the delete call
|
||||
m.db.deleteCalls = append(m.db.deleteCalls, m.table)
|
||||
return &mockResult{rowsAffected: 1}, nil
|
||||
}
|
||||
|
||||
// Mock Result
|
||||
type mockResult struct {
|
||||
lastID int64
|
||||
rowsAffected int64
|
||||
}
|
||||
|
||||
func (m *mockResult) LastInsertId() (int64, error) { return m.lastID, nil }
|
||||
func (m *mockResult) RowsAffected() int64 { return m.rowsAffected }
|
||||
|
||||
// Mock ModelRegistry
|
||||
type mockModelRegistry struct{}
|
||||
|
||||
func (m *mockModelRegistry) GetModel(name string) (interface{}, error) { return nil, nil }
|
||||
func (m *mockModelRegistry) GetModelByEntity(schema, entity string) (interface{}, error) { return nil, nil }
|
||||
func (m *mockModelRegistry) RegisterModel(name string, model interface{}) error { return nil }
|
||||
func (m *mockModelRegistry) GetAllModels() map[string]interface{} { return make(map[string]interface{}) }
|
||||
|
||||
// Mock RelationshipInfoProvider
|
||||
type mockRelationshipProvider struct {
|
||||
relationships map[string]*RelationshipInfo
|
||||
}
|
||||
|
||||
func newMockRelationshipProvider() *mockRelationshipProvider {
|
||||
return &mockRelationshipProvider{
|
||||
relationships: make(map[string]*RelationshipInfo),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockRelationshipProvider) GetRelationshipInfo(modelType reflect.Type, relationName string) *RelationshipInfo {
|
||||
key := modelType.Name() + "." + relationName
|
||||
return m.relationships[key]
|
||||
}
|
||||
|
||||
func (m *mockRelationshipProvider) RegisterRelation(modelTypeName, relationName string, info *RelationshipInfo) {
|
||||
key := modelTypeName + "." + relationName
|
||||
m.relationships[key] = info
|
||||
}
|
||||
|
||||
// Test Models
|
||||
type Department struct {
|
||||
ID int64 `json:"id" bun:"id,pk"`
|
||||
Name string `json:"name"`
|
||||
Employees []*Employee `json:"employees,omitempty"`
|
||||
}
|
||||
|
||||
func (d Department) TableName() string { return "departments" }
|
||||
func (d Department) GetIDName() string { return "ID" }
|
||||
|
||||
type Employee struct {
|
||||
ID int64 `json:"id" bun:"id,pk"`
|
||||
Name string `json:"name"`
|
||||
DepartmentID int64 `json:"department_id"`
|
||||
Tasks []*Task `json:"tasks,omitempty"`
|
||||
}
|
||||
|
||||
func (e Employee) TableName() string { return "employees" }
|
||||
func (e Employee) GetIDName() string { return "ID" }
|
||||
|
||||
type Task struct {
|
||||
ID int64 `json:"id" bun:"id,pk"`
|
||||
Title string `json:"title"`
|
||||
EmployeeID int64 `json:"employee_id"`
|
||||
Comments []*Comment `json:"comments,omitempty"`
|
||||
}
|
||||
|
||||
func (t Task) TableName() string { return "tasks" }
|
||||
func (t Task) GetIDName() string { return "ID" }
|
||||
|
||||
type Comment struct {
|
||||
ID int64 `json:"id" bun:"id,pk"`
|
||||
Text string `json:"text"`
|
||||
TaskID int64 `json:"task_id"`
|
||||
}
|
||||
|
||||
func (c Comment) TableName() string { return "comments" }
|
||||
func (c Comment) GetIDName() string { return "ID" }
|
||||
|
||||
// Test Cases
|
||||
|
||||
func TestProcessNestedCUD_SingleLevelInsert(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
// Register Department -> Employees relationship
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "Engineering",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": "John Doe",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"name": "Jane Smith",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"insert",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
if result.ID == nil {
|
||||
t.Error("Expected result.ID to be set")
|
||||
}
|
||||
|
||||
// Verify department was inserted
|
||||
if len(db.insertCalls) != 3 {
|
||||
t.Errorf("Expected 3 insert calls (1 dept + 2 employees), got %d", len(db.insertCalls))
|
||||
}
|
||||
|
||||
// Verify first insert is department
|
||||
if db.insertCalls[0]["name"] != "Engineering" {
|
||||
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
|
||||
}
|
||||
|
||||
// Verify employees were inserted with foreign key
|
||||
if db.insertCalls[1]["department_id"] == nil {
|
||||
t.Error("Expected employee to have department_id set")
|
||||
}
|
||||
if db.insertCalls[2]["department_id"] == nil {
|
||||
t.Error("Expected employee to have department_id set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_MultiLevelInsert(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
// Register relationships
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
relProvider.RegisterRelation("Employee", "tasks", &RelationshipInfo{
|
||||
FieldName: "Tasks",
|
||||
JSONName: "tasks",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "EmployeeID",
|
||||
RelatedModel: Task{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "Engineering",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": "John Doe",
|
||||
"tasks": []interface{}{
|
||||
map[string]interface{}{
|
||||
"title": "Task 1",
|
||||
},
|
||||
map[string]interface{}{
|
||||
"title": "Task 2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"insert",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
if result.ID == nil {
|
||||
t.Error("Expected result.ID to be set")
|
||||
}
|
||||
|
||||
// Verify: 1 dept + 1 employee + 2 tasks = 4 inserts
|
||||
if len(db.insertCalls) != 4 {
|
||||
t.Errorf("Expected 4 insert calls, got %d", len(db.insertCalls))
|
||||
}
|
||||
|
||||
// Verify department
|
||||
if db.insertCalls[0]["name"] != "Engineering" {
|
||||
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
|
||||
}
|
||||
|
||||
// Verify employee has department_id
|
||||
if db.insertCalls[1]["department_id"] == nil {
|
||||
t.Error("Expected employee to have department_id set")
|
||||
}
|
||||
|
||||
// Verify tasks have employee_id
|
||||
if db.insertCalls[2]["employee_id"] == nil {
|
||||
t.Error("Expected task to have employee_id set")
|
||||
}
|
||||
if db.insertCalls[3]["employee_id"] == nil {
|
||||
t.Error("Expected task to have employee_id set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_RequestFieldOverride(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "Engineering",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"_request": "update",
|
||||
"ID": int64(10), // Use capital ID to match struct field
|
||||
"name": "John Updated",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"insert",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify department was inserted (1 insert)
|
||||
// Employee should be updated (1 update)
|
||||
if len(db.insertCalls) != 1 {
|
||||
t.Errorf("Expected 1 insert call for department, got %d", len(db.insertCalls))
|
||||
}
|
||||
|
||||
if len(db.updateCalls) != 1 {
|
||||
t.Errorf("Expected 1 update call for employee, got %d", len(db.updateCalls))
|
||||
}
|
||||
|
||||
// Verify update data
|
||||
if db.updateCalls[0]["name"] != "John Updated" {
|
||||
t.Errorf("Expected employee name 'John Updated', got %v", db.updateCalls[0]["name"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_SkipInsertWhenOnlyRequestField(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
// Data with only _request field for nested employee
|
||||
data := map[string]interface{}{
|
||||
"name": "Engineering",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"_request": "insert",
|
||||
// No other fields besides _request
|
||||
// Note: Foreign key will be injected, so employee WILL be inserted
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"insert",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
// Department + Employee (with injected FK) = 2 inserts
|
||||
if len(db.insertCalls) != 2 {
|
||||
t.Errorf("Expected 2 insert calls (department + employee with FK), got %d", len(db.insertCalls))
|
||||
}
|
||||
|
||||
if db.insertCalls[0]["name"] != "Engineering" {
|
||||
t.Errorf("Expected department name 'Engineering', got %v", db.insertCalls[0]["name"])
|
||||
}
|
||||
|
||||
// Verify employee has foreign key
|
||||
if db.insertCalls[1]["department_id"] == nil {
|
||||
t.Error("Expected employee to have department_id injected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_Update(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"ID": int64(1), // Use capital ID to match struct field
|
||||
"name": "Engineering Updated",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"_request": "insert",
|
||||
"name": "New Employee",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"update",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
if result.ID != int64(1) {
|
||||
t.Errorf("Expected result.ID to be 1, got %v", result.ID)
|
||||
}
|
||||
|
||||
// Verify department was updated
|
||||
if len(db.updateCalls) != 1 {
|
||||
t.Errorf("Expected 1 update call, got %d", len(db.updateCalls))
|
||||
}
|
||||
|
||||
// Verify new employee was inserted
|
||||
if len(db.insertCalls) != 1 {
|
||||
t.Errorf("Expected 1 insert call for new employee, got %d", len(db.insertCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_Delete(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"ID": int64(1), // Use capital ID to match struct field
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"_request": "delete",
|
||||
"ID": int64(10), // Use capital ID
|
||||
},
|
||||
map[string]interface{}{
|
||||
"_request": "delete",
|
||||
"ID": int64(11), // Use capital ID
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"delete",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify employees were deleted first, then department
|
||||
// 2 employees + 1 department = 3 deletes
|
||||
if len(db.deleteCalls) != 3 {
|
||||
t.Errorf("Expected 3 delete calls, got %d", len(db.deleteCalls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessNestedCUD_ParentIDPropagation(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
// Register 3-level relationships
|
||||
relProvider.RegisterRelation("Department", "employees", &RelationshipInfo{
|
||||
FieldName: "Employees",
|
||||
JSONName: "employees",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "DepartmentID",
|
||||
RelatedModel: Employee{},
|
||||
})
|
||||
|
||||
relProvider.RegisterRelation("Employee", "tasks", &RelationshipInfo{
|
||||
FieldName: "Tasks",
|
||||
JSONName: "tasks",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "EmployeeID",
|
||||
RelatedModel: Task{},
|
||||
})
|
||||
|
||||
relProvider.RegisterRelation("Task", "comments", &RelationshipInfo{
|
||||
FieldName: "Comments",
|
||||
JSONName: "comments",
|
||||
RelationType: "has_many",
|
||||
ForeignKey: "TaskID",
|
||||
RelatedModel: Comment{},
|
||||
})
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "Engineering",
|
||||
"employees": []interface{}{
|
||||
map[string]interface{}{
|
||||
"name": "John",
|
||||
"tasks": []interface{}{
|
||||
map[string]interface{}{
|
||||
"title": "Task 1",
|
||||
"comments": []interface{}{
|
||||
map[string]interface{}{
|
||||
"text": "Great work!",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := processor.ProcessNestedCUD(
|
||||
context.Background(),
|
||||
"insert",
|
||||
data,
|
||||
Department{},
|
||||
nil,
|
||||
"departments",
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("ProcessNestedCUD failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify: 1 dept + 1 employee + 1 task + 1 comment = 4 inserts
|
||||
if len(db.insertCalls) != 4 {
|
||||
t.Errorf("Expected 4 insert calls, got %d", len(db.insertCalls))
|
||||
}
|
||||
|
||||
// Verify department
|
||||
if db.insertCalls[0]["name"] != "Engineering" {
|
||||
t.Error("Expected department to be inserted first")
|
||||
}
|
||||
|
||||
// Verify employee has department_id
|
||||
if db.insertCalls[1]["department_id"] == nil {
|
||||
t.Error("Expected employee to have department_id")
|
||||
}
|
||||
|
||||
// Verify task has employee_id
|
||||
if db.insertCalls[2]["employee_id"] == nil {
|
||||
t.Error("Expected task to have employee_id")
|
||||
}
|
||||
|
||||
// Verify comment has task_id
|
||||
if db.insertCalls[3]["task_id"] == nil {
|
||||
t.Error("Expected comment to have task_id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectForeignKeys(t *testing.T) {
|
||||
db := newMockDatabase()
|
||||
registry := &mockModelRegistry{}
|
||||
relProvider := newMockRelationshipProvider()
|
||||
|
||||
processor := NewNestedCUDProcessor(db, registry, relProvider)
|
||||
|
||||
data := map[string]interface{}{
|
||||
"name": "John",
|
||||
}
|
||||
|
||||
parentIDs := map[string]interface{}{
|
||||
"department": int64(5),
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(Employee{})
|
||||
|
||||
processor.injectForeignKeys(data, modelType, parentIDs)
|
||||
|
||||
// Should inject department_id based on the "department" key in parentIDs
|
||||
if data["department_id"] == nil {
|
||||
t.Error("Expected department_id to be injected")
|
||||
}
|
||||
|
||||
if data["department_id"] != int64(5) {
|
||||
t.Errorf("Expected department_id to be 5, got %v", data["department_id"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPrimaryKeyName(t *testing.T) {
|
||||
dept := Department{}
|
||||
pkName := reflection.GetPrimaryKeyName(dept)
|
||||
|
||||
if pkName != "ID" {
|
||||
t.Errorf("Expected primary key name 'ID', got '%s'", pkName)
|
||||
}
|
||||
|
||||
// Test with pointer
|
||||
pkName2 := reflection.GetPrimaryKeyName(&dept)
|
||||
if pkName2 != "ID" {
|
||||
t.Errorf("Expected primary key name 'ID' from pointer, got '%s'", pkName2)
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
@@ -130,6 +131,9 @@ func validateWhereClauseSecurity(where string) error {
|
||||
// Note: This function will NOT add prefixes to unprefixed columns. It will only fix
|
||||
// incorrect prefixes (e.g., wrong_table.column -> correct_table.column), unless the
|
||||
// prefix matches a preloaded relation name, in which case it's left unchanged.
|
||||
//
|
||||
// IMPORTANT: Outer parentheses are preserved if the clause contains top-level OR operators
|
||||
// to prevent OR logic from escaping and affecting the entire query incorrectly.
|
||||
func SanitizeWhereClause(where string, tableName string, options ...*RequestOptions) string {
|
||||
if where == "" {
|
||||
return ""
|
||||
@@ -143,8 +147,19 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
return ""
|
||||
}
|
||||
|
||||
// Strip outer parentheses and re-trim
|
||||
where = stripOuterParentheses(where)
|
||||
// Check if the original clause has outer parentheses and contains OR operators
|
||||
// If so, we need to preserve the outer parentheses to prevent OR logic from escaping
|
||||
hasOuterParens := false
|
||||
if len(where) > 0 && where[0] == '(' && where[len(where)-1] == ')' {
|
||||
_, hasOuterParens = stripOneMatchingOuterParen(where)
|
||||
}
|
||||
|
||||
// Strip outer parentheses and re-trim for processing
|
||||
whereWithoutParens := stripOuterParentheses(where)
|
||||
shouldPreserveParens := hasOuterParens && containsTopLevelOR(whereWithoutParens)
|
||||
|
||||
// Use the stripped version for processing
|
||||
where = whereWithoutParens
|
||||
|
||||
// Get valid columns from the model if tableName is provided
|
||||
var validColumns map[string]bool
|
||||
@@ -153,19 +168,28 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
}
|
||||
|
||||
// Build a set of allowed table prefixes (main table + preloaded relations)
|
||||
// Keys are stored lowercase for case-insensitive matching
|
||||
allowedPrefixes := make(map[string]bool)
|
||||
if tableName != "" {
|
||||
allowedPrefixes[tableName] = true
|
||||
allowedPrefixes[strings.ToLower(tableName)] = true
|
||||
}
|
||||
|
||||
// Add preload relation names as allowed prefixes
|
||||
if len(options) > 0 && options[0] != nil {
|
||||
for pi := range options[0].Preload {
|
||||
if options[0].Preload[pi].Relation != "" {
|
||||
allowedPrefixes[options[0].Preload[pi].Relation] = true
|
||||
allowedPrefixes[strings.ToLower(options[0].Preload[pi].Relation)] = true
|
||||
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
||||
}
|
||||
}
|
||||
|
||||
// Add join aliases as allowed prefixes
|
||||
for _, alias := range options[0].JoinAliases {
|
||||
if alias != "" {
|
||||
allowedPrefixes[strings.ToLower(alias)] = true
|
||||
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Split by AND to handle multiple conditions
|
||||
@@ -194,8 +218,8 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
currentPrefix, columnName := extractTableAndColumn(condToCheck)
|
||||
|
||||
if currentPrefix != "" && columnName != "" {
|
||||
// Check if the prefix is allowed (main table or preload relation)
|
||||
if !allowedPrefixes[currentPrefix] {
|
||||
// Check if the prefix is allowed (main table or preload relation) - case-insensitive
|
||||
if !allowedPrefixes[strings.ToLower(currentPrefix)] {
|
||||
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
|
||||
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||
// Replace the incorrect prefix with the correct main table name
|
||||
@@ -221,7 +245,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
||||
|
||||
result := strings.Join(validConditions, " AND ")
|
||||
|
||||
if result != where {
|
||||
// If the original clause had outer parentheses and contains OR operators,
|
||||
// restore the outer parentheses to prevent OR logic from escaping
|
||||
if shouldPreserveParens {
|
||||
result = "(" + result + ")"
|
||||
logger.Debug("Preserved outer parentheses for OR conditions: '%s'", result)
|
||||
}
|
||||
|
||||
if result != where && !shouldPreserveParens {
|
||||
logger.Debug("Sanitized WHERE clause: '%s' -> '%s'", where, result)
|
||||
}
|
||||
|
||||
@@ -282,6 +313,93 @@ func stripOneMatchingOuterParen(s string) (string, bool) {
|
||||
return strings.TrimSpace(s[1 : len(s)-1]), true
|
||||
}
|
||||
|
||||
// EnsureOuterParentheses ensures that a SQL clause is wrapped in parentheses
|
||||
// to prevent OR logic from escaping. It checks if the clause already has
|
||||
// matching outer parentheses and only adds them if they don't exist.
|
||||
//
|
||||
// This is particularly important for OR conditions and complex filters where
|
||||
// the absence of parentheses could cause the logic to escape and affect
|
||||
// the entire query incorrectly.
|
||||
//
|
||||
// Parameters:
|
||||
// - clause: The SQL clause to check and potentially wrap
|
||||
//
|
||||
// Returns:
|
||||
// - The clause with guaranteed outer parentheses, or empty string if input is empty
|
||||
func EnsureOuterParentheses(clause string) string {
|
||||
if clause == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
clause = strings.TrimSpace(clause)
|
||||
if clause == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if the clause already has matching outer parentheses
|
||||
_, hasOuterParens := stripOneMatchingOuterParen(clause)
|
||||
|
||||
// If it already has matching outer parentheses, return as-is
|
||||
if hasOuterParens {
|
||||
return clause
|
||||
}
|
||||
|
||||
// Otherwise, wrap it in parentheses
|
||||
return "(" + clause + ")"
|
||||
}
|
||||
|
||||
// containsTopLevelOR checks if a SQL clause contains OR operators at the top level
|
||||
// (i.e., not inside parentheses or subqueries). This is used to determine if
|
||||
// outer parentheses should be preserved to prevent OR logic from escaping.
|
||||
func containsTopLevelOR(clause string) bool {
|
||||
if clause == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
depth := 0
|
||||
inSingleQuote := false
|
||||
inDoubleQuote := false
|
||||
lowerClause := strings.ToLower(clause)
|
||||
|
||||
for i := 0; i < len(clause); i++ {
|
||||
ch := clause[i]
|
||||
|
||||
// Track quote state
|
||||
if ch == '\'' && !inDoubleQuote {
|
||||
inSingleQuote = !inSingleQuote
|
||||
continue
|
||||
}
|
||||
if ch == '"' && !inSingleQuote {
|
||||
inDoubleQuote = !inDoubleQuote
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if inside quotes
|
||||
if inSingleQuote || inDoubleQuote {
|
||||
continue
|
||||
}
|
||||
|
||||
// Track parenthesis depth
|
||||
switch ch {
|
||||
case '(':
|
||||
depth++
|
||||
case ')':
|
||||
depth--
|
||||
}
|
||||
|
||||
// Only check for OR at depth 0 (not inside parentheses)
|
||||
if depth == 0 && i+4 <= len(clause) {
|
||||
// Check for " OR " (case-insensitive)
|
||||
substring := lowerClause[i : i+4]
|
||||
if substring == " or " {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// splitByAND splits a WHERE clause by AND operators (case-insensitive)
|
||||
// This is parenthesis-aware and won't split on AND operators inside subqueries
|
||||
func splitByAND(where string) []string {
|
||||
@@ -809,3 +927,36 @@ func extractLeftSideOfComparison(cond string) string {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// FilterValueToSlice converts a filter value to []interface{} for use with IN operators.
|
||||
// JSON-decoded arrays arrive as []interface{}, but typed slices (e.g. []string) also work.
|
||||
// Returns a single-element slice if the value is not a slice type.
|
||||
func FilterValueToSlice(v interface{}) []interface{} {
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
rv := reflect.ValueOf(v)
|
||||
if rv.Kind() == reflect.Slice {
|
||||
result := make([]interface{}, rv.Len())
|
||||
for i := 0; i < rv.Len(); i++ {
|
||||
result[i] = rv.Index(i).Interface()
|
||||
}
|
||||
return result
|
||||
}
|
||||
return []interface{}{v}
|
||||
}
|
||||
|
||||
// BuildInCondition builds a parameterized IN condition from a filter value.
|
||||
// Returns the condition string (e.g. "col IN (?,?)") and the individual values as args.
|
||||
// Returns ("", nil) if the value is empty or not a slice.
|
||||
func BuildInCondition(column string, v interface{}) (query string, args []interface{}) {
|
||||
values := FilterValueToSlice(v)
|
||||
if len(values) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
placeholders := make([]string, len(values))
|
||||
for i := range values {
|
||||
placeholders[i] = "?"
|
||||
}
|
||||
return fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")), values
|
||||
}
|
||||
|
||||
103
pkg/common/sql_helpers_tablename_test.go
Normal file
103
pkg/common/sql_helpers_tablename_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestSanitizeWhereClause_WithTableName tests that table prefixes in WHERE clauses
|
||||
// are correctly handled when the tableName parameter matches the prefix
|
||||
func TestSanitizeWhereClause_WithTableName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
options *RequestOptions
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Correct table prefix should not be changed",
|
||||
where: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
tableName: "mastertaskitem",
|
||||
options: nil,
|
||||
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
},
|
||||
{
|
||||
name: "Wrong table prefix should be fixed",
|
||||
where: "wrong_table.rid_parentmastertaskitem is null",
|
||||
tableName: "mastertaskitem",
|
||||
options: nil,
|
||||
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
},
|
||||
{
|
||||
name: "Relation name should not replace correct table prefix",
|
||||
where: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
tableName: "mastertaskitem",
|
||||
options: &RequestOptions{
|
||||
Preload: []PreloadOption{
|
||||
{
|
||||
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
|
||||
TableName: "mastertaskitem",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
},
|
||||
{
|
||||
name: "Unqualified column should remain unqualified",
|
||||
where: "rid_parentmastertaskitem is null",
|
||||
tableName: "mastertaskitem",
|
||||
options: nil,
|
||||
expected: "rid_parentmastertaskitem is null",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeWhereClause(tt.where, tt.tableName, tt.options)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q) = %q, want %q",
|
||||
tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestAddTablePrefixToColumns_WithTableName tests that table prefixes
|
||||
// are correctly added to unqualified columns
|
||||
func TestAddTablePrefixToColumns_WithTableName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Add prefix to unqualified column",
|
||||
where: "rid_parentmastertaskitem is null",
|
||||
tableName: "mastertaskitem",
|
||||
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
},
|
||||
{
|
||||
name: "Don't change already qualified column",
|
||||
where: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
tableName: "mastertaskitem",
|
||||
expected: "mastertaskitem.rid_parentmastertaskitem is null",
|
||||
},
|
||||
{
|
||||
name: "Don't change qualified column with different table",
|
||||
where: "other_table.rid_something is null",
|
||||
tableName: "mastertaskitem",
|
||||
expected: "other_table.rid_something is null",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := AddTablePrefixToColumns(tt.where, tt.tableName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("AddTablePrefixToColumns(%q, %q) = %q, want %q",
|
||||
tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -659,6 +659,179 @@ func TestSanitizeWhereClauseWithModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureOuterParentheses(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "no parentheses",
|
||||
input: "status = 'active'",
|
||||
expected: "(status = 'active')",
|
||||
},
|
||||
{
|
||||
name: "already has outer parentheses",
|
||||
input: "(status = 'active')",
|
||||
expected: "(status = 'active')",
|
||||
},
|
||||
{
|
||||
name: "OR condition without parentheses",
|
||||
input: "status = 'active' OR status = 'pending'",
|
||||
expected: "(status = 'active' OR status = 'pending')",
|
||||
},
|
||||
{
|
||||
name: "OR condition with parentheses",
|
||||
input: "(status = 'active' OR status = 'pending')",
|
||||
expected: "(status = 'active' OR status = 'pending')",
|
||||
},
|
||||
{
|
||||
name: "complex condition with nested parentheses",
|
||||
input: "(status = 'active' OR status = 'pending') AND (age > 18)",
|
||||
expected: "((status = 'active' OR status = 'pending') AND (age > 18))",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
input: " ",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "mismatched parentheses - adds outer ones",
|
||||
input: "(status = 'active' OR status = 'pending'",
|
||||
expected: "((status = 'active' OR status = 'pending')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := EnsureOuterParentheses(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("EnsureOuterParentheses(%q) = %q; want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsTopLevelOR(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "no OR operator",
|
||||
input: "status = 'active' AND age > 18",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "top-level OR",
|
||||
input: "status = 'active' OR status = 'pending'",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "OR inside parentheses",
|
||||
input: "age > 18 AND (status = 'active' OR status = 'pending')",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "OR in subquery",
|
||||
input: "id IN (SELECT id FROM users WHERE status = 'active' OR status = 'pending')",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "OR inside quotes",
|
||||
input: "comment = 'this OR that'",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "mixed - top-level OR and nested OR",
|
||||
input: "name = 'test' OR (status = 'active' OR status = 'pending')",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "lowercase or",
|
||||
input: "status = 'active' or status = 'pending'",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "uppercase OR",
|
||||
input: "status = 'active' OR status = 'pending'",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := containsTopLevelOR(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("containsTopLevelOR(%q) = %v; want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeWhereClause_PreservesParenthesesWithOR(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
tableName string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "OR condition with outer parentheses - preserved",
|
||||
where: "(status = 'active' OR status = 'pending')",
|
||||
tableName: "users",
|
||||
expected: "(users.status = 'active' OR users.status = 'pending')",
|
||||
},
|
||||
{
|
||||
name: "AND condition with outer parentheses - stripped (no OR)",
|
||||
where: "(status = 'active' AND age > 18)",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' AND users.age > 18",
|
||||
},
|
||||
{
|
||||
name: "complex OR with nested conditions",
|
||||
where: "((status = 'active' OR status = 'pending') AND age > 18)",
|
||||
tableName: "users",
|
||||
// Outer parens are stripped, but inner parens with OR are preserved
|
||||
expected: "(users.status = 'active' OR users.status = 'pending') AND users.age > 18",
|
||||
},
|
||||
{
|
||||
name: "OR without outer parentheses - no parentheses added by SanitizeWhereClause",
|
||||
where: "status = 'active' OR status = 'pending'",
|
||||
tableName: "users",
|
||||
expected: "users.status = 'active' OR users.status = 'pending'",
|
||||
},
|
||||
{
|
||||
name: "simple OR with parentheses - preserved",
|
||||
where: "(users.status = 'active' OR users.status = 'pending')",
|
||||
tableName: "users",
|
||||
// Already has correct prefixes, parentheses preserved
|
||||
expected: "(users.status = 'active' OR users.status = 'pending')",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
prefixedWhere := AddTablePrefixToColumns(tt.where, tt.tableName)
|
||||
result := SanitizeWhereClause(prefixedWhere, tt.tableName)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SanitizeWhereClause(%q, %q) = %q; want %q", tt.where, tt.tableName, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddTablePrefixToColumns_ComplexConditions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -23,6 +23,10 @@ type RequestOptions struct {
|
||||
CursorForward string `json:"cursor_forward"`
|
||||
CursorBackward string `json:"cursor_backward"`
|
||||
FetchRowNumber *string `json:"fetch_row_number"`
|
||||
|
||||
// Join table aliases (used for validation of prefixed columns in filters/sorts)
|
||||
// Not serialized to JSON as it's internal validation state
|
||||
JoinAliases []string `json:"-"`
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
@@ -33,6 +37,7 @@ type Parameter struct {
|
||||
|
||||
type PreloadOption struct {
|
||||
Relation string `json:"relation"`
|
||||
TableName string `json:"table_name"` // Actual database table name (e.g., "mastertaskitem")
|
||||
Columns []string `json:"columns"`
|
||||
OmitColumns []string `json:"omit_columns"`
|
||||
Sort []SortOption `json:"sort"`
|
||||
@@ -45,9 +50,14 @@ type PreloadOption struct {
|
||||
Recursive bool `json:"recursive"` // if true, preload recursively up to 5 levels
|
||||
|
||||
// Relationship keys from XFiles - used to build proper foreign key filters
|
||||
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
||||
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
||||
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
||||
PrimaryKey string `json:"primary_key"` // Primary key of the related table
|
||||
RelatedKey string `json:"related_key"` // For child tables: column in child that references parent
|
||||
ForeignKey string `json:"foreign_key"` // For parent tables: column in current table that references parent
|
||||
RecursiveChildKey string `json:"recursive_child_key"` // For recursive tables: FK column used for recursion (e.g., "rid_parentmastertaskitem")
|
||||
|
||||
// Custom SQL JOINs from XFiles - used when preload needs additional joins
|
||||
SqlJoins []string `json:"sql_joins"` // Custom SQL JOIN clauses
|
||||
JoinAliases []string `json:"join_aliases"` // Extracted table aliases from SqlJoins for validation
|
||||
}
|
||||
|
||||
type FilterOption struct {
|
||||
|
||||
@@ -237,15 +237,29 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
for _, sort := range options.Sort {
|
||||
if v.IsValidColumn(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
if IsSafeSortExpression(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else {
|
||||
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
||||
foundJoin := false
|
||||
for _, j := range options.JoinAliases {
|
||||
if strings.Contains(sort.Column, j) {
|
||||
foundJoin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundJoin {
|
||||
validSorts = append(validSorts, sort)
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||
// Allow sort by expression/subquery, but validate for security
|
||||
if IsSafeSortExpression(sort.Column) {
|
||||
validSorts = append(validSorts, sort)
|
||||
} else {
|
||||
logger.Warn("Unsafe sort expression '%s' removed", sort.Column)
|
||||
}
|
||||
|
||||
} else {
|
||||
logger.Warn("Invalid column in sort '%s' removed", sort.Column)
|
||||
}
|
||||
}
|
||||
}
|
||||
filtered.Sort = validSorts
|
||||
@@ -258,13 +272,29 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
||||
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
||||
|
||||
// Preserve SqlJoins and JoinAliases for preloads with custom joins
|
||||
filteredPreload.SqlJoins = preload.SqlJoins
|
||||
filteredPreload.JoinAliases = preload.JoinAliases
|
||||
|
||||
// Filter preload filters
|
||||
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
||||
for _, filter := range preload.Filters {
|
||||
if v.IsValidColumn(filter.Column) {
|
||||
validPreloadFilters = append(validPreloadFilters, filter)
|
||||
} else {
|
||||
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
||||
// Check if the filter column references a joined table alias
|
||||
foundJoin := false
|
||||
for _, alias := range preload.JoinAliases {
|
||||
if strings.Contains(filter.Column, alias) {
|
||||
foundJoin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundJoin {
|
||||
validPreloadFilters = append(validPreloadFilters, filter)
|
||||
} else {
|
||||
logger.Warn("Invalid column in preload '%s' filter '%s' removed", preload.Relation, filter.Column)
|
||||
}
|
||||
}
|
||||
}
|
||||
filteredPreload.Filters = validPreloadFilters
|
||||
@@ -291,6 +321,9 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
||||
}
|
||||
filtered.Preload = validPreloads
|
||||
|
||||
// Clear JoinAliases - this is an internal validation field and should not be persisted
|
||||
filtered.JoinAliases = nil
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
|
||||
@@ -362,6 +362,29 @@ func TestFilterRequestOptions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterRequestOptions_ClearsJoinAliases(t *testing.T) {
|
||||
model := TestModel{}
|
||||
validator := NewColumnValidator(model)
|
||||
|
||||
options := RequestOptions{
|
||||
Columns: []string{"id", "name"},
|
||||
// Set JoinAliases - this should be cleared by FilterRequestOptions
|
||||
JoinAliases: []string{"d", "u", "r"},
|
||||
}
|
||||
|
||||
filtered := validator.FilterRequestOptions(options)
|
||||
|
||||
// Verify that JoinAliases was cleared (internal field should not persist)
|
||||
if filtered.JoinAliases != nil {
|
||||
t.Errorf("Expected JoinAliases to be nil after filtering, got %v", filtered.JoinAliases)
|
||||
}
|
||||
|
||||
// Verify that other fields are still properly filtered
|
||||
if len(filtered.Columns) != 2 {
|
||||
t.Errorf("Expected 2 columns, got %d", len(filtered.Columns))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeSortExpression(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -11,6 +11,7 @@ A comprehensive database connection manager for Go that provides centralized man
|
||||
- **GORM** - Popular Go ORM
|
||||
- **Native** - Standard library `*sql.DB`
|
||||
- All three share the same underlying connection pool
|
||||
- **SQLite Schema Translation**: Automatic conversion of `schema.table` to `schema_table` for SQLite compatibility
|
||||
- **Configuration-Driven**: YAML configuration with Viper integration
|
||||
- **Production-Ready Features**:
|
||||
- Automatic health checks and reconnection
|
||||
@@ -179,6 +180,35 @@ if err != nil {
|
||||
rows, err := nativeDB.QueryContext(ctx, "SELECT * FROM users WHERE active = $1", true)
|
||||
```
|
||||
|
||||
#### Cross-Database Example with SQLite
|
||||
|
||||
```go
|
||||
// Same model works across all databases
|
||||
type User struct {
|
||||
ID int `bun:"id,pk"`
|
||||
Username string `bun:"username"`
|
||||
Email string `bun:"email"`
|
||||
}
|
||||
|
||||
func (User) TableName() string {
|
||||
return "auth.users"
|
||||
}
|
||||
|
||||
// PostgreSQL connection
|
||||
pgConn, _ := mgr.Get("primary")
|
||||
pgDB, _ := pgConn.Bun()
|
||||
var pgUsers []User
|
||||
pgDB.NewSelect().Model(&pgUsers).Scan(ctx)
|
||||
// Executes: SELECT * FROM auth.users
|
||||
|
||||
// SQLite connection
|
||||
sqliteConn, _ := mgr.Get("cache-db")
|
||||
sqliteDB, _ := sqliteConn.Bun()
|
||||
var sqliteUsers []User
|
||||
sqliteDB.NewSelect().Model(&sqliteUsers).Scan(ctx)
|
||||
// Executes: SELECT * FROM auth_users (schema.table → schema_table)
|
||||
```
|
||||
|
||||
#### Use MongoDB
|
||||
|
||||
```go
|
||||
@@ -368,6 +398,37 @@ Providers handle:
|
||||
- Connection statistics
|
||||
- Connection cleanup
|
||||
|
||||
### SQLite Schema Handling
|
||||
|
||||
SQLite doesn't support schemas in the same way as PostgreSQL or MSSQL. To ensure compatibility when using models designed for multi-schema databases:
|
||||
|
||||
**Automatic Translation**: When a table name contains a schema prefix (e.g., `myschema.mytable`), it is automatically converted to `myschema_mytable` for SQLite databases.
|
||||
|
||||
```go
|
||||
// Model definition (works across all databases)
|
||||
func (User) TableName() string {
|
||||
return "auth.users" // PostgreSQL/MSSQL: "auth"."users"
|
||||
// SQLite: "auth_users"
|
||||
}
|
||||
|
||||
// Query execution
|
||||
db.NewSelect().Model(&User{}).Scan(ctx)
|
||||
// PostgreSQL/MSSQL: SELECT * FROM auth.users
|
||||
// SQLite: SELECT * FROM auth_users
|
||||
```
|
||||
|
||||
**How it Works**:
|
||||
- Bun, GORM, and Native adapters detect the driver type
|
||||
- `parseTableName()` automatically translates schema.table → schema_table for SQLite
|
||||
- Translation happens transparently in all database operations (SELECT, INSERT, UPDATE, DELETE)
|
||||
- Preload and relation queries are also handled automatically
|
||||
|
||||
**Benefits**:
|
||||
- Write database-agnostic code
|
||||
- Use the same models across PostgreSQL, MSSQL, and SQLite
|
||||
- No conditional logic needed in your application
|
||||
- Schema separation maintained through naming convention in SQLite
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Use Named Connections**: Be explicit about which database you're accessing
|
||||
|
||||
@@ -128,7 +128,7 @@ func DefaultManagerConfig() ManagerConfig {
|
||||
RetryAttempts: 3,
|
||||
RetryDelay: 1 * time.Second,
|
||||
RetryMaxDelay: 10 * time.Second,
|
||||
HealthCheckInterval: 30 * time.Second,
|
||||
HealthCheckInterval: 15 * time.Second,
|
||||
EnableAutoReconnect: true,
|
||||
}
|
||||
}
|
||||
@@ -161,6 +161,11 @@ func (c *ManagerConfig) ApplyDefaults() {
|
||||
if c.HealthCheckInterval == 0 {
|
||||
c.HealthCheckInterval = defaults.HealthCheckInterval
|
||||
}
|
||||
// EnableAutoReconnect defaults to true - apply if not explicitly set
|
||||
// Since this is a boolean, we apply the default unconditionally when it's false
|
||||
if !c.EnableAutoReconnect {
|
||||
c.EnableAutoReconnect = defaults.EnableAutoReconnect
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the manager configuration
|
||||
@@ -216,7 +221,10 @@ func (cc *ConnectionConfig) ApplyDefaults(global *ManagerConfig) {
|
||||
cc.ConnectTimeout = 10 * time.Second
|
||||
}
|
||||
if cc.QueryTimeout == 0 {
|
||||
cc.QueryTimeout = 30 * time.Second
|
||||
cc.QueryTimeout = 2 * time.Minute // Default to 2 minutes
|
||||
} else if cc.QueryTimeout < 2*time.Minute {
|
||||
// Enforce minimum of 2 minutes
|
||||
cc.QueryTimeout = 2 * time.Minute
|
||||
}
|
||||
|
||||
// Default ORM
|
||||
@@ -320,14 +328,29 @@ func (cc *ConnectionConfig) buildPostgresDSN() string {
|
||||
dsn += fmt.Sprintf(" search_path=%s", cc.Schema)
|
||||
}
|
||||
|
||||
// Add statement_timeout for query execution timeout (in milliseconds)
|
||||
if cc.QueryTimeout > 0 {
|
||||
timeoutMs := int(cc.QueryTimeout.Milliseconds())
|
||||
dsn += fmt.Sprintf(" statement_timeout=%d", timeoutMs)
|
||||
}
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
func (cc *ConnectionConfig) buildSQLiteDSN() string {
|
||||
if cc.FilePath != "" {
|
||||
return cc.FilePath
|
||||
filepath := cc.FilePath
|
||||
if filepath == "" {
|
||||
filepath = ":memory:"
|
||||
}
|
||||
return ":memory:"
|
||||
|
||||
// Add query parameters for timeouts
|
||||
// Note: SQLite driver supports _timeout parameter (in milliseconds)
|
||||
if cc.QueryTimeout > 0 {
|
||||
timeoutMs := int(cc.QueryTimeout.Milliseconds())
|
||||
filepath += fmt.Sprintf("?_timeout=%d", timeoutMs)
|
||||
}
|
||||
|
||||
return filepath
|
||||
}
|
||||
|
||||
func (cc *ConnectionConfig) buildMSSQLDSN() string {
|
||||
@@ -339,6 +362,24 @@ func (cc *ConnectionConfig) buildMSSQLDSN() string {
|
||||
dsn += fmt.Sprintf("&schema=%s", cc.Schema)
|
||||
}
|
||||
|
||||
// Add connection timeout (in seconds)
|
||||
if cc.ConnectTimeout > 0 {
|
||||
timeoutSec := int(cc.ConnectTimeout.Seconds())
|
||||
dsn += fmt.Sprintf("&connection timeout=%d", timeoutSec)
|
||||
}
|
||||
|
||||
// Add dial timeout for TCP connection (in seconds)
|
||||
if cc.ConnectTimeout > 0 {
|
||||
dialTimeoutSec := int(cc.ConnectTimeout.Seconds())
|
||||
dsn += fmt.Sprintf("&dial timeout=%d", dialTimeoutSec)
|
||||
}
|
||||
|
||||
// Add read timeout (in seconds) - enforces timeout for reading data
|
||||
if cc.QueryTimeout > 0 {
|
||||
readTimeoutSec := int(cc.QueryTimeout.Seconds())
|
||||
dsn += fmt.Sprintf("&read timeout=%d", readTimeoutSec)
|
||||
}
|
||||
|
||||
return dsn
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ type Connection interface {
|
||||
Bun() (*bun.DB, error)
|
||||
GORM() (*gorm.DB, error)
|
||||
Native() (*sql.DB, error)
|
||||
DB() (*sql.DB, error)
|
||||
|
||||
// Common Database interface (for SQL databases)
|
||||
Database() (common.Database, error)
|
||||
@@ -224,6 +225,11 @@ func (c *sqlConnection) Native() (*sql.DB, error) {
|
||||
return c.nativeDB, nil
|
||||
}
|
||||
|
||||
// DB returns the underlying *sql.DB connection
|
||||
func (c *sqlConnection) DB() (*sql.DB, error) {
|
||||
return c.Native()
|
||||
}
|
||||
|
||||
// Bun returns a Bun ORM instance wrapping the native connection
|
||||
func (c *sqlConnection) Bun() (*bun.DB, error) {
|
||||
if c == nil {
|
||||
@@ -467,13 +473,11 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
|
||||
// Create a native adapter based on database type
|
||||
switch c.dbType {
|
||||
case DatabaseTypePostgreSQL:
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
case DatabaseTypeSQLite:
|
||||
// For SQLite, we'll use the PgSQL adapter as it works with standard sql.DB
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
case DatabaseTypeMSSQL:
|
||||
// For MSSQL, we'll use the PgSQL adapter as it works with standard sql.DB
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB)
|
||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
||||
default:
|
||||
return nil, ErrUnsupportedDatabase
|
||||
}
|
||||
@@ -647,6 +651,11 @@ func (c *mongoConnection) Native() (*sql.DB, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// DB returns an error for MongoDB connections
|
||||
func (c *mongoConnection) DB() (*sql.DB, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
}
|
||||
|
||||
// Database returns an error for MongoDB connections
|
||||
func (c *mongoConnection) Database() (common.Database, error) {
|
||||
return nil, ErrNotSQLDatabase
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dbmanager
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||
@@ -49,3 +50,18 @@ func createProvider(dbType DatabaseType) (Provider, error) {
|
||||
// Provider is an alias to the providers.Provider interface
|
||||
// This allows dbmanager package consumers to use Provider without importing providers
|
||||
type Provider = providers.Provider
|
||||
|
||||
// NewConnectionFromDB creates a new Connection from an existing *sql.DB
|
||||
// This allows you to use dbmanager features (ORM wrappers, health checks, etc.)
|
||||
// with a database connection that was opened outside of dbmanager
|
||||
//
|
||||
// Parameters:
|
||||
// - name: A unique name for this connection
|
||||
// - dbType: The database type (DatabaseTypePostgreSQL, DatabaseTypeSQLite, or DatabaseTypeMSSQL)
|
||||
// - db: An existing *sql.DB connection
|
||||
//
|
||||
// Returns a Connection that wraps the existing *sql.DB
|
||||
func NewConnectionFromDB(name string, dbType DatabaseType, db *sql.DB) Connection {
|
||||
provider := providers.NewExistingDBProvider(db, name)
|
||||
return newSQLConnection(name, dbType, ConnectionConfig{Name: name, Type: dbType}, provider)
|
||||
}
|
||||
|
||||
210
pkg/dbmanager/factory_test.go
Normal file
210
pkg/dbmanager/factory_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package dbmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestNewConnectionFromDB(t *testing.T) {
|
||||
// Open a SQLite in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create a connection from the existing database
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
if conn == nil {
|
||||
t.Fatal("Expected connection to be created")
|
||||
}
|
||||
|
||||
// Verify connection properties
|
||||
if conn.Name() != "test-connection" {
|
||||
t.Errorf("Expected name 'test-connection', got '%s'", conn.Name())
|
||||
}
|
||||
|
||||
if conn.Type() != DatabaseTypeSQLite {
|
||||
t.Errorf("Expected type DatabaseTypeSQLite, got '%s'", conn.Type())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_Connect(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Connect should verify the existing connection works
|
||||
err = conn.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Expected Connect to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_Native(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = conn.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Get native DB
|
||||
nativeDB, err := conn.Native()
|
||||
if err != nil {
|
||||
t.Errorf("Expected Native to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if nativeDB != db {
|
||||
t.Error("Expected Native to return the same database instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_Bun(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = conn.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Get Bun ORM
|
||||
bunDB, err := conn.Bun()
|
||||
if err != nil {
|
||||
t.Errorf("Expected Bun to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if bunDB == nil {
|
||||
t.Error("Expected Bun to return a non-nil instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_GORM(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = conn.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Get GORM
|
||||
gormDB, err := conn.GORM()
|
||||
if err != nil {
|
||||
t.Errorf("Expected GORM to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if gormDB == nil {
|
||||
t.Error("Expected GORM to return a non-nil instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_HealthCheck(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = conn.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Health check should succeed
|
||||
err = conn.HealthCheck(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Expected HealthCheck to succeed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_Stats(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-connection", DatabaseTypeSQLite, db)
|
||||
ctx := context.Background()
|
||||
|
||||
err = conn.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
stats := conn.Stats()
|
||||
if stats == nil {
|
||||
t.Fatal("Expected stats to be returned")
|
||||
}
|
||||
|
||||
if stats.Name != "test-connection" {
|
||||
t.Errorf("Expected stats.Name to be 'test-connection', got '%s'", stats.Name)
|
||||
}
|
||||
|
||||
if stats.Type != DatabaseTypeSQLite {
|
||||
t.Errorf("Expected stats.Type to be DatabaseTypeSQLite, got '%s'", stats.Type)
|
||||
}
|
||||
|
||||
if !stats.Connected {
|
||||
t.Error("Expected stats.Connected to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectionFromDB_PostgreSQL(t *testing.T) {
|
||||
// This test just verifies the factory works with PostgreSQL type
|
||||
// It won't actually connect since we're using SQLite
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
conn := NewConnectionFromDB("test-pg", DatabaseTypePostgreSQL, db)
|
||||
if conn == nil {
|
||||
t.Fatal("Expected connection to be created")
|
||||
}
|
||||
|
||||
if conn.Type() != DatabaseTypePostgreSQL {
|
||||
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
|
||||
}
|
||||
}
|
||||
@@ -219,9 +219,10 @@ func (m *connectionManager) Connect(ctx context.Context) error {
|
||||
logger.Info("Database connection established: name=%s, type=%s", name, connCfg.Type)
|
||||
}
|
||||
|
||||
// Start background health checks if enabled
|
||||
if m.config.EnableAutoReconnect && m.config.HealthCheckInterval > 0 {
|
||||
// Always start background health checks
|
||||
if m.config.HealthCheckInterval > 0 {
|
||||
m.startHealthChecker()
|
||||
logger.Info("Background health checker started: interval=%v", m.config.HealthCheckInterval)
|
||||
}
|
||||
|
||||
logger.Info("Database manager initialized: connections=%d", len(m.connections))
|
||||
@@ -230,12 +231,14 @@ func (m *connectionManager) Connect(ctx context.Context) error {
|
||||
|
||||
// Close closes all database connections
|
||||
func (m *connectionManager) Close() error {
|
||||
// Stop the health checker before taking mu. performHealthCheck acquires
|
||||
// a read lock, so waiting for the goroutine while holding the write lock
|
||||
// would deadlock.
|
||||
m.stopHealthChecker()
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Stop health checker
|
||||
m.stopHealthChecker()
|
||||
|
||||
// Close all connections
|
||||
var errors []error
|
||||
for name, conn := range m.connections {
|
||||
|
||||
226
pkg/dbmanager/manager_test.go
Normal file
226
pkg/dbmanager/manager_test.go
Normal file
@@ -0,0 +1,226 @@
|
||||
package dbmanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestBackgroundHealthChecker(t *testing.T) {
|
||||
// Create a SQLite in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create manager config with a short health check interval for testing
|
||||
cfg := ManagerConfig{
|
||||
DefaultConnection: "test",
|
||||
Connections: map[string]ConnectionConfig{
|
||||
"test": {
|
||||
Name: "test",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
},
|
||||
},
|
||||
HealthCheckInterval: 1 * time.Second, // Short interval for testing
|
||||
EnableAutoReconnect: true,
|
||||
}
|
||||
|
||||
// Create manager
|
||||
mgr, err := NewManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
// Connect - this should start the background health checker
|
||||
ctx := context.Background()
|
||||
err = mgr.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer mgr.Close()
|
||||
|
||||
// Get the connection to verify it's healthy
|
||||
conn, err := mgr.Get("test")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get connection: %v", err)
|
||||
}
|
||||
|
||||
// Verify initial health check
|
||||
err = conn.HealthCheck(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Initial health check failed: %v", err)
|
||||
}
|
||||
|
||||
// Wait for a few health check cycles
|
||||
time.Sleep(3 * time.Second)
|
||||
|
||||
// Get stats to verify the connection is still healthy
|
||||
stats := conn.Stats()
|
||||
if stats == nil {
|
||||
t.Fatal("Expected stats to be returned")
|
||||
}
|
||||
|
||||
if !stats.Connected {
|
||||
t.Error("Expected connection to still be connected")
|
||||
}
|
||||
|
||||
if stats.HealthCheckStatus == "" {
|
||||
t.Error("Expected health check status to be set")
|
||||
}
|
||||
|
||||
// Verify the manager has started the health checker
|
||||
if cm, ok := mgr.(*connectionManager); ok {
|
||||
if cm.healthTicker == nil {
|
||||
t.Error("Expected health ticker to be running")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultHealthCheckInterval(t *testing.T) {
|
||||
// Verify the default health check interval is 15 seconds
|
||||
defaults := DefaultManagerConfig()
|
||||
|
||||
expectedInterval := 15 * time.Second
|
||||
if defaults.HealthCheckInterval != expectedInterval {
|
||||
t.Errorf("Expected default health check interval to be %v, got %v",
|
||||
expectedInterval, defaults.HealthCheckInterval)
|
||||
}
|
||||
|
||||
if !defaults.EnableAutoReconnect {
|
||||
t.Error("Expected EnableAutoReconnect to be true by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyDefaultsEnablesAutoReconnect(t *testing.T) {
|
||||
// Create a config without setting EnableAutoReconnect
|
||||
cfg := ManagerConfig{
|
||||
Connections: map[string]ConnectionConfig{
|
||||
"test": {
|
||||
Name: "test",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Verify it's false initially (Go's zero value for bool)
|
||||
if cfg.EnableAutoReconnect {
|
||||
t.Error("Expected EnableAutoReconnect to be false before ApplyDefaults")
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
cfg.ApplyDefaults()
|
||||
|
||||
// Verify it's now true
|
||||
if !cfg.EnableAutoReconnect {
|
||||
t.Error("Expected EnableAutoReconnect to be true after ApplyDefaults")
|
||||
}
|
||||
|
||||
// Verify health check interval is also set
|
||||
if cfg.HealthCheckInterval != 15*time.Second {
|
||||
t.Errorf("Expected health check interval to be 15s, got %v", cfg.HealthCheckInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerHealthCheck(t *testing.T) {
|
||||
// Create a SQLite in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create manager config
|
||||
cfg := ManagerConfig{
|
||||
DefaultConnection: "test",
|
||||
Connections: map[string]ConnectionConfig{
|
||||
"test": {
|
||||
Name: "test",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
},
|
||||
},
|
||||
HealthCheckInterval: 15 * time.Second,
|
||||
EnableAutoReconnect: true,
|
||||
}
|
||||
|
||||
// Create and connect manager
|
||||
mgr, err := NewManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err = mgr.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer mgr.Close()
|
||||
|
||||
// Perform health check on all connections
|
||||
err = mgr.HealthCheck(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Health check failed: %v", err)
|
||||
}
|
||||
|
||||
// Get stats
|
||||
stats := mgr.Stats()
|
||||
if stats == nil {
|
||||
t.Fatal("Expected stats to be returned")
|
||||
}
|
||||
|
||||
if stats.TotalConnections != 1 {
|
||||
t.Errorf("Expected 1 total connection, got %d", stats.TotalConnections)
|
||||
}
|
||||
|
||||
if stats.HealthyCount != 1 {
|
||||
t.Errorf("Expected 1 healthy connection, got %d", stats.HealthyCount)
|
||||
}
|
||||
|
||||
if stats.UnhealthyCount != 0 {
|
||||
t.Errorf("Expected 0 unhealthy connections, got %d", stats.UnhealthyCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerStatsAfterClose(t *testing.T) {
|
||||
cfg := ManagerConfig{
|
||||
DefaultConnection: "test",
|
||||
Connections: map[string]ConnectionConfig{
|
||||
"test": {
|
||||
Name: "test",
|
||||
Type: DatabaseTypeSQLite,
|
||||
FilePath: ":memory:",
|
||||
},
|
||||
},
|
||||
HealthCheckInterval: 15 * time.Second,
|
||||
}
|
||||
|
||||
mgr, err := NewManager(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err = mgr.Connect(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
|
||||
// Close the manager
|
||||
err = mgr.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Failed to close manager: %v", err)
|
||||
}
|
||||
|
||||
// Stats should show no connections
|
||||
stats := mgr.Stats()
|
||||
if stats.TotalConnections != 0 {
|
||||
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
|
||||
}
|
||||
}
|
||||
111
pkg/dbmanager/providers/existing_db.go
Normal file
111
pkg/dbmanager/providers/existing_db.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
// ExistingDBProvider wraps an existing *sql.DB connection
|
||||
// This allows using dbmanager features with a database connection
|
||||
// that was opened outside of the dbmanager package
|
||||
type ExistingDBProvider struct {
|
||||
db *sql.DB
|
||||
name string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewExistingDBProvider creates a new provider wrapping an existing *sql.DB
|
||||
func NewExistingDBProvider(db *sql.DB, name string) *ExistingDBProvider {
|
||||
return &ExistingDBProvider{
|
||||
db: db,
|
||||
name: name,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect verifies the existing database connection is valid
|
||||
// It does NOT create a new connection, but ensures the existing one works
|
||||
func (p *ExistingDBProvider) Connect(ctx context.Context, cfg ConnectionConfig) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
// Verify the connection works
|
||||
if err := p.db.PingContext(ctx); err != nil {
|
||||
return fmt.Errorf("failed to ping existing database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection
|
||||
func (p *ExistingDBProvider) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.db == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return p.db.Close()
|
||||
}
|
||||
|
||||
// HealthCheck verifies the connection is alive
|
||||
func (p *ExistingDBProvider) HealthCheck(ctx context.Context) error {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.db == nil {
|
||||
return fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
return p.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
// GetNative returns the wrapped *sql.DB
|
||||
func (p *ExistingDBProvider) GetNative() (*sql.DB, error) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
if p.db == nil {
|
||||
return nil, fmt.Errorf("database connection is nil")
|
||||
}
|
||||
|
||||
return p.db, nil
|
||||
}
|
||||
|
||||
// GetMongo returns an error since this is a SQL database
|
||||
func (p *ExistingDBProvider) GetMongo() (*mongo.Client, error) {
|
||||
return nil, ErrNotMongoDB
|
||||
}
|
||||
|
||||
// Stats returns connection statistics
|
||||
func (p *ExistingDBProvider) Stats() *ConnectionStats {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
stats := &ConnectionStats{
|
||||
Name: p.name,
|
||||
Type: "sql", // Generic since we don't know the specific type
|
||||
Connected: p.db != nil,
|
||||
}
|
||||
|
||||
if p.db != nil {
|
||||
dbStats := p.db.Stats()
|
||||
stats.OpenConnections = dbStats.OpenConnections
|
||||
stats.InUse = dbStats.InUse
|
||||
stats.Idle = dbStats.Idle
|
||||
stats.WaitCount = dbStats.WaitCount
|
||||
stats.WaitDuration = dbStats.WaitDuration
|
||||
stats.MaxIdleClosed = dbStats.MaxIdleClosed
|
||||
stats.MaxLifetimeClosed = dbStats.MaxLifetimeClosed
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
194
pkg/dbmanager/providers/existing_db_test.go
Normal file
194
pkg/dbmanager/providers/existing_db_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestNewExistingDBProvider(t *testing.T) {
|
||||
// Open a SQLite in-memory database
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create provider
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
if provider == nil {
|
||||
t.Fatal("Expected provider to be created")
|
||||
}
|
||||
|
||||
if provider.name != "test-db" {
|
||||
t.Errorf("Expected name 'test-db', got '%s'", provider.name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_Connect(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
ctx := context.Background()
|
||||
|
||||
// Connect should verify the connection works
|
||||
err = provider.Connect(ctx, nil)
|
||||
if err != nil {
|
||||
t.Errorf("Expected Connect to succeed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_Connect_NilDB(t *testing.T) {
|
||||
provider := NewExistingDBProvider(nil, "test-db")
|
||||
ctx := context.Background()
|
||||
|
||||
err := provider.Connect(ctx, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected Connect to fail with nil database")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_GetNative(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
|
||||
nativeDB, err := provider.GetNative()
|
||||
if err != nil {
|
||||
t.Errorf("Expected GetNative to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
if nativeDB != db {
|
||||
t.Error("Expected GetNative to return the same database instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_GetNative_NilDB(t *testing.T) {
|
||||
provider := NewExistingDBProvider(nil, "test-db")
|
||||
|
||||
_, err := provider.GetNative()
|
||||
if err == nil {
|
||||
t.Error("Expected GetNative to fail with nil database")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_HealthCheck(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
ctx := context.Background()
|
||||
|
||||
err = provider.HealthCheck(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Expected HealthCheck to succeed, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_HealthCheck_ClosedDB(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
|
||||
// Close the database
|
||||
db.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
err = provider.HealthCheck(ctx)
|
||||
if err == nil {
|
||||
t.Error("Expected HealthCheck to fail with closed database")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_GetMongo(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
|
||||
_, err = provider.GetMongo()
|
||||
if err != ErrNotMongoDB {
|
||||
t.Errorf("Expected ErrNotMongoDB, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_Stats(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Set some connection pool settings to test stats
|
||||
db.SetMaxOpenConns(10)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(time.Hour)
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
|
||||
stats := provider.Stats()
|
||||
if stats == nil {
|
||||
t.Fatal("Expected stats to be returned")
|
||||
}
|
||||
|
||||
if stats.Name != "test-db" {
|
||||
t.Errorf("Expected stats.Name to be 'test-db', got '%s'", stats.Name)
|
||||
}
|
||||
|
||||
if stats.Type != "sql" {
|
||||
t.Errorf("Expected stats.Type to be 'sql', got '%s'", stats.Type)
|
||||
}
|
||||
|
||||
if !stats.Connected {
|
||||
t.Error("Expected stats.Connected to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_Close(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to open database: %v", err)
|
||||
}
|
||||
|
||||
provider := NewExistingDBProvider(db, "test-db")
|
||||
|
||||
err = provider.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Expected Close to succeed, got error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the database is closed
|
||||
err = db.Ping()
|
||||
if err == nil {
|
||||
t.Error("Expected database to be closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingDBProvider_Close_NilDB(t *testing.T) {
|
||||
provider := NewExistingDBProvider(nil, "test-db")
|
||||
|
||||
err := provider.Close()
|
||||
if err != nil {
|
||||
t.Errorf("Expected Close to succeed with nil database, got error: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package providers_test
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
|
||||
@@ -29,14 +30,14 @@ func ExamplePostgresListener_basic() {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||
log.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
// Get listener
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||
log.Fatalf("Failed to get listener: %v", err)
|
||||
}
|
||||
|
||||
// Subscribe to a channel with a handler
|
||||
@@ -44,13 +45,13 @@ func ExamplePostgresListener_basic() {
|
||||
fmt.Printf("Received notification on %s: %s\n", channel, payload)
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to listen: %v", err))
|
||||
log.Fatalf("Failed to listen: %v", err)
|
||||
}
|
||||
|
||||
// Send a notification
|
||||
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to notify: %v", err))
|
||||
log.Fatalf("Failed to notify: %v", err)
|
||||
}
|
||||
|
||||
// Wait for notification to be processed
|
||||
@@ -58,7 +59,7 @@ func ExamplePostgresListener_basic() {
|
||||
|
||||
// Unsubscribe from the channel
|
||||
if err := listener.Unlisten("user_events"); err != nil {
|
||||
panic(fmt.Sprintf("Failed to unlisten: %v", err))
|
||||
log.Fatalf("Failed to unlisten: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,13 +81,13 @@ func ExamplePostgresListener_multipleChannels() {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||
log.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||
log.Fatalf("Failed to get listener: %v", err)
|
||||
}
|
||||
|
||||
// Listen to multiple channels
|
||||
@@ -97,7 +98,7 @@ func ExamplePostgresListener_multipleChannels() {
|
||||
fmt.Printf("[%s] %s\n", ch, payload)
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to listen on %s: %v", channel, err))
|
||||
log.Fatalf("Failed to listen on %s: %v", channel, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -140,14 +141,14 @@ func ExamplePostgresListener_withDBManager() {
|
||||
|
||||
provider := providers.NewPostgresProvider()
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
// Get listener
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Subscribe to application events
|
||||
@@ -186,13 +187,13 @@ func ExamplePostgresListener_errorHandling() {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := provider.Connect(ctx, cfg); err != nil {
|
||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
||||
log.Fatalf("Failed to connect: %v", err)
|
||||
}
|
||||
defer provider.Close()
|
||||
|
||||
listener, err := provider.GetListener(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
||||
log.Fatalf("Failed to get listener: %v", err)
|
||||
}
|
||||
|
||||
// The listener automatically reconnects if the connection is lost
|
||||
|
||||
@@ -76,8 +76,12 @@ func (p *SQLiteProvider) Connect(ctx context.Context, cfg ConnectionConfig) erro
|
||||
// Don't fail connection if WAL mode cannot be enabled
|
||||
}
|
||||
|
||||
// Set busy timeout to handle locked database
|
||||
_, err = db.ExecContext(ctx, "PRAGMA busy_timeout=5000")
|
||||
// Set busy timeout to handle locked database (minimum 2 minutes = 120000ms)
|
||||
busyTimeout := cfg.GetQueryTimeout().Milliseconds()
|
||||
if busyTimeout < 120000 {
|
||||
busyTimeout = 120000 // Enforce minimum of 2 minutes
|
||||
}
|
||||
_, err = db.ExecContext(ctx, fmt.Sprintf("PRAGMA busy_timeout=%d", busyTimeout))
|
||||
if err != nil {
|
||||
if cfg.GetEnableLogging() {
|
||||
logger.Warn("Failed to set busy timeout for SQLite", "error", err)
|
||||
|
||||
@@ -74,6 +74,10 @@ func (m *MockDatabase) GetUnderlyingDB() interface{} {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockDatabase) DriverName() string {
|
||||
return "postgres"
|
||||
}
|
||||
|
||||
// MockResult implements common.Result interface for testing
|
||||
type MockResult struct {
|
||||
rows int64
|
||||
|
||||
@@ -2,14 +2,38 @@ package funcspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks registers security hooks for funcspec handlers
|
||||
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
|
||||
// We provide audit logging for data access tracking
|
||||
// We provide auth enforcement and audit logging for data access tracking
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeQueryList - Auth check before list query execution
|
||||
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
||||
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = "authentication required"
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 0: BeforeQuery - Auth check before single query execution
|
||||
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
|
||||
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = "authentication required"
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeQueryList - Audit logging before query list execution
|
||||
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
||||
secCtx := newFuncSpecSecurityContext(hookCtx)
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
|
||||
// ModelRules defines the permissions and security settings for a model
|
||||
type ModelRules struct {
|
||||
CanPublicRead bool // Whether the model can be read (GET operations)
|
||||
CanPublicUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||
CanPublicCreate bool // Whether the model can be created (POST operations)
|
||||
CanPublicDelete bool // Whether the model can be deleted (DELETE operations)
|
||||
CanRead bool // Whether the model can be read (GET operations)
|
||||
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||
CanCreate bool // Whether the model can be created (POST operations)
|
||||
@@ -22,6 +26,10 @@ func DefaultModelRules() ModelRules {
|
||||
CanUpdate: true,
|
||||
CanCreate: true,
|
||||
CanDelete: true,
|
||||
CanPublicRead: false,
|
||||
CanPublicUpdate: false,
|
||||
CanPublicCreate: false,
|
||||
CanPublicDelete: false,
|
||||
SecurityDisabled: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ MQTTSpec is an MQTT-based database query framework that enables real-time databa
|
||||
- **Full CRUD Operations**: Create, Read, Update, Delete with hooks
|
||||
- **Real-time Subscriptions**: Subscribe to entity changes with filtering
|
||||
- **Database Agnostic**: GORM and Bun ORM support
|
||||
- **Lifecycle Hooks**: 12 hooks for authentication, authorization, validation, and auditing
|
||||
- **Lifecycle Hooks**: 13 hooks for authentication, authorization, validation, and auditing
|
||||
- **Multi-tenancy Support**: Built-in tenant isolation via hooks
|
||||
- **Thread-safe**: Proper concurrency handling throughout
|
||||
|
||||
@@ -326,10 +326,11 @@ When any client creates/updates/deletes a user matching the subscription filters
|
||||
|
||||
## Lifecycle Hooks
|
||||
|
||||
MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
|
||||
MQTTSpec provides 13 lifecycle hooks for implementing cross-cutting concerns:
|
||||
|
||||
### Hook Types
|
||||
|
||||
- `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
|
||||
- `BeforeConnect` / `AfterConnect` - Connection lifecycle
|
||||
- `BeforeDisconnect` / `AfterDisconnect` - Disconnection lifecycle
|
||||
- `BeforeRead` / `AfterRead` - Read operations
|
||||
@@ -339,6 +340,20 @@ MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
|
||||
- `BeforeSubscribe` / `AfterSubscribe` - Subscription creation
|
||||
- `BeforeUnsubscribe` / `AfterUnsubscribe` - Subscription removal
|
||||
|
||||
### Security Hooks (Recommended)
|
||||
|
||||
Use `RegisterSecurityHooks` for integrated auth with model-rule support:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList := security.NewSecurityList(provider)
|
||||
mqttspec.RegisterSecurityHooks(handler, securityList)
|
||||
// Registers BeforeHandle (model auth), BeforeRead (load rules),
|
||||
// AfterRead (column security + audit), BeforeUpdate, BeforeDelete
|
||||
```
|
||||
|
||||
### Authentication Example (JWT)
|
||||
|
||||
```go
|
||||
@@ -657,7 +672,7 @@ handler, err := mqttspec.NewHandlerWithGORM(db,
|
||||
| **Network Efficiency** | Better for unreliable networks | Better for low-latency |
|
||||
| **Best For** | IoT, mobile apps, distributed systems | Web applications, real-time dashboards |
|
||||
| **Message Protocol** | Same JSON structure | Same JSON structure |
|
||||
| **Hooks** | Same 12 hooks | Same 12 hooks |
|
||||
| **Hooks** | Same 13 hooks | Same 13 hooks |
|
||||
| **CRUD Operations** | Identical | Identical |
|
||||
| **Subscriptions** | Identical (via MQTT topics) | Identical (via app-level) |
|
||||
|
||||
|
||||
@@ -284,6 +284,15 @@ func (h *Handler) handleRequest(client *Client, msg *Message) {
|
||||
},
|
||||
}
|
||||
|
||||
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||
hookCtx.Operation = string(msg.Operation)
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
if hookCtx.Abort {
|
||||
h.sendError(client.ID, msg.ID, "unauthorized", hookCtx.AbortMessage)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Route to operation handler
|
||||
switch msg.Operation {
|
||||
case OperationRead:
|
||||
@@ -645,11 +654,14 @@ func (h *Handler) getNotifyTopic(clientID, subscriptionID string) string {
|
||||
// Database operation helpers (adapted from websocketspec)
|
||||
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
// Use entity as table name
|
||||
tableName := entity
|
||||
|
||||
if schema != "" {
|
||||
tableName = schema + "." + tableName
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
tableName = schema + "_" + tableName
|
||||
} else {
|
||||
tableName = schema + "." + tableName
|
||||
}
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
|
||||
@@ -20,8 +20,11 @@ type (
|
||||
HookRegistry = websocketspec.HookRegistry
|
||||
)
|
||||
|
||||
// Hook type constants - all 12 lifecycle hooks
|
||||
// Hook type constants - all lifecycle hooks
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch
|
||||
BeforeHandle = websocketspec.BeforeHandle
|
||||
|
||||
// CRUD operation hooks
|
||||
BeforeRead = websocketspec.BeforeRead
|
||||
AfterRead = websocketspec.AfterRead
|
||||
|
||||
108
pkg/mqttspec/security_hooks.go
Normal file
108
pkg/mqttspec/security_hooks.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package mqttspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the MQTT handler
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.LoadSecurityRules(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 2: AfterRead - Apply column-level security (masking)
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.ApplyColumnSecurity(secCtx, securityList)
|
||||
})
|
||||
|
||||
// Hook 3 (Optional): Audit logging
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Hook 4: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelUpdateAllowed(secCtx)
|
||||
})
|
||||
|
||||
// Hook 5: BeforeDelete - enforce CanDelete rule from context/registry
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelDeleteAllowed(secCtx)
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for mqttspec handler")
|
||||
}
|
||||
|
||||
// securityContext adapts mqttspec.HookContext to security.SecurityContext interface
|
||||
type securityContext struct {
|
||||
ctx *HookContext
|
||||
}
|
||||
|
||||
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||
return &securityContext{ctx: ctx}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetContext() context.Context {
|
||||
return s.ctx.Context
|
||||
}
|
||||
|
||||
func (s *securityContext) GetUserID() (int, bool) {
|
||||
return security.GetUserID(s.ctx.Context)
|
||||
}
|
||||
|
||||
func (s *securityContext) GetSchema() string {
|
||||
return s.ctx.Schema
|
||||
}
|
||||
|
||||
func (s *securityContext) GetEntity() string {
|
||||
return s.ctx.Entity
|
||||
}
|
||||
|
||||
func (s *securityContext) GetModel() interface{} {
|
||||
return s.ctx.Model
|
||||
}
|
||||
|
||||
// GetQuery retrieves a stored query from hook metadata
|
||||
func (s *securityContext) GetQuery() interface{} {
|
||||
if s.ctx.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
return s.ctx.Metadata["query"]
|
||||
}
|
||||
|
||||
// SetQuery stores the query in hook metadata
|
||||
func (s *securityContext) SetQuery(query interface{}) {
|
||||
if s.ctx.Metadata == nil {
|
||||
s.ctx.Metadata = make(map[string]interface{})
|
||||
}
|
||||
s.ctx.Metadata["query"] = query
|
||||
}
|
||||
|
||||
func (s *securityContext) GetResult() interface{} {
|
||||
return s.ctx.Result
|
||||
}
|
||||
|
||||
func (s *securityContext) SetResult(result interface{}) {
|
||||
s.ctx.Result = result
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package reflection
|
||||
|
||||
import "reflect"
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Len(v any) int {
|
||||
val := reflect.ValueOf(v)
|
||||
@@ -64,3 +67,41 @@ func GetPointerElement(v reflect.Type) reflect.Type {
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// GetJSONNameForField gets the JSON tag name for a struct field.
|
||||
// Returns the JSON field name from the json struct tag, or an empty string if not found.
|
||||
// Handles the "json" tag format: "name", "name,omitempty", etc.
|
||||
func GetJSONNameForField(modelType reflect.Type, fieldName string) string {
|
||||
if modelType == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handle pointer types
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType.Kind() != reflect.Struct {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the field
|
||||
field, found := modelType.FieldByName(fieldName)
|
||||
if !found {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Get the JSON tag
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Parse the tag (format: "name,omitempty" or just "name")
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" && parts[0] != "-" {
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -948,29 +948,35 @@ func MapToStruct(dataMap map[string]interface{}, target interface{}) error {
|
||||
// Build list of possible column names for this field
|
||||
var columnNames []string
|
||||
|
||||
// 1. Bun tag
|
||||
if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" {
|
||||
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||
columnNames = append(columnNames, colName)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Gorm tag
|
||||
if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" {
|
||||
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||
columnNames = append(columnNames, colName)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. JSON tag
|
||||
// 1. JSON tag (primary - most common)
|
||||
jsonFound := false
|
||||
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
columnNames = append(columnNames, parts[0])
|
||||
jsonFound = true
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Field name variations
|
||||
// 2. Bun tag (fallback if no JSON tag)
|
||||
if !jsonFound {
|
||||
if bunTag := field.Tag.Get("bun"); bunTag != "" && bunTag != "-" {
|
||||
if colName := ExtractColumnFromBunTag(bunTag); colName != "" {
|
||||
columnNames = append(columnNames, colName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Gorm tag (fallback if no JSON tag)
|
||||
if !jsonFound {
|
||||
if gormTag := field.Tag.Get("gorm"); gormTag != "" && gormTag != "-" {
|
||||
if colName := ExtractColumnFromGormTag(gormTag); colName != "" {
|
||||
columnNames = append(columnNames, colName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Field name variations (last resort)
|
||||
columnNames = append(columnNames, field.Name)
|
||||
columnNames = append(columnNames, strings.ToLower(field.Name))
|
||||
// columnNames = append(columnNames, ToSnakeCase(field.Name))
|
||||
@@ -1096,6 +1102,12 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
||||
}
|
||||
}
|
||||
|
||||
// If we can convert the type, do it
|
||||
if valueReflect.Type().ConvertibleTo(field.Type()) {
|
||||
field.Set(valueReflect.Convert(field.Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle struct types (like SqlTimeStamp, SqlDate, SqlTime which wrap SqlNull[time.Time])
|
||||
if field.Kind() == reflect.Struct {
|
||||
|
||||
@@ -1107,9 +1119,9 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
||||
// Call the Scan method with the value
|
||||
results := scanMethod.Call([]reflect.Value{reflect.ValueOf(value)})
|
||||
if len(results) > 0 {
|
||||
// Check if there was an error
|
||||
if err, ok := results[0].Interface().(error); ok && err != nil {
|
||||
return err
|
||||
// The Scan method returns error - check if it's nil
|
||||
if !results[0].IsNil() {
|
||||
return results[0].Interface().(error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1164,12 +1176,6 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
||||
|
||||
}
|
||||
|
||||
// If we can convert the type, do it
|
||||
if valueReflect.Type().ConvertibleTo(field.Type()) {
|
||||
field.Set(valueReflect.Convert(field.Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("cannot convert %v to %v", valueReflect.Type(), field.Type())
|
||||
}
|
||||
|
||||
@@ -1364,6 +1370,63 @@ func convertToFloat64(value interface{}) (float64, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetValidJSONFieldNames returns a map of valid JSON field names for a model
|
||||
// This can be used to validate input data against a model's structure
|
||||
// The map keys are the JSON field names (from json tags) that exist in the model
|
||||
func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool {
|
||||
validFields := make(map[string]bool)
|
||||
|
||||
// Unwrap pointers to get to the base struct type
|
||||
for modelType != nil && modelType.Kind() == reflect.Pointer {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return validFields
|
||||
}
|
||||
|
||||
collectValidFieldNames(modelType, validFields)
|
||||
return validFields
|
||||
}
|
||||
|
||||
// collectValidFieldNames recursively collects valid JSON field names from a struct type
|
||||
func collectValidFieldNames(typ reflect.Type, validFields map[string]bool) {
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
|
||||
// Skip unexported fields
|
||||
if !field.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for embedded structs
|
||||
if field.Anonymous {
|
||||
fieldType := field.Type
|
||||
if fieldType.Kind() == reflect.Ptr {
|
||||
fieldType = fieldType.Elem()
|
||||
}
|
||||
if fieldType.Kind() == reflect.Struct {
|
||||
// Recursively add fields from embedded struct
|
||||
collectValidFieldNames(fieldType, validFields)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Get the JSON tag name for this field (same logic as MapToStruct)
|
||||
jsonTag := field.Tag.Get("json")
|
||||
if jsonTag != "" && jsonTag != "-" {
|
||||
// Extract the field name from the JSON tag (before any options like omitempty)
|
||||
parts := strings.Split(jsonTag, ",")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
validFields[parts[0]] = true
|
||||
}
|
||||
} else {
|
||||
// If no JSON tag, use the field name in lowercase as a fallback
|
||||
validFields[strings.ToLower(field.Name)] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getRelationModelSingleLevel gets the model type for a single level field (non-recursive)
|
||||
// This is a helper function used by GetRelationModel to handle one level at a time
|
||||
func getRelationModelSingleLevel(model interface{}, fieldName string) interface{} {
|
||||
|
||||
120
pkg/reflection/model_utils_stdlib_sqltypes_test.go
Normal file
120
pkg/reflection/model_utils_stdlib_sqltypes_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package reflection_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
func TestMapToStruct_StandardSqlNullTypes(t *testing.T) {
|
||||
// Test model with standard library sql.Null* types
|
||||
type TestModel struct {
|
||||
ID int64 `bun:"id,pk" json:"id"`
|
||||
Age sql.NullInt64 `bun:"age" json:"age"`
|
||||
Name sql.NullString `bun:"name" json:"name"`
|
||||
Score sql.NullFloat64 `bun:"score" json:"score"`
|
||||
Active sql.NullBool `bun:"active" json:"active"`
|
||||
UpdatedAt sql.NullTime `bun:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
dataMap := map[string]any{
|
||||
"id": int64(100),
|
||||
"age": int64(25),
|
||||
"name": "John Doe",
|
||||
"score": 95.5,
|
||||
"active": true,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
var result TestModel
|
||||
err := reflection.MapToStruct(dataMap, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("MapToStruct() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify ID
|
||||
if result.ID != 100 {
|
||||
t.Errorf("ID = %v, want 100", result.ID)
|
||||
}
|
||||
|
||||
// Verify Age (sql.NullInt64)
|
||||
if !result.Age.Valid {
|
||||
t.Error("Age.Valid = false, want true")
|
||||
}
|
||||
if result.Age.Int64 != 25 {
|
||||
t.Errorf("Age.Int64 = %v, want 25", result.Age.Int64)
|
||||
}
|
||||
|
||||
// Verify Name (sql.NullString)
|
||||
if !result.Name.Valid {
|
||||
t.Error("Name.Valid = false, want true")
|
||||
}
|
||||
if result.Name.String != "John Doe" {
|
||||
t.Errorf("Name.String = %v, want 'John Doe'", result.Name.String)
|
||||
}
|
||||
|
||||
// Verify Score (sql.NullFloat64)
|
||||
if !result.Score.Valid {
|
||||
t.Error("Score.Valid = false, want true")
|
||||
}
|
||||
if result.Score.Float64 != 95.5 {
|
||||
t.Errorf("Score.Float64 = %v, want 95.5", result.Score.Float64)
|
||||
}
|
||||
|
||||
// Verify Active (sql.NullBool)
|
||||
if !result.Active.Valid {
|
||||
t.Error("Active.Valid = false, want true")
|
||||
}
|
||||
if !result.Active.Bool {
|
||||
t.Error("Active.Bool = false, want true")
|
||||
}
|
||||
|
||||
// Verify UpdatedAt (sql.NullTime)
|
||||
if !result.UpdatedAt.Valid {
|
||||
t.Error("UpdatedAt.Valid = false, want true")
|
||||
}
|
||||
if !result.UpdatedAt.Time.Equal(now) {
|
||||
t.Errorf("UpdatedAt.Time = %v, want %v", result.UpdatedAt.Time, now)
|
||||
}
|
||||
|
||||
t.Log("All standard library sql.Null* types handled correctly!")
|
||||
}
|
||||
|
||||
func TestMapToStruct_StandardSqlNullTypes_WithNil(t *testing.T) {
|
||||
// Test nil handling for standard library sql.Null* types
|
||||
type TestModel struct {
|
||||
ID int64 `bun:"id,pk" json:"id"`
|
||||
Age sql.NullInt64 `bun:"age" json:"age"`
|
||||
Name sql.NullString `bun:"name" json:"name"`
|
||||
}
|
||||
|
||||
dataMap := map[string]any{
|
||||
"id": int64(200),
|
||||
"age": int64(30),
|
||||
"name": nil, // Explicitly nil
|
||||
}
|
||||
|
||||
var result TestModel
|
||||
err := reflection.MapToStruct(dataMap, &result)
|
||||
if err != nil {
|
||||
t.Fatalf("MapToStruct() error = %v", err)
|
||||
}
|
||||
|
||||
// Age should be valid
|
||||
if !result.Age.Valid {
|
||||
t.Error("Age.Valid = false, want true")
|
||||
}
|
||||
if result.Age.Int64 != 30 {
|
||||
t.Errorf("Age.Int64 = %v, want 30", result.Age.Int64)
|
||||
}
|
||||
|
||||
// Name should be invalid (null)
|
||||
if result.Name.Valid {
|
||||
t.Error("Name.Valid = true, want false (null)")
|
||||
}
|
||||
|
||||
t.Log("Nil handling for sql.Null* types works correctly!")
|
||||
}
|
||||
364
pkg/reflection/spectypes_integration_test.go
Normal file
364
pkg/reflection/spectypes_integration_test.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package reflection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/spectypes"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// TestModel contains all spectypes custom types
|
||||
type TestModel struct {
|
||||
ID int64 `bun:"id,pk" json:"id"`
|
||||
Name spectypes.SqlString `bun:"name" json:"name"`
|
||||
Age spectypes.SqlInt64 `bun:"age" json:"age"`
|
||||
Score spectypes.SqlFloat64 `bun:"score" json:"score"`
|
||||
Active spectypes.SqlBool `bun:"active" json:"active"`
|
||||
UUID spectypes.SqlUUID `bun:"uuid" json:"uuid"`
|
||||
CreatedAt spectypes.SqlTimeStamp `bun:"created_at" json:"created_at"`
|
||||
BirthDate spectypes.SqlDate `bun:"birth_date" json:"birth_date"`
|
||||
StartTime spectypes.SqlTime `bun:"start_time" json:"start_time"`
|
||||
Metadata spectypes.SqlJSONB `bun:"metadata" json:"metadata"`
|
||||
Count16 spectypes.SqlInt16 `bun:"count16" json:"count16"`
|
||||
Count32 spectypes.SqlInt32 `bun:"count32" json:"count32"`
|
||||
}
|
||||
|
||||
// TestMapToStruct_AllSpectypes verifies that MapToStruct can convert all spectypes correctly
|
||||
func TestMapToStruct_AllSpectypes(t *testing.T) {
|
||||
testUUID := uuid.New()
|
||||
testTime := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
dataMap map[string]interface{}
|
||||
validator func(*testing.T, *TestModel)
|
||||
}{
|
||||
{
|
||||
name: "SqlString from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"name": "John Doe",
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Name.Valid || m.Name.String() != "John Doe" {
|
||||
t.Errorf("expected name='John Doe', got valid=%v, value=%s", m.Name.Valid, m.Name.String())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlInt64 from int64",
|
||||
dataMap: map[string]interface{}{
|
||||
"age": int64(42),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Age.Valid || m.Age.Int64() != 42 {
|
||||
t.Errorf("expected age=42, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlInt64 from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"age": "99",
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Age.Valid || m.Age.Int64() != 99 {
|
||||
t.Errorf("expected age=99, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlFloat64 from float64",
|
||||
dataMap: map[string]interface{}{
|
||||
"score": float64(98.5),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Score.Valid || m.Score.Float64() != 98.5 {
|
||||
t.Errorf("expected score=98.5, got valid=%v, value=%f", m.Score.Valid, m.Score.Float64())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlBool from bool",
|
||||
dataMap: map[string]interface{}{
|
||||
"active": true,
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Active.Valid || !m.Active.Bool() {
|
||||
t.Errorf("expected active=true, got valid=%v, value=%v", m.Active.Valid, m.Active.Bool())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlUUID from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"uuid": testUUID.String(),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.UUID.Valid || m.UUID.UUID() != testUUID {
|
||||
t.Errorf("expected uuid=%s, got valid=%v, value=%s", testUUID.String(), m.UUID.Valid, m.UUID.UUID().String())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlTimeStamp from time.Time",
|
||||
dataMap: map[string]interface{}{
|
||||
"created_at": testTime,
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.CreatedAt.Valid {
|
||||
t.Errorf("expected created_at to be valid")
|
||||
}
|
||||
// Check if times are close enough (within a second)
|
||||
diff := m.CreatedAt.Time().Sub(testTime)
|
||||
if diff < -time.Second || diff > time.Second {
|
||||
t.Errorf("time difference too large: %v", diff)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlTimeStamp from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"created_at": "2024-01-15T10:30:00",
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.CreatedAt.Valid {
|
||||
t.Errorf("expected created_at to be valid")
|
||||
}
|
||||
expected := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||
if m.CreatedAt.Time().Year() != expected.Year() ||
|
||||
m.CreatedAt.Time().Month() != expected.Month() ||
|
||||
m.CreatedAt.Time().Day() != expected.Day() {
|
||||
t.Errorf("expected date 2024-01-15, got %v", m.CreatedAt.Time())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlDate from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"birth_date": "2000-05-20",
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.BirthDate.Valid {
|
||||
t.Errorf("expected birth_date to be valid")
|
||||
}
|
||||
expected := "2000-05-20"
|
||||
if m.BirthDate.String() != expected {
|
||||
t.Errorf("expected date=%s, got %s", expected, m.BirthDate.String())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlTime from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"start_time": "14:30:00",
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.StartTime.Valid {
|
||||
t.Errorf("expected start_time to be valid")
|
||||
}
|
||||
if m.StartTime.String() != "14:30:00" {
|
||||
t.Errorf("expected time=14:30:00, got %s", m.StartTime.String())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlJSONB from map",
|
||||
dataMap: map[string]interface{}{
|
||||
"metadata": map[string]interface{}{
|
||||
"key1": "value1",
|
||||
"key2": 123,
|
||||
},
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if len(m.Metadata) == 0 {
|
||||
t.Errorf("expected metadata to have data")
|
||||
}
|
||||
asMap, err := m.Metadata.AsMap()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to convert metadata to map: %v", err)
|
||||
}
|
||||
if asMap["key1"] != "value1" {
|
||||
t.Errorf("expected key1=value1, got %v", asMap["key1"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlJSONB from string",
|
||||
dataMap: map[string]interface{}{
|
||||
"metadata": `{"test":"data"}`,
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if len(m.Metadata) == 0 {
|
||||
t.Errorf("expected metadata to have data")
|
||||
}
|
||||
asMap, err := m.Metadata.AsMap()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to convert metadata to map: %v", err)
|
||||
}
|
||||
if asMap["test"] != "data" {
|
||||
t.Errorf("expected test=data, got %v", asMap["test"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlJSONB from []byte",
|
||||
dataMap: map[string]interface{}{
|
||||
"metadata": []byte(`{"byte":"array"}`),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if len(m.Metadata) == 0 {
|
||||
t.Errorf("expected metadata to have data")
|
||||
}
|
||||
if string(m.Metadata) != `{"byte":"array"}` {
|
||||
t.Errorf("expected {\"byte\":\"array\"}, got %s", string(m.Metadata))
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlInt16 from int16",
|
||||
dataMap: map[string]interface{}{
|
||||
"count16": int16(100),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Count16.Valid || m.Count16.Int64() != 100 {
|
||||
t.Errorf("expected count16=100, got valid=%v, value=%d", m.Count16.Valid, m.Count16.Int64())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SqlInt32 from int32",
|
||||
dataMap: map[string]interface{}{
|
||||
"count32": int32(5000),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if !m.Count32.Valid || m.Count32.Int64() != 5000 {
|
||||
t.Errorf("expected count32=5000, got valid=%v, value=%d", m.Count32.Valid, m.Count32.Int64())
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nil values create invalid nulls",
|
||||
dataMap: map[string]interface{}{
|
||||
"name": nil,
|
||||
"age": nil,
|
||||
"active": nil,
|
||||
"created_at": nil,
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if m.Name.Valid {
|
||||
t.Error("expected name to be invalid for nil value")
|
||||
}
|
||||
if m.Age.Valid {
|
||||
t.Error("expected age to be invalid for nil value")
|
||||
}
|
||||
if m.Active.Valid {
|
||||
t.Error("expected active to be invalid for nil value")
|
||||
}
|
||||
if m.CreatedAt.Valid {
|
||||
t.Error("expected created_at to be invalid for nil value")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all types together",
|
||||
dataMap: map[string]interface{}{
|
||||
"id": int64(1),
|
||||
"name": "Test User",
|
||||
"age": int64(30),
|
||||
"score": float64(95.7),
|
||||
"active": true,
|
||||
"uuid": testUUID.String(),
|
||||
"created_at": "2024-01-15T10:30:00",
|
||||
"birth_date": "1994-06-15",
|
||||
"start_time": "09:00:00",
|
||||
"metadata": map[string]interface{}{"role": "admin"},
|
||||
"count16": int16(50),
|
||||
"count32": int32(1000),
|
||||
},
|
||||
validator: func(t *testing.T, m *TestModel) {
|
||||
if m.ID != 1 {
|
||||
t.Errorf("expected id=1, got %d", m.ID)
|
||||
}
|
||||
if !m.Name.Valid || m.Name.String() != "Test User" {
|
||||
t.Errorf("expected name='Test User', got valid=%v, value=%s", m.Name.Valid, m.Name.String())
|
||||
}
|
||||
if !m.Age.Valid || m.Age.Int64() != 30 {
|
||||
t.Errorf("expected age=30, got valid=%v, value=%d", m.Age.Valid, m.Age.Int64())
|
||||
}
|
||||
if !m.Score.Valid || m.Score.Float64() != 95.7 {
|
||||
t.Errorf("expected score=95.7, got valid=%v, value=%f", m.Score.Valid, m.Score.Float64())
|
||||
}
|
||||
if !m.Active.Valid || !m.Active.Bool() {
|
||||
t.Errorf("expected active=true, got valid=%v, value=%v", m.Active.Valid, m.Active.Bool())
|
||||
}
|
||||
if !m.UUID.Valid {
|
||||
t.Error("expected uuid to be valid")
|
||||
}
|
||||
if !m.CreatedAt.Valid {
|
||||
t.Error("expected created_at to be valid")
|
||||
}
|
||||
if !m.BirthDate.Valid || m.BirthDate.String() != "1994-06-15" {
|
||||
t.Errorf("expected birth_date=1994-06-15, got valid=%v, value=%s", m.BirthDate.Valid, m.BirthDate.String())
|
||||
}
|
||||
if !m.StartTime.Valid || m.StartTime.String() != "09:00:00" {
|
||||
t.Errorf("expected start_time=09:00:00, got valid=%v, value=%s", m.StartTime.Valid, m.StartTime.String())
|
||||
}
|
||||
if len(m.Metadata) == 0 {
|
||||
t.Error("expected metadata to have data")
|
||||
}
|
||||
if !m.Count16.Valid || m.Count16.Int64() != 50 {
|
||||
t.Errorf("expected count16=50, got valid=%v, value=%d", m.Count16.Valid, m.Count16.Int64())
|
||||
}
|
||||
if !m.Count32.Valid || m.Count32.Int64() != 1000 {
|
||||
t.Errorf("expected count32=1000, got valid=%v, value=%d", m.Count32.Valid, m.Count32.Int64())
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := &TestModel{}
|
||||
if err := MapToStruct(tt.dataMap, model); err != nil {
|
||||
t.Fatalf("MapToStruct failed: %v", err)
|
||||
}
|
||||
tt.validator(t, model)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMapToStruct_PartialUpdate tests that partial updates preserve unset fields
|
||||
func TestMapToStruct_PartialUpdate(t *testing.T) {
|
||||
// Create initial model with some values
|
||||
initial := &TestModel{
|
||||
ID: 1,
|
||||
Name: spectypes.NewSqlString("Original Name"),
|
||||
Age: spectypes.NewSqlInt64(25),
|
||||
}
|
||||
|
||||
// Update only the age field
|
||||
partialUpdate := map[string]interface{}{
|
||||
"age": int64(30),
|
||||
}
|
||||
|
||||
// Apply partial update
|
||||
if err := MapToStruct(partialUpdate, initial); err != nil {
|
||||
t.Fatalf("MapToStruct failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify age was updated
|
||||
if !initial.Age.Valid || initial.Age.Int64() != 30 {
|
||||
t.Errorf("expected age=30, got valid=%v, value=%d", initial.Age.Valid, initial.Age.Int64())
|
||||
}
|
||||
|
||||
// Verify name was preserved (not overwritten with zero value)
|
||||
if !initial.Name.Valid || initial.Name.String() != "Original Name" {
|
||||
t.Errorf("expected name='Original Name' to be preserved, got valid=%v, value=%s", initial.Name.Valid, initial.Name.String())
|
||||
}
|
||||
|
||||
// Verify ID was preserved
|
||||
if initial.ID != 1 {
|
||||
t.Errorf("expected id=1 to be preserved, got %d", initial.ID)
|
||||
}
|
||||
}
|
||||
670
pkg/resolvemcp/README.md
Normal file
670
pkg/resolvemcp/README.md
Normal file
@@ -0,0 +1,670 @@
|
||||
# resolvemcp
|
||||
|
||||
Package `resolvemcp` exposes registered database models as **Model Context Protocol (MCP) tools and resources** over HTTP/SSE transport. It mirrors the `resolvespec` package patterns — same model registration API, same filter/sort/pagination/preload options, same lifecycle hook system.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```go
|
||||
import (
|
||||
"github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// 1. Create a handler
|
||||
handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{
|
||||
BaseURL: "http://localhost:8080",
|
||||
})
|
||||
|
||||
// 2. Register models
|
||||
handler.RegisterModel("public", "users", &User{})
|
||||
handler.RegisterModel("public", "orders", &Order{})
|
||||
|
||||
// 3. Mount routes
|
||||
r := mux.NewRouter()
|
||||
resolvemcp.SetupMuxRoutes(r, handler)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Config
|
||||
|
||||
```go
|
||||
type Config struct {
|
||||
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||
// Sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||
// If empty, it is detected from each incoming request using the Host header and
|
||||
// TLS state (X-Forwarded-Proto is honoured for reverse-proxy deployments).
|
||||
BaseURL string
|
||||
|
||||
// BasePath is the URL path prefix where MCP endpoints are mounted (e.g. "/mcp").
|
||||
// Required.
|
||||
BasePath string
|
||||
}
|
||||
```
|
||||
|
||||
## Handler Creation
|
||||
|
||||
| Function | Description |
|
||||
|---|---|
|
||||
| `NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler` | Backed by GORM |
|
||||
| `NewHandlerWithBun(db *bun.DB, cfg Config) *Handler` | Backed by Bun |
|
||||
| `NewHandlerWithDB(db common.Database, cfg Config) *Handler` | Backed by any `common.Database` |
|
||||
| `NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler` | Full control over registry |
|
||||
|
||||
---
|
||||
|
||||
## Registering Models
|
||||
|
||||
```go
|
||||
handler.RegisterModel(schema, entity string, model interface{}) error
|
||||
```
|
||||
|
||||
- `schema` — database schema name (e.g. `"public"`), or empty string for no schema prefix.
|
||||
- `entity` — table/entity name (e.g. `"users"`).
|
||||
- `model` — a pointer to a struct (e.g. `&User{}`).
|
||||
|
||||
Each call immediately creates four MCP **tools** and one MCP **resource** for the model.
|
||||
|
||||
---
|
||||
|
||||
## HTTP Transports
|
||||
|
||||
`Config.BasePath` is required and used for all route registration.
|
||||
`Config.BaseURL` is optional — when empty it is detected from each request.
|
||||
|
||||
Two transports are supported: **SSE** (legacy, two-endpoint) and **Streamable HTTP** (recommended, single-endpoint).
|
||||
|
||||
---
|
||||
|
||||
### SSE Transport
|
||||
|
||||
Two endpoints: `GET {BasePath}/sse` (subscribe) + `POST {BasePath}/message` (send).
|
||||
|
||||
#### Gorilla Mux
|
||||
|
||||
```go
|
||||
resolvemcp.SetupMuxRoutes(r, handler)
|
||||
```
|
||||
|
||||
| Route | Method | Description |
|
||||
|---|---|---|
|
||||
| `{BasePath}/sse` | GET | SSE connection — clients subscribe here |
|
||||
| `{BasePath}/message` | POST | JSON-RPC — clients send requests here |
|
||||
|
||||
#### bunrouter
|
||||
|
||||
```go
|
||||
resolvemcp.SetupBunRouterRoutes(router, handler)
|
||||
```
|
||||
|
||||
#### Gin / net/http / Echo
|
||||
|
||||
```go
|
||||
sse := handler.SSEServer()
|
||||
|
||||
engine.Any("/mcp/*path", gin.WrapH(sse)) // Gin
|
||||
http.Handle("/mcp/", sse) // net/http
|
||||
e.Any("/mcp/*", echo.WrapHandler(sse)) // Echo
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Streamable HTTP Transport
|
||||
|
||||
Single endpoint at `{BasePath}`. Handles POST (client→server) and GET (server→client streaming). Preferred for new integrations.
|
||||
|
||||
#### Gorilla Mux
|
||||
|
||||
```go
|
||||
resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler)
|
||||
```
|
||||
|
||||
Mounts the handler at `{BasePath}` (all methods).
|
||||
|
||||
#### bunrouter
|
||||
|
||||
```go
|
||||
resolvemcp.SetupBunRouterStreamableHTTPRoutes(router, handler)
|
||||
```
|
||||
|
||||
Registers GET, POST, DELETE on `{BasePath}`.
|
||||
|
||||
#### Gin / net/http / Echo
|
||||
|
||||
```go
|
||||
h := handler.StreamableHTTPServer()
|
||||
// or: h := resolvemcp.NewStreamableHTTPHandler(handler)
|
||||
|
||||
engine.Any("/mcp", gin.WrapH(h)) // Gin
|
||||
http.Handle("/mcp", h) // net/http
|
||||
e.Any("/mcp", echo.WrapHandler(h)) // Echo
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## OAuth2 Authentication
|
||||
|
||||
`resolvemcp` ships a full **MCP-standard OAuth2 authorization server** (`pkg/security.OAuthServer`) that MCP clients (Claude Desktop, Cursor, etc.) can discover and use automatically.
|
||||
|
||||
It can operate as:
|
||||
- **Its own identity provider** — shows a login form, validates via `DatabaseAuthenticator.Login()`
|
||||
- **An OAuth2 federation layer** — delegates to external providers (Google, GitHub, Microsoft, etc.)
|
||||
- **Both simultaneously**
|
||||
|
||||
### Standard endpoints served
|
||||
|
||||
| Path | Spec | Purpose |
|
||||
|---|---|---|
|
||||
| `GET /.well-known/oauth-authorization-server` | RFC 8414 | MCP client auto-discovery |
|
||||
| `POST /oauth/register` | RFC 7591 | Dynamic client registration |
|
||||
| `GET /oauth/authorize` | OAuth 2.1 + PKCE | Start login (form or provider redirect) |
|
||||
| `POST /oauth/authorize` | — | Login form submission |
|
||||
| `POST /oauth/token` | OAuth 2.1 | Auth code → Bearer token exchange |
|
||||
| `POST /oauth/token` (refresh) | OAuth 2.1 | Refresh token rotation |
|
||||
| `GET /oauth/provider/callback` | Internal | External provider redirect target |
|
||||
|
||||
MCP clients send `Authorization: Bearer <token>` on all subsequent requests.
|
||||
|
||||
---
|
||||
|
||||
### Mode 1 — Direct login (server as identity provider)
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
db, _ := sql.Open("postgres", dsn)
|
||||
auth := security.NewDatabaseAuthenticator(db)
|
||||
|
||||
handler := resolvemcp.NewHandlerWithGORM(gormDB, resolvemcp.Config{
|
||||
BaseURL: "https://api.example.com",
|
||||
BasePath: "/mcp",
|
||||
})
|
||||
|
||||
// Enable the OAuth2 server — auth enables the login form
|
||||
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
}, auth)
|
||||
|
||||
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList, _ := security.NewSecurityList(provider)
|
||||
security.RegisterSecurityHooks(handler, securityList)
|
||||
|
||||
http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
|
||||
```
|
||||
|
||||
MCP client flow:
|
||||
1. Discovers server at `/.well-known/oauth-authorization-server`
|
||||
2. Registers itself at `/oauth/register`
|
||||
3. Redirects user to `/oauth/authorize` → login form appears
|
||||
4. On submit, exchanges code at `/oauth/token` → receives `Authorization: Bearer` token
|
||||
5. Uses token on all MCP tool calls
|
||||
|
||||
---
|
||||
|
||||
### Mode 2 — External provider (Google, GitHub, etc.)
|
||||
|
||||
The `RedirectURL` in the provider config must point to `/oauth/provider/callback` on this server.
|
||||
|
||||
```go
|
||||
auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
|
||||
ClientID: os.Getenv("GOOGLE_CLIENT_ID"),
|
||||
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
|
||||
RedirectURL: "https://api.example.com/oauth/provider/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
})
|
||||
|
||||
// nil = no password login; Google handles auth
|
||||
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
}, nil)
|
||||
handler.RegisterOAuth2Provider(auth, "google")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Mode 3 — Both (login form + external providers)
|
||||
|
||||
```go
|
||||
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
LoginTitle: "My App Login",
|
||||
}, auth) // auth enables the username/password form
|
||||
|
||||
handler.RegisterOAuth2Provider(googleAuth, "google")
|
||||
handler.RegisterOAuth2Provider(githubAuth, "github")
|
||||
```
|
||||
|
||||
When external providers are registered they take priority; the login form is used as fallback when no providers are configured.
|
||||
|
||||
---
|
||||
|
||||
### Using `security.OAuthServer` standalone
|
||||
|
||||
The authorization server lives in `pkg/security` and can be used with any HTTP framework independently of `resolvemcp`:
|
||||
|
||||
```go
|
||||
oauthSrv := security.NewOAuthServer(security.OAuthServerConfig{
|
||||
Issuer: "https://api.example.com",
|
||||
}, auth)
|
||||
oauthSrv.RegisterExternalProvider(googleAuth, "google")
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/", oauthSrv.HTTPHandler()) // mounts all OAuth2 routes
|
||||
mux.Handle("/mcp/", myMCPHandler)
|
||||
http.ListenAndServe(":8080", mux)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Cookie-based flow (legacy)
|
||||
|
||||
For simple setups without full MCP OAuth2 compliance, use the legacy helpers that set a session cookie after external provider login:
|
||||
|
||||
```go
|
||||
resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
|
||||
ProviderName: "google",
|
||||
LoginPath: "/auth/google/login",
|
||||
CallbackPath: "/auth/google/callback",
|
||||
AfterLoginRedirect: "/",
|
||||
})
|
||||
resolvemcp.SetupMuxRoutesWithAuth(r, handler, securityList)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Security
|
||||
|
||||
`resolvemcp` integrates with the `security` package to provide per-entity access control, row-level security, and column-level security — the same system used by `resolvespec` and `restheadspec`.
|
||||
|
||||
### Wiring security hooks
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
securityList := security.NewSecurityList(mySecurityProvider)
|
||||
resolvemcp.RegisterSecurityHooks(handler, securityList)
|
||||
```
|
||||
|
||||
Call `RegisterSecurityHooks` **once**, after creating the handler and before registering models. It installs these controls automatically:
|
||||
|
||||
| Hook | Effect |
|
||||
|---|---|
|
||||
| `BeforeHandle` | Enforces per-entity operation rules (see below) |
|
||||
| `BeforeRead` | Loads RLS/CLS rules, then injects a user-scoped WHERE clause |
|
||||
| `AfterRead` | Masks/hides columns per column-security rules; writes audit log |
|
||||
| `BeforeUpdate` | Blocks update if `CanUpdate` is false |
|
||||
| `BeforeDelete` | Blocks delete if `CanDelete` is false |
|
||||
|
||||
### Per-entity operation rules
|
||||
|
||||
Use `RegisterModelWithRules` instead of `RegisterModel` to set access rules at registration time:
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
|
||||
// Read-only entity
|
||||
handler.RegisterModelWithRules("public", "audit_logs", &AuditLog{}, modelregistry.ModelRules{
|
||||
CanRead: true,
|
||||
CanCreate: false,
|
||||
CanUpdate: false,
|
||||
CanDelete: false,
|
||||
})
|
||||
|
||||
// Public read, authenticated write
|
||||
handler.RegisterModelWithRules("public", "products", &Product{}, modelregistry.ModelRules{
|
||||
CanPublicRead: true,
|
||||
CanRead: true,
|
||||
CanCreate: true,
|
||||
CanUpdate: true,
|
||||
CanDelete: false,
|
||||
})
|
||||
```
|
||||
|
||||
To update rules for an already-registered model:
|
||||
|
||||
```go
|
||||
handler.SetModelRules("public", "users", modelregistry.ModelRules{
|
||||
CanRead: true,
|
||||
CanCreate: true,
|
||||
CanUpdate: true,
|
||||
CanDelete: false,
|
||||
})
|
||||
```
|
||||
|
||||
`RegisterModel` (no rules) registers with all-allowed defaults (`CanRead/Create/Update/Delete = true`).
|
||||
|
||||
### ModelRules fields
|
||||
|
||||
| Field | Default | Description |
|
||||
|---|---|---|
|
||||
| `CanPublicRead` | `false` | Allow unauthenticated reads |
|
||||
| `CanPublicCreate` | `false` | Allow unauthenticated creates |
|
||||
| `CanPublicUpdate` | `false` | Allow unauthenticated updates |
|
||||
| `CanPublicDelete` | `false` | Allow unauthenticated deletes |
|
||||
| `CanRead` | `true` | Allow authenticated reads |
|
||||
| `CanCreate` | `true` | Allow authenticated creates |
|
||||
| `CanUpdate` | `true` | Allow authenticated updates |
|
||||
| `CanDelete` | `true` | Allow authenticated deletes |
|
||||
| `SecurityDisabled` | `false` | Skip all security checks for this model |
|
||||
|
||||
---
|
||||
|
||||
## MCP Tools
|
||||
|
||||
### Tool Naming
|
||||
|
||||
```
|
||||
{operation}_{schema}_{entity} // e.g. read_public_users
|
||||
{operation}_{entity} // e.g. read_users (when schema is empty)
|
||||
```
|
||||
|
||||
Operations: `read`, `create`, `update`, `delete`.
|
||||
|
||||
### Read Tool — `read_{schema}_{entity}`
|
||||
|
||||
Fetch one or many records.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `id` | string | Primary key value. Omit to return multiple records. |
|
||||
| `limit` | number | Max records per page (recommended: 10–100). |
|
||||
| `offset` | number | Records to skip (offset-based pagination). |
|
||||
| `cursor_forward` | string | PK of the **last** record on the current page (next-page cursor). |
|
||||
| `cursor_backward` | string | PK of the **first** record on the current page (prev-page cursor). |
|
||||
| `columns` | array | Column names to include. Omit for all columns. |
|
||||
| `omit_columns` | array | Column names to exclude. |
|
||||
| `filters` | array | Filter objects (see [Filtering](#filtering)). |
|
||||
| `sort` | array | Sort objects (see [Sorting](#sorting)). |
|
||||
| `preloads` | array | Relation preload objects (see [Preloading](#preloading)). |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [...],
|
||||
"metadata": {
|
||||
"total": 100,
|
||||
"filtered": 100,
|
||||
"count": 10,
|
||||
"limit": 10,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Create Tool — `create_{schema}_{entity}`
|
||||
|
||||
Insert one or more records.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `data` | object \| array | Single object or array of objects to insert. |
|
||||
|
||||
Array input runs inside a single transaction — all succeed or all fail.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "data": { ... } }
|
||||
```
|
||||
|
||||
### Update Tool — `update_{schema}_{entity}`
|
||||
|
||||
Partially update an existing record. Only non-null, non-empty fields in `data` are applied; existing values are preserved for omitted fields.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `id` | string | Primary key of the record. Can also be included inside `data`. |
|
||||
| `data` | object (required) | Fields to update. |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "data": { ...merged record... } }
|
||||
```
|
||||
|
||||
### Delete Tool — `delete_{schema}_{entity}`
|
||||
|
||||
Delete a record by primary key. **Irreversible.**
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `id` | string (required) | Primary key of the record to delete. |
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "data": { ...deleted record... } }
|
||||
```
|
||||
|
||||
### Annotation Tool — `resolvespec_annotate`
|
||||
|
||||
Store or retrieve freeform annotation records for any tool, model, or entity. Registered automatically on every handler.
|
||||
|
||||
| Argument | Type | Description |
|
||||
|---|---|---|
|
||||
| `tool_name` | string (required) | Key to annotate — an MCP tool name (e.g. `read_public_users`), a model name (e.g. `public.users`), or any other identifier. |
|
||||
| `annotations` | object | Annotation data to persist. Omit to retrieve existing annotations instead. |
|
||||
|
||||
**Set annotations** (calls `resolvespec_set_annotation(tool_name, annotations)`):
|
||||
```json
|
||||
{ "tool_name": "read_public_users", "annotations": { "description": "Returns active users", "owner": "platform-team" } }
|
||||
```
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "tool_name": "read_public_users", "action": "set" }
|
||||
```
|
||||
|
||||
**Get annotations** (calls `resolvespec_get_annotation(tool_name)`):
|
||||
```json
|
||||
{ "tool_name": "read_public_users" }
|
||||
```
|
||||
**Response:**
|
||||
```json
|
||||
{ "success": true, "tool_name": "read_public_users", "action": "get", "annotations": { ... } }
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Resource — `{schema}.{entity}`
|
||||
|
||||
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.
|
||||
|
||||
---
|
||||
|
||||
## Filtering
|
||||
|
||||
Pass an array of filter objects to the `filters` argument:
|
||||
|
||||
```json
|
||||
[
|
||||
{ "column": "status", "operator": "=", "value": "active" },
|
||||
{ "column": "age", "operator": ">", "value": 18, "logic_operator": "AND" },
|
||||
{ "column": "role", "operator": "in", "value": ["admin", "editor"], "logic_operator": "OR" }
|
||||
]
|
||||
```
|
||||
|
||||
### Supported Operators
|
||||
|
||||
| Operator | Aliases | Description |
|
||||
|---|---|---|
|
||||
| `=` | `eq` | Equal |
|
||||
| `!=` | `neq`, `<>` | Not equal |
|
||||
| `>` | `gt` | Greater than |
|
||||
| `>=` | `gte` | Greater than or equal |
|
||||
| `<` | `lt` | Less than |
|
||||
| `<=` | `lte` | Less than or equal |
|
||||
| `like` | | SQL LIKE (case-sensitive) |
|
||||
| `ilike` | | SQL ILIKE (case-insensitive) |
|
||||
| `in` | | Value in list |
|
||||
| `is_null` | | Column IS NULL |
|
||||
| `is_not_null` | | Column IS NOT NULL |
|
||||
|
||||
### Logic Operators
|
||||
|
||||
- `"logic_operator": "AND"` (default) — filter is AND-chained with the previous condition.
|
||||
- `"logic_operator": "OR"` — filter is OR-grouped with the previous condition.
|
||||
|
||||
Consecutive OR filters are grouped into a single `(cond1 OR cond2 OR ...)` clause.
|
||||
|
||||
---
|
||||
|
||||
## Sorting
|
||||
|
||||
```json
|
||||
[
|
||||
{ "column": "created_at", "direction": "desc" },
|
||||
{ "column": "name", "direction": "asc" }
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pagination
|
||||
|
||||
### Offset-Based
|
||||
|
||||
```json
|
||||
{ "limit": 20, "offset": 40 }
|
||||
```
|
||||
|
||||
### Cursor-Based
|
||||
|
||||
Cursor pagination uses a SQL `EXISTS` subquery for stable, efficient paging. Always pair with a `sort` argument.
|
||||
|
||||
```json
|
||||
// Next page: pass the PK of the last record on the current page
|
||||
{ "cursor_forward": "42", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
|
||||
|
||||
// Previous page: pass the PK of the first record on the current page
|
||||
{ "cursor_backward": "23", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Preloading Relations
|
||||
|
||||
```json
|
||||
[
|
||||
{ "relation": "Profile" },
|
||||
{ "relation": "Orders" }
|
||||
]
|
||||
```
|
||||
|
||||
Available relations are listed in each tool's description. Only relations defined on the model struct are valid.
|
||||
|
||||
---
|
||||
|
||||
## Hook System
|
||||
|
||||
Hooks let you intercept and modify CRUD operations at well-defined lifecycle points.
|
||||
|
||||
### Hook Types
|
||||
|
||||
| Constant | Fires |
|
||||
|---|---|
|
||||
| `BeforeHandle` | After model resolution, before operation dispatch (all CRUD) |
|
||||
| `BeforeRead` / `AfterRead` | Around read queries |
|
||||
| `BeforeCreate` / `AfterCreate` | Around insert |
|
||||
| `BeforeUpdate` / `AfterUpdate` | Around update |
|
||||
| `BeforeDelete` / `AfterDelete` | Around delete |
|
||||
|
||||
### Registering Hooks
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(resolvemcp.BeforeCreate, func(ctx *resolvemcp.HookContext) error {
|
||||
// Inject a timestamp before insert
|
||||
if data, ok := ctx.Data.(map[string]interface{}); ok {
|
||||
data["created_at"] = time.Now()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Register the same hook for multiple events
|
||||
handler.Hooks().RegisterMultiple(
|
||||
[]resolvemcp.HookType{resolvemcp.BeforeCreate, resolvemcp.BeforeUpdate},
|
||||
auditHook,
|
||||
)
|
||||
```
|
||||
|
||||
### HookContext Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `Context` | `context.Context` | Request context |
|
||||
| `Handler` | `*Handler` | The resolvemcp handler |
|
||||
| `Schema` | `string` | Database schema name |
|
||||
| `Entity` | `string` | Entity/table name |
|
||||
| `Model` | `interface{}` | Registered model instance |
|
||||
| `Options` | `common.RequestOptions` | Parsed request options (read operations) |
|
||||
| `Operation` | `string` | `"read"`, `"create"`, `"update"`, or `"delete"` |
|
||||
| `ID` | `string` | Primary key from request (read/update/delete) |
|
||||
| `Data` | `interface{}` | Input data (create/update — modifiable) |
|
||||
| `Result` | `interface{}` | Output data (set by After hooks) |
|
||||
| `Error` | `error` | Operation error, if any |
|
||||
| `Query` | `common.SelectQuery` | Live query object (available in `BeforeRead`) |
|
||||
| `Tx` | `common.Database` | Database/transaction handle |
|
||||
| `Abort` | `bool` | Set to `true` to abort the operation |
|
||||
| `AbortMessage` | `string` | Error message returned when aborting |
|
||||
| `AbortCode` | `int` | Optional status code for the abort |
|
||||
|
||||
### Aborting an Operation
|
||||
|
||||
```go
|
||||
handler.Hooks().Register(resolvemcp.BeforeDelete, func(ctx *resolvemcp.HookContext) error {
|
||||
ctx.Abort = true
|
||||
ctx.AbortMessage = "deletion is disabled"
|
||||
return nil
|
||||
})
|
||||
```
|
||||
|
||||
### Managing Hooks
|
||||
|
||||
```go
|
||||
registry := handler.Hooks()
|
||||
registry.HasHooks(resolvemcp.BeforeCreate) // bool
|
||||
registry.Clear(resolvemcp.BeforeCreate) // remove hooks for one type
|
||||
registry.ClearAll() // remove all hooks
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Context Helpers
|
||||
|
||||
Request metadata is threaded through `context.Context` during handler execution. Hooks and custom tools can read it:
|
||||
|
||||
```go
|
||||
schema := resolvemcp.GetSchema(ctx)
|
||||
entity := resolvemcp.GetEntity(ctx)
|
||||
tableName := resolvemcp.GetTableName(ctx)
|
||||
model := resolvemcp.GetModel(ctx)
|
||||
modelPtr := resolvemcp.GetModelPtr(ctx)
|
||||
```
|
||||
|
||||
You can also set values manually (e.g. in middleware):
|
||||
|
||||
```go
|
||||
ctx = resolvemcp.WithSchema(ctx, "tenant_a")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Adding Custom MCP Tools
|
||||
|
||||
Access the underlying `*server.MCPServer` to register additional tools:
|
||||
|
||||
```go
|
||||
mcpServer := handler.MCPServer()
|
||||
mcpServer.AddTool(myTool, myHandler)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Table Name Resolution
|
||||
|
||||
The handler resolves table names in priority order:
|
||||
|
||||
1. `TableNameProvider` interface — `TableName() string` (can return `"schema.table"`)
|
||||
2. `SchemaProvider` interface — `SchemaName() string` (combined with entity name)
|
||||
3. Fallback: `schema.entity` (or `schema_entity` for SQLite)
|
||||
107
pkg/resolvemcp/annotation.go
Normal file
107
pkg/resolvemcp/annotation.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
)
|
||||
|
||||
const annotationToolName = "resolvespec_annotate"
|
||||
|
||||
// registerAnnotationTool adds the resolvespec_annotate tool to the MCP server.
|
||||
// The tool lets models/entities store and retrieve freeform annotation records
|
||||
// using the resolvespec_set_annotation / resolvespec_get_annotation database procedures.
|
||||
func registerAnnotationTool(h *Handler) {
|
||||
tool := mcp.NewTool(annotationToolName,
|
||||
mcp.WithDescription(
|
||||
"Store or retrieve annotations for any MCP tool, model, or entity.\n\n"+
|
||||
"To set annotations: provide both 'tool_name' and 'annotations'. "+
|
||||
"Calls resolvespec_set_annotation(tool_name, annotations) to persist the data.\n\n"+
|
||||
"To get annotations: provide only 'tool_name'. "+
|
||||
"Calls resolvespec_get_annotation(tool_name) and returns the stored annotations.\n\n"+
|
||||
"'tool_name' may be any identifier: an MCP tool name (e.g. 'read_public_users'), "+
|
||||
"a model/entity name (e.g. 'public.users'), or any other key.",
|
||||
),
|
||||
mcp.WithString("tool_name",
|
||||
mcp.Description("Name of the tool, model, or entity to annotate (e.g. 'read_public_users', 'public.users')."),
|
||||
mcp.Required(),
|
||||
),
|
||||
mcp.WithObject("annotations",
|
||||
mcp.Description("Annotation data to store. Omit to retrieve existing annotations instead of setting them."),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
|
||||
toolName, ok := args["tool_name"].(string)
|
||||
if !ok || toolName == "" {
|
||||
return mcp.NewToolResultError("missing required argument: tool_name"), nil
|
||||
}
|
||||
|
||||
annotations, hasAnnotations := args["annotations"]
|
||||
|
||||
if hasAnnotations && annotations != nil {
|
||||
return executeSetAnnotation(ctx, h, toolName, annotations)
|
||||
}
|
||||
return executeGetAnnotation(ctx, h, toolName)
|
||||
})
|
||||
}
|
||||
|
||||
func executeSetAnnotation(ctx context.Context, h *Handler, toolName string, annotations interface{}) (*mcp.CallToolResult, error) {
|
||||
jsonBytes, err := json.Marshal(annotations)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal annotations: %v", err)), nil
|
||||
}
|
||||
|
||||
_, err = h.db.Exec(ctx, "SELECT resolvespec_set_annotation($1, $2)", toolName, string(jsonBytes))
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to set annotation: %v", err)), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"tool_name": toolName,
|
||||
"action": "set",
|
||||
})
|
||||
}
|
||||
|
||||
func executeGetAnnotation(ctx context.Context, h *Handler, toolName string) (*mcp.CallToolResult, error) {
|
||||
var rows []map[string]interface{}
|
||||
err := h.db.Query(ctx, &rows, "SELECT resolvespec_get_annotation($1)", toolName)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to get annotation: %v", err)), nil
|
||||
}
|
||||
|
||||
var annotations interface{}
|
||||
if len(rows) > 0 {
|
||||
// The procedure returns a single value; extract the first column of the first row.
|
||||
for _, v := range rows[0] {
|
||||
annotations = v
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If the value is a []byte or string containing JSON, decode it so it round-trips cleanly.
|
||||
switch v := annotations.(type) {
|
||||
case []byte:
|
||||
var decoded interface{}
|
||||
if json.Unmarshal(v, &decoded) == nil {
|
||||
annotations = decoded
|
||||
}
|
||||
case string:
|
||||
var decoded interface{}
|
||||
if json.Unmarshal([]byte(v), &decoded) == nil {
|
||||
annotations = decoded
|
||||
}
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"tool_name": toolName,
|
||||
"action": "get",
|
||||
"annotations": annotations,
|
||||
})
|
||||
}
|
||||
71
pkg/resolvemcp/context.go
Normal file
71
pkg/resolvemcp/context.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package resolvemcp
|
||||
|
||||
import "context"
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
contextKeySchema contextKey = "schema"
|
||||
contextKeyEntity contextKey = "entity"
|
||||
contextKeyTableName contextKey = "tableName"
|
||||
contextKeyModel contextKey = "model"
|
||||
contextKeyModelPtr contextKey = "modelPtr"
|
||||
)
|
||||
|
||||
func WithSchema(ctx context.Context, schema string) context.Context {
|
||||
return context.WithValue(ctx, contextKeySchema, schema)
|
||||
}
|
||||
|
||||
func GetSchema(ctx context.Context) string {
|
||||
if v := ctx.Value(contextKeySchema); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithEntity(ctx context.Context, entity string) context.Context {
|
||||
return context.WithValue(ctx, contextKeyEntity, entity)
|
||||
}
|
||||
|
||||
func GetEntity(ctx context.Context) string {
|
||||
if v := ctx.Value(contextKeyEntity); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithTableName(ctx context.Context, tableName string) context.Context {
|
||||
return context.WithValue(ctx, contextKeyTableName, tableName)
|
||||
}
|
||||
|
||||
func GetTableName(ctx context.Context) string {
|
||||
if v := ctx.Value(contextKeyTableName); v != nil {
|
||||
return v.(string)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func WithModel(ctx context.Context, model interface{}) context.Context {
|
||||
return context.WithValue(ctx, contextKeyModel, model)
|
||||
}
|
||||
|
||||
func GetModel(ctx context.Context) interface{} {
|
||||
return ctx.Value(contextKeyModel)
|
||||
}
|
||||
|
||||
func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context {
|
||||
return context.WithValue(ctx, contextKeyModelPtr, modelPtr)
|
||||
}
|
||||
|
||||
func GetModelPtr(ctx context.Context) interface{} {
|
||||
return ctx.Value(contextKeyModelPtr)
|
||||
}
|
||||
|
||||
func withRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
|
||||
ctx = WithSchema(ctx, schema)
|
||||
ctx = WithEntity(ctx, entity)
|
||||
ctx = WithTableName(ctx, tableName)
|
||||
ctx = WithModel(ctx, model)
|
||||
ctx = WithModelPtr(ctx, modelPtr)
|
||||
return ctx
|
||||
}
|
||||
161
pkg/resolvemcp/cursor.go
Normal file
161
pkg/resolvemcp/cursor.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package resolvemcp
|
||||
|
||||
// Cursor-based pagination adapted from pkg/resolvespec/cursor.go.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
type cursorDirection int
|
||||
|
||||
const (
|
||||
cursorForward cursorDirection = 1
|
||||
cursorBackward cursorDirection = -1
|
||||
)
|
||||
|
||||
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
|
||||
// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support.
|
||||
func getCursorFilter(
|
||||
tableName string,
|
||||
pkName string,
|
||||
modelColumns []string,
|
||||
options common.RequestOptions,
|
||||
expandJoins map[string]string,
|
||||
) (string, error) {
|
||||
fullTableName := tableName
|
||||
if strings.Contains(tableName, ".") {
|
||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||
}
|
||||
|
||||
cursorID, direction := getActiveCursor(options)
|
||||
if cursorID == "" {
|
||||
return "", fmt.Errorf("no cursor provided for table %s", tableName)
|
||||
}
|
||||
|
||||
sortItems := options.Sort
|
||||
if len(sortItems) == 0 {
|
||||
return "", fmt.Errorf("no sort columns defined")
|
||||
}
|
||||
|
||||
var whereClauses []string
|
||||
joinSQL := ""
|
||||
reverse := direction < 0
|
||||
|
||||
for _, s := range sortItems {
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(col, ".")
|
||||
field := strings.TrimSpace(parts[len(parts)-1])
|
||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||
|
||||
desc := strings.EqualFold(s.Direction, "desc")
|
||||
if reverse {
|
||||
desc = !desc
|
||||
}
|
||||
|
||||
cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
|
||||
if err != nil {
|
||||
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if isJoin {
|
||||
if expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
if cursorCol == "" {
|
||||
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
op := "<"
|
||||
if desc {
|
||||
op = ">"
|
||||
}
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
|
||||
}
|
||||
|
||||
if len(whereClauses) == 0 {
|
||||
return "", fmt.Errorf("no valid sort columns after filtering")
|
||||
}
|
||||
|
||||
orSQL := buildCursorPriorityChain(whereClauses)
|
||||
|
||||
query := fmt.Sprintf(`EXISTS (
|
||||
SELECT 1
|
||||
FROM %s cursor_select
|
||||
%s
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
fullTableName,
|
||||
joinSQL,
|
||||
pkName,
|
||||
cursorID,
|
||||
orSQL,
|
||||
)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func getActiveCursor(options common.RequestOptions) (id string, direction cursorDirection) {
|
||||
if options.CursorForward != "" {
|
||||
return options.CursorForward, cursorForward
|
||||
}
|
||||
if options.CursorBackward != "" {
|
||||
return options.CursorBackward, cursorBackward
|
||||
}
|
||||
return "", 0
|
||||
}
|
||||
|
||||
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||
if strings.Contains(field, "->") {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
if modelColumns != nil {
|
||||
for _, col := range modelColumns {
|
||||
if strings.EqualFold(col, field) {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
if prefix != "" && prefix != tableName {
|
||||
return "", "", true, nil
|
||||
}
|
||||
|
||||
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||
}
|
||||
|
||||
func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||
cursorAlias = "cursor_select_" + alias
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||
return joinSQL, cursorAlias
|
||||
}
|
||||
|
||||
func buildCursorPriorityChain(clauses []string) string {
|
||||
var or []string
|
||||
for i := 0; i < len(clauses); i++ {
|
||||
and := strings.Join(clauses[:i+1], "\n AND ")
|
||||
or = append(or, "("+and+")")
|
||||
}
|
||||
return strings.Join(or, "\n OR ")
|
||||
}
|
||||
760
pkg/resolvemcp/handler.go
Normal file
760
pkg/resolvemcp/handler.go
Normal file
@@ -0,0 +1,760 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// Handler exposes registered database models as MCP tools and resources.
|
||||
type Handler struct {
|
||||
db common.Database
|
||||
registry common.ModelRegistry
|
||||
hooks *HookRegistry
|
||||
mcpServer *server.MCPServer
|
||||
config Config
|
||||
name string
|
||||
version string
|
||||
oauth2Regs []oauth2Registration
|
||||
oauthSrv *security.OAuthServer
|
||||
}
|
||||
|
||||
// NewHandler creates a Handler with the given database, model registry, and config.
|
||||
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
|
||||
h := &Handler{
|
||||
db: db,
|
||||
registry: registry,
|
||||
hooks: NewHookRegistry(),
|
||||
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
|
||||
config: cfg,
|
||||
name: "resolvemcp",
|
||||
version: "1.0.0",
|
||||
}
|
||||
registerAnnotationTool(h)
|
||||
return h
|
||||
}
|
||||
|
||||
// Hooks returns the hook registry.
|
||||
func (h *Handler) Hooks() *HookRegistry {
|
||||
return h.hooks
|
||||
}
|
||||
|
||||
// GetDatabase returns the underlying database.
|
||||
func (h *Handler) GetDatabase() common.Database {
|
||||
return h.db
|
||||
}
|
||||
|
||||
// MCPServer returns the underlying MCP server, e.g. to add custom tools.
|
||||
func (h *Handler) MCPServer() *server.MCPServer {
|
||||
return h.mcpServer
|
||||
}
|
||||
|
||||
// SSEServer returns an http.Handler that serves MCP over SSE.
|
||||
// Config.BasePath must be set. Config.BaseURL is used when set; if empty it is
|
||||
// detected automatically from each incoming request.
|
||||
func (h *Handler) SSEServer() http.Handler {
|
||||
if h.config.BaseURL != "" {
|
||||
return h.newSSEServer(h.config.BaseURL, h.config.BasePath)
|
||||
}
|
||||
return &dynamicSSEHandler{h: h}
|
||||
}
|
||||
|
||||
// StreamableHTTPServer returns an http.Handler that serves MCP over the streamable HTTP transport.
|
||||
// Unlike SSE (which requires two endpoints), streamable HTTP uses a single endpoint for all
|
||||
// client-server communication (POST for requests, GET for server-initiated messages).
|
||||
// Mount the returned handler at the desired path; the path itself becomes the MCP endpoint.
|
||||
func (h *Handler) StreamableHTTPServer() http.Handler {
|
||||
return server.NewStreamableHTTPServer(h.mcpServer)
|
||||
}
|
||||
|
||||
// newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values.
|
||||
func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
|
||||
return server.NewSSEServer(
|
||||
h.mcpServer,
|
||||
server.WithBaseURL(baseURL),
|
||||
server.WithStaticBasePath(basePath),
|
||||
)
|
||||
}
|
||||
|
||||
// dynamicSSEHandler detects BaseURL from each request and delegates to a cached
|
||||
// *server.SSEServer per detected baseURL. Used when Config.BaseURL is empty.
|
||||
type dynamicSSEHandler struct {
|
||||
h *Handler
|
||||
mu sync.Mutex
|
||||
pool map[string]*server.SSEServer
|
||||
}
|
||||
|
||||
func (d *dynamicSSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
baseURL := requestBaseURL(r)
|
||||
|
||||
d.mu.Lock()
|
||||
if d.pool == nil {
|
||||
d.pool = make(map[string]*server.SSEServer)
|
||||
}
|
||||
s, ok := d.pool[baseURL]
|
||||
if !ok {
|
||||
s = d.h.newSSEServer(baseURL, d.h.config.BasePath)
|
||||
d.pool[baseURL] = s
|
||||
}
|
||||
d.mu.Unlock()
|
||||
|
||||
s.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// requestBaseURL builds the base URL from an incoming request.
|
||||
// It honours the X-Forwarded-Proto header for deployments behind a proxy.
|
||||
func requestBaseURL(r *http.Request) string {
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
}
|
||||
return scheme + "://" + r.Host
|
||||
}
|
||||
|
||||
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
|
||||
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
|
||||
fullName := buildModelName(schema, entity)
|
||||
if err := h.registry.RegisterModel(fullName, model); err != nil {
|
||||
return err
|
||||
}
|
||||
registerModelTools(h, schema, entity, model)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterModelWithRules registers a model and sets per-entity operation rules
|
||||
// (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*, SecurityDisabled).
|
||||
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
|
||||
func (h *Handler) RegisterModelWithRules(schema, entity string, model interface{}, rules modelregistry.ModelRules) error {
|
||||
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
|
||||
if !ok {
|
||||
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
|
||||
}
|
||||
fullName := buildModelName(schema, entity)
|
||||
if err := reg.RegisterModelWithRules(fullName, model, rules); err != nil {
|
||||
return err
|
||||
}
|
||||
registerModelTools(h, schema, entity, model)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetModelRules updates the operation rules for an already-registered model.
|
||||
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
|
||||
func (h *Handler) SetModelRules(schema, entity string, rules modelregistry.ModelRules) error {
|
||||
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
|
||||
if !ok {
|
||||
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
|
||||
}
|
||||
return reg.SetModelRules(buildModelName(schema, entity), rules)
|
||||
}
|
||||
|
||||
// buildModelName builds the registry key for a model (same format as resolvespec).
|
||||
func buildModelName(schema, entity string) string {
|
||||
if schema == "" {
|
||||
return entity
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schema, entity)
|
||||
}
|
||||
|
||||
// getTableName returns the fully qualified table name for a model.
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||
if schemaName != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
|
||||
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
|
||||
if tableProvider, ok := model.(common.TableNameProvider); ok {
|
||||
tableName := tableProvider.TableName()
|
||||
if idx := strings.LastIndex(tableName, "."); idx != -1 {
|
||||
return tableName[:idx], tableName[idx+1:]
|
||||
}
|
||||
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||
return schemaProvider.SchemaName(), tableName
|
||||
}
|
||||
return defaultSchema, tableName
|
||||
}
|
||||
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||
return schemaProvider.SchemaName(), entity
|
||||
}
|
||||
return defaultSchema, entity
|
||||
}
|
||||
|
||||
// recoverPanic catches a panic from the current goroutine and returns it as an error.
|
||||
// Usage: defer recoverPanic(&returnedErr)
|
||||
func recoverPanic(err *error) {
|
||||
if r := recover(); r != nil {
|
||||
msg := fmt.Sprintf("%v", r)
|
||||
logger.Error("[resolvemcp] panic recovered: %s", msg)
|
||||
*err = fmt.Errorf("internal error: %s", msg)
|
||||
}
|
||||
}
|
||||
|
||||
// executeRead reads records from the database and returns raw data + metadata.
|
||||
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (_ interface{}, _ *common.Metadata, retErr error) {
|
||||
defer recoverPanic(&retErr)
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
unwrapped, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = unwrapped.Model
|
||||
modelType := unwrapped.ModelType
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, unwrapped.ModelPtr)
|
||||
|
||||
validator := common.NewColumnValidator(model)
|
||||
options = validator.FilterRequestOptions(options)
|
||||
|
||||
// BeforeHandle hook
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "read",
|
||||
Options: options,
|
||||
ID: id,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
|
||||
modelPtr := reflect.New(sliceType).Interface()
|
||||
|
||||
query := h.db.NewSelect().Model(modelPtr)
|
||||
|
||||
tempInstance := reflect.New(modelType).Interface()
|
||||
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||
query = query.Table(tableName)
|
||||
}
|
||||
|
||||
// Column selection
|
||||
if len(options.Columns) == 0 && len(options.ComputedColumns) > 0 {
|
||||
options.Columns = reflection.GetSQLModelColumns(model)
|
||||
}
|
||||
for _, col := range options.Columns {
|
||||
query = query.Column(reflection.ExtractSourceColumn(col))
|
||||
}
|
||||
for _, cu := range options.ComputedColumns {
|
||||
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
||||
}
|
||||
|
||||
// Filters
|
||||
query = h.applyFilters(query, options.Filters)
|
||||
|
||||
// Custom operators
|
||||
for _, customOp := range options.CustomOperators {
|
||||
query = query.Where(customOp.SQL)
|
||||
}
|
||||
|
||||
// Sorting
|
||||
for _, sort := range options.Sort {
|
||||
direction := "ASC"
|
||||
if strings.EqualFold(sort.Direction, "desc") {
|
||||
direction = "DESC"
|
||||
}
|
||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||
}
|
||||
|
||||
// Cursor pagination
|
||||
if options.CursorForward != "" || options.CursorBackward != "" {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
modelColumns := reflection.GetModelColumns(model)
|
||||
|
||||
if len(options.Sort) == 0 {
|
||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||
}
|
||||
|
||||
// expandJoins is empty for resolvemcp — no custom SQL join support yet
|
||||
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("cursor error: %w", err)
|
||||
}
|
||||
|
||||
if cursorFilter != "" {
|
||||
sanitized := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
||||
sanitized = common.EnsureOuterParentheses(sanitized)
|
||||
if sanitized != "" {
|
||||
query = query.Where(sanitized)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Count — must happen before preloads are applied; Bun panics when counting with relations.
|
||||
total, err := query.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error counting records: %w", err)
|
||||
}
|
||||
|
||||
// Pagination
|
||||
if options.Limit != nil && *options.Limit > 0 {
|
||||
query = query.Limit(*options.Limit)
|
||||
}
|
||||
if options.Offset != nil && *options.Offset > 0 {
|
||||
query = query.Offset(*options.Offset)
|
||||
}
|
||||
|
||||
// Preloads — applied after count to avoid Bun panic when counting with relations.
|
||||
if len(options.Preload) > 0 {
|
||||
var preloadErr error
|
||||
query, preloadErr = h.applyPreloads(model, query, options.Preload)
|
||||
if preloadErr != nil {
|
||||
return nil, nil, fmt.Errorf("failed to apply preloads: %w", preloadErr)
|
||||
}
|
||||
}
|
||||
|
||||
// BeforeRead hook
|
||||
hookCtx.Query = query
|
||||
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var data interface{}
|
||||
if id != "" {
|
||||
singleResult := reflect.New(modelType).Interface()
|
||||
pkName := reflection.GetPrimaryKeyName(singleResult)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
if err := query.Scan(ctx, singleResult); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil, fmt.Errorf("record not found")
|
||||
}
|
||||
return nil, nil, fmt.Errorf("query error: %w", err)
|
||||
}
|
||||
data = singleResult
|
||||
} else {
|
||||
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||
return nil, nil, fmt.Errorf("query error: %w", err)
|
||||
}
|
||||
data = reflect.ValueOf(modelPtr).Elem().Interface()
|
||||
}
|
||||
|
||||
limit := 0
|
||||
offset := 0
|
||||
if options.Limit != nil {
|
||||
limit = *options.Limit
|
||||
}
|
||||
if options.Offset != nil {
|
||||
offset = *options.Offset
|
||||
}
|
||||
|
||||
// Count is the number of records in this page, not the total.
|
||||
var pageCount int64
|
||||
if id != "" {
|
||||
pageCount = 1
|
||||
} else {
|
||||
pageCount = int64(reflect.ValueOf(data).Len())
|
||||
}
|
||||
|
||||
metadata := &common.Metadata{
|
||||
Total: int64(total),
|
||||
Filtered: int64(total),
|
||||
Count: pageCount,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
// AfterRead hook
|
||||
hookCtx.Result = data
|
||||
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return data, metadata, nil
|
||||
}
|
||||
|
||||
// executeCreate inserts one or more records.
|
||||
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (_ interface{}, retErr error) {
|
||||
defer recoverPanic(&retErr)
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "create",
|
||||
Data: data,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use potentially modified data
|
||||
data = hookCtx.Data
|
||||
|
||||
switch v := data.(type) {
|
||||
case map[string]interface{}:
|
||||
query := h.db.NewInsert().Table(tableName)
|
||||
for key, value := range v {
|
||||
query = query.Value(key, value)
|
||||
}
|
||||
if _, err := query.Exec(ctx); err != nil {
|
||||
return nil, fmt.Errorf("create error: %w", err)
|
||||
}
|
||||
hookCtx.Result = v
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||
}
|
||||
return v, nil
|
||||
|
||||
case []interface{}:
|
||||
results := make([]interface{}, 0, len(v))
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range v {
|
||||
itemMap, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("each item must be an object")
|
||||
}
|
||||
q := tx.NewInsert().Table(tableName)
|
||||
for key, value := range itemMap {
|
||||
q = q.Value(key, value)
|
||||
}
|
||||
if _, err := q.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
results = append(results, item)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch create error: %w", err)
|
||||
}
|
||||
hookCtx.Result = results
|
||||
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||
}
|
||||
return results, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("data must be an object or array of objects")
|
||||
}
|
||||
}
|
||||
|
||||
// executeUpdate updates a record by ID.
|
||||
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (_ interface{}, retErr error) {
|
||||
defer recoverPanic(&retErr)
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||
|
||||
updates, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("data must be an object")
|
||||
}
|
||||
|
||||
if id == "" {
|
||||
if idVal, exists := updates["id"]; exists {
|
||||
id = fmt.Sprintf("%v", idVal)
|
||||
}
|
||||
}
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("update requires an ID")
|
||||
}
|
||||
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
var updateResult interface{}
|
||||
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Read existing record
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
existingRecord := reflect.New(modelType).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("no records found to update")
|
||||
}
|
||||
return fmt.Errorf("error fetching existing record: %w", err)
|
||||
}
|
||||
|
||||
// Convert to map
|
||||
existingMap := make(map[string]interface{})
|
||||
jsonData, err := json.Marshal(existingRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling existing record: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||
return fmt.Errorf("error unmarshaling existing record: %w", err)
|
||||
}
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "update",
|
||||
ID: id,
|
||||
Data: updates,
|
||||
Tx: tx,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||
updates = modifiedData
|
||||
}
|
||||
|
||||
// Merge non-nil, non-empty values
|
||||
for key, newValue := range updates {
|
||||
if newValue == nil {
|
||||
continue
|
||||
}
|
||||
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||
continue
|
||||
}
|
||||
existingMap[key] = newValue
|
||||
}
|
||||
|
||||
q := tx.NewUpdate().Table(tableName).SetMap(existingMap).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
res, err := q.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating record: %w", err)
|
||||
}
|
||||
if res.RowsAffected() == 0 {
|
||||
return fmt.Errorf("no records found to update")
|
||||
}
|
||||
|
||||
updateResult = existingMap
|
||||
hookCtx.Result = updateResult
|
||||
return h.hooks.Execute(AfterUpdate, hookCtx)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return updateResult, nil
|
||||
}
|
||||
|
||||
// executeDelete deletes a record by ID.
|
||||
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (_ interface{}, retErr error) {
|
||||
defer recoverPanic(&retErr)
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("delete requires an ID")
|
||||
}
|
||||
|
||||
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
result, err := common.ValidateAndUnwrapModel(model)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid model: %w", err)
|
||||
}
|
||||
|
||||
model = result.Model
|
||||
tableName := h.getTableName(schema, entity, model)
|
||||
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Operation: "delete",
|
||||
ID: id,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
|
||||
var recordToDelete interface{}
|
||||
|
||||
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
record := reflect.New(modelType).Interface()
|
||||
selectQuery := tx.NewSelect().Model(record).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("record not found")
|
||||
}
|
||||
return fmt.Errorf("error fetching record: %w", err)
|
||||
}
|
||||
|
||||
res, err := tx.NewDelete().Table(tableName).
|
||||
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
|
||||
Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete error: %w", err)
|
||||
}
|
||||
if res.RowsAffected() == 0 {
|
||||
return fmt.Errorf("record not found or already deleted")
|
||||
}
|
||||
|
||||
recordToDelete = record
|
||||
hookCtx.Tx = tx
|
||||
hookCtx.Result = record
|
||||
return h.hooks.Execute(AfterDelete, hookCtx)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("[resolvemcp] Deleted record %s from %s.%s", id, schema, entity)
|
||||
return recordToDelete, nil
|
||||
}
|
||||
|
||||
// applyFilters applies all filters with OR grouping logic.
|
||||
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(filters) {
|
||||
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||
|
||||
if startORGroup {
|
||||
orGroup := []common.FilterOption{filters[i]}
|
||||
j := i + 1
|
||||
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||
orGroup = append(orGroup, filters[j])
|
||||
j++
|
||||
}
|
||||
query = h.applyFilterGroup(query, orGroup)
|
||||
i = j
|
||||
} else {
|
||||
condition, args := h.buildFilterCondition(filters[i])
|
||||
if condition != "" {
|
||||
query = query.Where(condition, args...)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for _, filter := range filters {
|
||||
condition, filterArgs := h.buildFilterCondition(filter)
|
||||
if condition != "" {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, filterArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return query
|
||||
}
|
||||
if len(conditions) == 1 {
|
||||
return query.Where(conditions[0], args...)
|
||||
}
|
||||
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
|
||||
}
|
||||
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (string, []interface{}) {
|
||||
switch filter.Operator {
|
||||
case "eq", "=":
|
||||
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
|
||||
case "neq", "!=", "<>":
|
||||
return fmt.Sprintf("%s != ?", filter.Column), []interface{}{filter.Value}
|
||||
case "gt", ">":
|
||||
return fmt.Sprintf("%s > ?", filter.Column), []interface{}{filter.Value}
|
||||
case "gte", ">=":
|
||||
return fmt.Sprintf("%s >= ?", filter.Column), []interface{}{filter.Value}
|
||||
case "lt", "<":
|
||||
return fmt.Sprintf("%s < ?", filter.Column), []interface{}{filter.Value}
|
||||
case "lte", "<=":
|
||||
return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value}
|
||||
case "like":
|
||||
return fmt.Sprintf("%s LIKE ?", filter.Column), []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
return fmt.Sprintf("%s ILIKE ?", filter.Column), []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition, args := common.BuildInCondition(filter.Column, filter.Value)
|
||||
return condition, args
|
||||
case "is_null":
|
||||
return fmt.Sprintf("%s IS NULL", filter.Column), nil
|
||||
case "is_not_null":
|
||||
return fmt.Sprintf("%s IS NOT NULL", filter.Column), nil
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||
for _, preload := range preloads {
|
||||
if preload.Relation == "" {
|
||||
continue
|
||||
}
|
||||
query = query.PreloadRelation(preload.Relation)
|
||||
}
|
||||
return query, nil
|
||||
}
|
||||
113
pkg/resolvemcp/hooks.go
Normal file
113
pkg/resolvemcp/hooks.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
)
|
||||
|
||||
// HookType defines the type of hook to execute
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||
BeforeHandle HookType = "before_handle"
|
||||
|
||||
BeforeRead HookType = "before_read"
|
||||
AfterRead HookType = "after_read"
|
||||
|
||||
BeforeCreate HookType = "before_create"
|
||||
AfterCreate HookType = "after_create"
|
||||
|
||||
BeforeUpdate HookType = "before_update"
|
||||
AfterUpdate HookType = "after_update"
|
||||
|
||||
BeforeDelete HookType = "before_delete"
|
||||
AfterDelete HookType = "after_delete"
|
||||
)
|
||||
|
||||
// HookContext contains all the data available to a hook
|
||||
type HookContext struct {
|
||||
Context context.Context
|
||||
Handler *Handler
|
||||
Schema string
|
||||
Entity string
|
||||
Model interface{}
|
||||
Options common.RequestOptions
|
||||
Operation string
|
||||
ID string
|
||||
Data interface{}
|
||||
Result interface{}
|
||||
Error error
|
||||
Query common.SelectQuery
|
||||
Abort bool
|
||||
AbortMessage string
|
||||
AbortCode int
|
||||
Tx common.Database
|
||||
}
|
||||
|
||||
// HookFunc is the signature for hook functions
|
||||
type HookFunc func(*HookContext) error
|
||||
|
||||
// HookRegistry manages all registered hooks
|
||||
type HookRegistry struct {
|
||||
hooks map[HookType][]HookFunc
|
||||
}
|
||||
|
||||
func NewHookRegistry() *HookRegistry {
|
||||
return &HookRegistry{
|
||||
hooks: make(map[HookType][]HookFunc),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
|
||||
if r.hooks == nil {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
}
|
||||
r.hooks[hookType] = append(r.hooks[hookType], hook)
|
||||
logger.Info("Registered resolvemcp hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
|
||||
}
|
||||
|
||||
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
|
||||
for _, hookType := range hookTypes {
|
||||
r.Register(hookType, hook)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||
hooks, exists := r.hooks[hookType]
|
||||
if !exists || len(hooks) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Debug("Executing %d resolvemcp hook(s) for %s", len(hooks), hookType)
|
||||
|
||||
for i, hook := range hooks {
|
||||
if err := hook(ctx); err != nil {
|
||||
logger.Error("resolvemcp hook %d for %s failed: %v", i+1, hookType, err)
|
||||
return fmt.Errorf("hook execution failed: %w", err)
|
||||
}
|
||||
|
||||
if ctx.Abort {
|
||||
logger.Warn("resolvemcp hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *HookRegistry) Clear(hookType HookType) {
|
||||
delete(r.hooks, hookType)
|
||||
}
|
||||
|
||||
func (r *HookRegistry) ClearAll() {
|
||||
r.hooks = make(map[HookType][]HookFunc)
|
||||
}
|
||||
|
||||
func (r *HookRegistry) HasHooks(hookType HookType) bool {
|
||||
hooks, exists := r.hooks[hookType]
|
||||
return exists && len(hooks) > 0
|
||||
}
|
||||
264
pkg/resolvemcp/oauth2.go
Normal file
264
pkg/resolvemcp/oauth2.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// OAuth2 registration on the Handler
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// oauth2Registration stores a configured auth provider and its route config.
|
||||
type oauth2Registration struct {
|
||||
auth *security.DatabaseAuthenticator
|
||||
cfg OAuth2RouteConfig
|
||||
}
|
||||
|
||||
// RegisterOAuth2 attaches an OAuth2 provider to the Handler.
|
||||
// The login and callback HTTP routes are served by HTTPHandler / StreamableHTTPMux.
|
||||
// Call this once per provider before serving requests.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
|
||||
// handler.RegisterOAuth2(auth, resolvemcp.OAuth2RouteConfig{
|
||||
// ProviderName: "google",
|
||||
// LoginPath: "/auth/google/login",
|
||||
// CallbackPath: "/auth/google/callback",
|
||||
// AfterLoginRedirect: "/",
|
||||
// })
|
||||
func (h *Handler) RegisterOAuth2(auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
|
||||
h.oauth2Regs = append(h.oauth2Regs, oauth2Registration{auth: auth, cfg: cfg})
|
||||
}
|
||||
|
||||
// HTTPHandler returns a single http.Handler that serves:
|
||||
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
|
||||
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
|
||||
// - The MCP SSE transport wrapped with required authentication middleware
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// auth := security.NewGoogleAuthenticator(...)
|
||||
// handler.RegisterOAuth2(auth, cfg)
|
||||
// handler.EnableOAuthServer(resolvemcp.OAuthServerConfig{Issuer: "https://api.example.com"})
|
||||
// security.RegisterSecurityHooks(handler, securityList)
|
||||
// http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
|
||||
func (h *Handler) HTTPHandler(securityList *security.SecurityList) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
if h.oauthSrv != nil {
|
||||
h.mountOAuthServerRoutes(mux)
|
||||
}
|
||||
h.mountOAuth2Routes(mux)
|
||||
|
||||
mcpHandler := h.AuthedSSEServer(securityList)
|
||||
basePath := h.config.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/mcp"
|
||||
}
|
||||
mux.Handle(basePath+"/sse", mcpHandler)
|
||||
mux.Handle(basePath+"/message", mcpHandler)
|
||||
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// StreamableHTTPMux returns a single http.Handler that serves:
|
||||
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
|
||||
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
|
||||
// - The MCP streamable HTTP transport wrapped with required authentication middleware
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// http.ListenAndServe(":8080", handler.StreamableHTTPMux(securityList))
|
||||
func (h *Handler) StreamableHTTPMux(securityList *security.SecurityList) http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
if h.oauthSrv != nil {
|
||||
h.mountOAuthServerRoutes(mux)
|
||||
}
|
||||
h.mountOAuth2Routes(mux)
|
||||
|
||||
mcpHandler := h.AuthedStreamableHTTPServer(securityList)
|
||||
basePath := h.config.BasePath
|
||||
if basePath == "" {
|
||||
basePath = "/mcp"
|
||||
}
|
||||
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
|
||||
mux.Handle(basePath, mcpHandler)
|
||||
|
||||
return mux
|
||||
}
|
||||
|
||||
// mountOAuth2Routes registers all stored OAuth2 login+callback routes onto mux.
|
||||
func (h *Handler) mountOAuth2Routes(mux *http.ServeMux) {
|
||||
for _, reg := range h.oauth2Regs {
|
||||
var cookieOpts []security.SessionCookieOptions
|
||||
if reg.cfg.CookieOptions != nil {
|
||||
cookieOpts = append(cookieOpts, *reg.cfg.CookieOptions)
|
||||
}
|
||||
mux.Handle(reg.cfg.LoginPath, OAuth2LoginHandler(reg.auth, reg.cfg.ProviderName))
|
||||
mux.Handle(reg.cfg.CallbackPath, OAuth2CallbackHandler(reg.auth, reg.cfg.ProviderName, reg.cfg.AfterLoginRedirect, cookieOpts...))
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Auth-wrapped transports
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// AuthedSSEServer wraps SSEServer with required authentication middleware from pkg/security.
|
||||
// The middleware reads the session cookie / Authorization header and populates the user
|
||||
// context into the request context, making it available to BeforeHandle security hooks.
|
||||
// Unauthenticated requests receive 401 before reaching any MCP tool.
|
||||
func (h *Handler) AuthedSSEServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewAuthMiddleware(securityList)(h.SSEServer())
|
||||
}
|
||||
|
||||
// OptionalAuthSSEServer wraps SSEServer with optional authentication middleware.
|
||||
// Unauthenticated requests continue as guest rather than returning 401.
|
||||
// Use together with RegisterSecurityHooks and per-model CanPublicRead/Write rules
|
||||
// to allow mixed public/private access.
|
||||
func (h *Handler) OptionalAuthSSEServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewOptionalAuthMiddleware(securityList)(h.SSEServer())
|
||||
}
|
||||
|
||||
// AuthedStreamableHTTPServer wraps StreamableHTTPServer with required authentication middleware.
|
||||
func (h *Handler) AuthedStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewAuthMiddleware(securityList)(h.StreamableHTTPServer())
|
||||
}
|
||||
|
||||
// OptionalAuthStreamableHTTPServer wraps StreamableHTTPServer with optional authentication middleware.
|
||||
func (h *Handler) OptionalAuthStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
|
||||
return security.NewOptionalAuthMiddleware(securityList)(h.StreamableHTTPServer())
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// OAuth2 route config and standalone handlers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// OAuth2RouteConfig configures the OAuth2 HTTP endpoints for a single provider.
|
||||
type OAuth2RouteConfig struct {
|
||||
// ProviderName is the OAuth2 provider name as registered with WithOAuth2()
|
||||
// (e.g. "google", "github", "microsoft").
|
||||
ProviderName string
|
||||
|
||||
// LoginPath is the HTTP path that redirects the browser to the OAuth2 provider
|
||||
// (e.g. "/auth/google/login").
|
||||
LoginPath string
|
||||
|
||||
// CallbackPath is the HTTP path that the OAuth2 provider redirects back to
|
||||
// (e.g. "/auth/google/callback"). Must match the RedirectURL in OAuth2Config.
|
||||
CallbackPath string
|
||||
|
||||
// AfterLoginRedirect is the URL to redirect the browser to after a successful
|
||||
// login. When empty the LoginResponse JSON is written directly to the response.
|
||||
AfterLoginRedirect string
|
||||
|
||||
// CookieOptions customises the session cookie written on successful login.
|
||||
// Defaults to HttpOnly, Secure, SameSite=Lax when nil.
|
||||
CookieOptions *security.SessionCookieOptions
|
||||
}
|
||||
|
||||
// OAuth2LoginHandler returns an http.HandlerFunc that redirects the browser to
|
||||
// the OAuth2 provider's authorization URL.
|
||||
//
|
||||
// Register it on any router:
|
||||
//
|
||||
// mux.Handle("/auth/google/login", resolvemcp.OAuth2LoginHandler(auth, "google"))
|
||||
func OAuth2LoginHandler(auth *security.DatabaseAuthenticator, providerName string) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := auth.OAuth2GenerateState()
|
||||
if err != nil {
|
||||
http.Error(w, "failed to generate state", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
authURL, err := auth.OAuth2GetAuthURL(providerName, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
}
|
||||
}
|
||||
|
||||
// OAuth2CallbackHandler returns an http.HandlerFunc that handles the OAuth2 provider
|
||||
// callback: exchanges the authorization code for a session token, writes the session
|
||||
// cookie, then either redirects to afterLoginRedirect or writes the LoginResponse as JSON.
|
||||
//
|
||||
// Register it on any router:
|
||||
//
|
||||
// mux.Handle("/auth/google/callback", resolvemcp.OAuth2CallbackHandler(auth, "google", "/dashboard"))
|
||||
func OAuth2CallbackHandler(auth *security.DatabaseAuthenticator, providerName, afterLoginRedirect string, cookieOpts ...security.SessionCookieOptions) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
if code == "" {
|
||||
http.Error(w, "missing code parameter", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), providerName, code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
security.SetSessionCookie(w, loginResp, cookieOpts...)
|
||||
|
||||
if afterLoginRedirect != "" {
|
||||
http.Redirect(w, r, afterLoginRedirect, http.StatusTemporaryRedirect)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(loginResp) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Gorilla Mux convenience helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// SetupMuxOAuth2Routes registers the OAuth2 login and callback routes on a Gorilla Mux router.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
|
||||
// ProviderName: "google", LoginPath: "/auth/google/login",
|
||||
// CallbackPath: "/auth/google/callback", AfterLoginRedirect: "/",
|
||||
// })
|
||||
func SetupMuxOAuth2Routes(muxRouter *mux.Router, auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
|
||||
var cookieOpts []security.SessionCookieOptions
|
||||
if cfg.CookieOptions != nil {
|
||||
cookieOpts = append(cookieOpts, *cfg.CookieOptions)
|
||||
}
|
||||
|
||||
muxRouter.Handle(cfg.LoginPath,
|
||||
OAuth2LoginHandler(auth, cfg.ProviderName),
|
||||
).Methods(http.MethodGet)
|
||||
|
||||
muxRouter.Handle(cfg.CallbackPath,
|
||||
OAuth2CallbackHandler(auth, cfg.ProviderName, cfg.AfterLoginRedirect, cookieOpts...),
|
||||
).Methods(http.MethodGet)
|
||||
}
|
||||
|
||||
// SetupMuxRoutesWithAuth mounts the MCP SSE endpoints on a Gorilla Mux router
|
||||
// with required authentication middleware applied.
|
||||
func SetupMuxRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.AuthedSSEServer(securityList)
|
||||
|
||||
muxRouter.Handle(basePath+"/sse", h).Methods(http.MethodGet, http.MethodOptions)
|
||||
muxRouter.Handle(basePath+"/message", h).Methods(http.MethodPost, http.MethodOptions)
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
|
||||
// SetupMuxStreamableHTTPRoutesWithAuth mounts the MCP streamable HTTP endpoint on a
|
||||
// Gorilla Mux router with required authentication middleware applied.
|
||||
func SetupMuxStreamableHTTPRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.AuthedStreamableHTTPServer(securityList)
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
51
pkg/resolvemcp/oauth2_server.go
Normal file
51
pkg/resolvemcp/oauth2_server.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// EnableOAuthServer activates the MCP-standard OAuth2 authorization server on this Handler.
|
||||
//
|
||||
// Pass a DatabaseAuthenticator to enable direct username/password login — the server acts as
|
||||
// its own identity provider and renders a login form at /oauth/authorize. Pass nil to use
|
||||
// only external providers registered via RegisterOAuth2Provider.
|
||||
//
|
||||
// After calling this, HTTPHandler and StreamableHTTPMux serve the full set of RFC-compliant
|
||||
// endpoints required by MCP clients alongside the MCP transport:
|
||||
//
|
||||
// GET /.well-known/oauth-authorization-server RFC 8414 — auto-discovery
|
||||
// POST /oauth/register RFC 7591 — dynamic client registration
|
||||
// GET /oauth/authorize OAuth 2.1 + PKCE — start login
|
||||
// POST /oauth/authorize Login form submission (password flow)
|
||||
// POST /oauth/token Bearer token exchange + refresh
|
||||
// GET /oauth/provider/callback External provider redirect target
|
||||
func (h *Handler) EnableOAuthServer(cfg security.OAuthServerConfig, auth *security.DatabaseAuthenticator) {
|
||||
h.oauthSrv = security.NewOAuthServer(cfg, auth)
|
||||
// Wire any external providers already registered via RegisterOAuth2
|
||||
for _, reg := range h.oauth2Regs {
|
||||
h.oauthSrv.RegisterExternalProvider(reg.auth, reg.cfg.ProviderName)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterOAuth2Provider adds an external OAuth2 provider to the MCP OAuth2 authorization server.
|
||||
// EnableOAuthServer must be called before this. The auth must have been configured with
|
||||
// WithOAuth2(providerName, ...) for the given provider name.
|
||||
func (h *Handler) RegisterOAuth2Provider(auth *security.DatabaseAuthenticator, providerName string) {
|
||||
if h.oauthSrv != nil {
|
||||
h.oauthSrv.RegisterExternalProvider(auth, providerName)
|
||||
}
|
||||
}
|
||||
|
||||
// mountOAuthServerRoutes mounts the security.OAuthServer's HTTP handler onto mux.
|
||||
func (h *Handler) mountOAuthServerRoutes(mux *http.ServeMux) {
|
||||
oauthHandler := h.oauthSrv.HTTPHandler()
|
||||
// Delegate all /oauth/ and /.well-known/ paths to the OAuth server
|
||||
mux.Handle("/.well-known/", oauthHandler)
|
||||
mux.Handle("/oauth/", oauthHandler)
|
||||
if h.oauthSrv != nil {
|
||||
// Also mount the external provider callback path if it differs from /oauth/
|
||||
mux.Handle(h.oauthSrv.ProviderCallbackPath(), oauthHandler)
|
||||
}
|
||||
}
|
||||
133
pkg/resolvemcp/resolvemcp.go
Normal file
133
pkg/resolvemcp/resolvemcp.go
Normal file
@@ -0,0 +1,133 @@
|
||||
// Package resolvemcp exposes registered database models as Model Context Protocol (MCP) tools
|
||||
// and resources over HTTP/SSE transport.
|
||||
//
|
||||
// It mirrors the resolvespec package patterns:
|
||||
// - Same model registration API
|
||||
// - Same filter, sort, cursor pagination, preload options
|
||||
// - Same lifecycle hook system
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{BaseURL: "http://localhost:8080"})
|
||||
// handler.RegisterModel("public", "users", &User{})
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// resolvemcp.SetupMuxRoutes(r, handler)
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/uptrace/bun"
|
||||
bunrouter "github.com/uptrace/bunrouter"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// Config holds configuration for the resolvemcp handler.
|
||||
type Config struct {
|
||||
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||
BaseURL string
|
||||
|
||||
// BasePath is the URL path prefix where the MCP endpoints are mounted (e.g. "/mcp").
|
||||
// If empty, the path is detected from each incoming request automatically.
|
||||
BasePath string
|
||||
}
|
||||
|
||||
// NewHandlerWithGORM creates a Handler backed by a GORM database connection.
|
||||
func NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler {
|
||||
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry(), cfg)
|
||||
}
|
||||
|
||||
// NewHandlerWithBun creates a Handler backed by a Bun database connection.
|
||||
func NewHandlerWithBun(db *bun.DB, cfg Config) *Handler {
|
||||
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry(), cfg)
|
||||
}
|
||||
|
||||
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
|
||||
func NewHandlerWithDB(db common.Database, cfg Config) *Handler {
|
||||
return NewHandler(db, modelregistry.NewModelRegistry(), cfg)
|
||||
}
|
||||
|
||||
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router
|
||||
// using the base path from Config.BasePath (falls back to "/mcp" if empty).
|
||||
//
|
||||
// Two routes are registered:
|
||||
// - GET {basePath}/sse — SSE connection endpoint (client subscribes here)
|
||||
// - POST {basePath}/message — JSON-RPC message endpoint (client sends requests here)
|
||||
//
|
||||
// To protect these routes with authentication, wrap the mux router or apply middleware
|
||||
// before calling SetupMuxRoutes.
|
||||
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.SSEServer()
|
||||
|
||||
muxRouter.Handle(basePath+"/sse", h).Methods("GET", "OPTIONS")
|
||||
muxRouter.Handle(basePath+"/message", h).Methods("POST", "OPTIONS")
|
||||
|
||||
// Convenience: also expose the full SSE server at basePath for clients that
|
||||
// use ServeHTTP directly (e.g. net/http default mux).
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
|
||||
// SetupBunRouterRoutes mounts the MCP HTTP/SSE endpoints on a bunrouter router
|
||||
// using the base path from Config.BasePath.
|
||||
//
|
||||
// Two routes are registered:
|
||||
// - GET {basePath}/sse — SSE connection endpoint
|
||||
// - POST {basePath}/message — JSON-RPC message endpoint
|
||||
func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.SSEServer()
|
||||
|
||||
router.GET(basePath+"/sse", bunrouter.HTTPHandler(h))
|
||||
router.POST(basePath+"/message", bunrouter.HTTPHandler(h))
|
||||
}
|
||||
|
||||
// NewSSEServer returns an http.Handler that serves MCP over SSE.
|
||||
// If Config.BasePath is set it is used directly; otherwise the base path is
|
||||
// detected from each incoming request (by stripping the "/sse" or "/message" suffix).
|
||||
//
|
||||
// h := resolvemcp.NewSSEServer(handler)
|
||||
// http.Handle("/api/mcp/", h)
|
||||
func NewSSEServer(handler *Handler) http.Handler {
|
||||
return handler.SSEServer()
|
||||
}
|
||||
|
||||
// SetupMuxStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on the given Gorilla Mux router.
|
||||
// The streamable HTTP transport uses a single endpoint (Config.BasePath) for all communication:
|
||||
// POST for client→server messages, GET for server→client streaming.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler) // mounts at Config.BasePath
|
||||
func SetupMuxStreamableHTTPRoutes(muxRouter *mux.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.StreamableHTTPServer()
|
||||
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||
}
|
||||
|
||||
// SetupBunRouterStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on a bunrouter router.
|
||||
// The streamable HTTP transport uses a single endpoint (Config.BasePath).
|
||||
func SetupBunRouterStreamableHTTPRoutes(router *bunrouter.Router, handler *Handler) {
|
||||
basePath := handler.config.BasePath
|
||||
h := handler.StreamableHTTPServer()
|
||||
router.GET(basePath, bunrouter.HTTPHandler(h))
|
||||
router.POST(basePath, bunrouter.HTTPHandler(h))
|
||||
router.DELETE(basePath, bunrouter.HTTPHandler(h))
|
||||
}
|
||||
|
||||
// NewStreamableHTTPHandler returns an http.Handler that serves MCP over the streamable HTTP transport.
|
||||
// Mount it at the desired path; that path becomes the MCP endpoint.
|
||||
//
|
||||
// h := resolvemcp.NewStreamableHTTPHandler(handler)
|
||||
// http.Handle("/mcp", h)
|
||||
// engine.Any("/mcp", gin.WrapH(h))
|
||||
func NewStreamableHTTPHandler(handler *Handler) http.Handler {
|
||||
return handler.StreamableHTTPServer()
|
||||
}
|
||||
115
pkg/resolvemcp/security_hooks.go
Normal file
115
pkg/resolvemcp/security_hooks.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
)
|
||||
|
||||
// RegisterSecurityHooks wires the security package's access-control layer into the
|
||||
// resolvemcp handler. Call it once after creating the handler, before registering models.
|
||||
//
|
||||
// The following controls are applied:
|
||||
// - Per-entity operation rules (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*)
|
||||
// stored via RegisterModelWithRules / SetModelRules.
|
||||
// - Row-level security: WHERE clause injected per user from the SecurityList provider.
|
||||
// - Column-level security: sensitive columns masked/hidden in read results.
|
||||
// - Audit logging after each read.
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// BeforeHandle: enforce model-level operation rules (auth check).
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// BeforeRead (1st): load RLS + CLS rules from the provider into SecurityList.
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
return security.LoadSecurityRules(newSecurityContext(hookCtx), securityList)
|
||||
})
|
||||
|
||||
// BeforeRead (2nd): apply row-level security — injects a WHERE clause into the query.
|
||||
// resolvemcp has no separate BeforeScan hook; the query is available in BeforeRead.
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
return security.ApplyRowSecurity(newSecurityContext(hookCtx), securityList)
|
||||
})
|
||||
|
||||
// AfterRead (1st): apply column-level security — mask/hide columns in the result.
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
return security.ApplyColumnSecurity(newSecurityContext(hookCtx), securityList)
|
||||
})
|
||||
|
||||
// AfterRead (2nd): audit log.
|
||||
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||
return security.LogDataAccess(newSecurityContext(hookCtx))
|
||||
})
|
||||
|
||||
// BeforeUpdate: enforce CanUpdate rule.
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
return security.CheckModelUpdateAllowed(newSecurityContext(hookCtx))
|
||||
})
|
||||
|
||||
// BeforeDelete: enforce CanDelete rule.
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
return security.CheckModelDeleteAllowed(newSecurityContext(hookCtx))
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for resolvemcp handler")
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// securityContext — adapts resolvemcp.HookContext to security.SecurityContext
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
type securityContext struct {
|
||||
ctx *HookContext
|
||||
}
|
||||
|
||||
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||
return &securityContext{ctx: ctx}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetContext() context.Context {
|
||||
return s.ctx.Context
|
||||
}
|
||||
|
||||
func (s *securityContext) GetUserID() (int, bool) {
|
||||
return security.GetUserID(s.ctx.Context)
|
||||
}
|
||||
|
||||
func (s *securityContext) GetSchema() string {
|
||||
return s.ctx.Schema
|
||||
}
|
||||
|
||||
func (s *securityContext) GetEntity() string {
|
||||
return s.ctx.Entity
|
||||
}
|
||||
|
||||
func (s *securityContext) GetModel() interface{} {
|
||||
return s.ctx.Model
|
||||
}
|
||||
|
||||
func (s *securityContext) GetQuery() interface{} {
|
||||
return s.ctx.Query
|
||||
}
|
||||
|
||||
func (s *securityContext) SetQuery(query interface{}) {
|
||||
if q, ok := query.(common.SelectQuery); ok {
|
||||
s.ctx.Query = q
|
||||
}
|
||||
}
|
||||
|
||||
func (s *securityContext) GetResult() interface{} {
|
||||
return s.ctx.Result
|
||||
}
|
||||
|
||||
func (s *securityContext) SetResult(result interface{}) {
|
||||
s.ctx.Result = result
|
||||
}
|
||||
692
pkg/resolvemcp/tools.go
Normal file
692
pkg/resolvemcp/tools.go
Normal file
@@ -0,0 +1,692 @@
|
||||
package resolvemcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||
)
|
||||
|
||||
// toolName builds the MCP tool name for a given operation and model.
|
||||
func toolName(operation, schema, entity string) string {
|
||||
if schema == "" {
|
||||
return fmt.Sprintf("%s_%s", operation, entity)
|
||||
}
|
||||
return fmt.Sprintf("%s_%s_%s", operation, schema, entity)
|
||||
}
|
||||
|
||||
// registerModelTools registers the four CRUD tools and resource for a model.
|
||||
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
|
||||
info := buildModelInfo(schema, entity, model)
|
||||
registerReadTool(h, schema, entity, info)
|
||||
registerCreateTool(h, schema, entity, info)
|
||||
registerUpdateTool(h, schema, entity, info)
|
||||
registerDeleteTool(h, schema, entity, info)
|
||||
registerModelResource(h, schema, entity, info)
|
||||
|
||||
logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Model introspection
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// modelInfo holds pre-computed metadata for a model used in tool descriptions.
|
||||
type modelInfo struct {
|
||||
fullName string // e.g. "public.users"
|
||||
pkName string // e.g. "id"
|
||||
columns []columnInfo
|
||||
relationNames []string
|
||||
schemaDoc string // formatted multi-line schema listing
|
||||
}
|
||||
|
||||
type columnInfo struct {
|
||||
jsonName string
|
||||
sqlName string
|
||||
goType string
|
||||
sqlType string
|
||||
isPrimary bool
|
||||
isUnique bool
|
||||
isFK bool
|
||||
nullable bool
|
||||
}
|
||||
|
||||
// buildModelInfo extracts column metadata and pre-builds the schema documentation string.
|
||||
func buildModelInfo(schema, entity string, model interface{}) modelInfo {
|
||||
info := modelInfo{
|
||||
fullName: buildModelName(schema, entity),
|
||||
pkName: reflection.GetPrimaryKeyName(model),
|
||||
}
|
||||
|
||||
// Unwrap to base struct type
|
||||
modelType := reflect.TypeOf(model)
|
||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return info
|
||||
}
|
||||
|
||||
details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem())
|
||||
|
||||
for _, d := range details {
|
||||
// Derive the JSON name from the struct field
|
||||
jsonName := fieldJSONName(modelType, d.Name)
|
||||
if jsonName == "" || jsonName == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip relation fields (slice or user-defined struct that isn't time.Time).
|
||||
fieldType, found := modelType.FieldByName(d.Name)
|
||||
if found {
|
||||
ft := fieldType.Type
|
||||
if ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
|
||||
if ft.Kind() == reflect.Slice || isUserStruct {
|
||||
info.relationNames = append(info.relationNames, jsonName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
sqlName := d.SQLName
|
||||
if sqlName == "" {
|
||||
sqlName = jsonName
|
||||
}
|
||||
|
||||
// Derive Go type name, unwrapping pointer if needed.
|
||||
goType := d.DataType
|
||||
if goType == "" && found {
|
||||
ft := fieldType.Type
|
||||
for ft.Kind() == reflect.Ptr {
|
||||
ft = ft.Elem()
|
||||
}
|
||||
goType = ft.Name()
|
||||
}
|
||||
|
||||
// isPrimary: use both the GORM-tag detection and a name comparison against
|
||||
// the known primary key (handles camelCase "primaryKey" tags correctly).
|
||||
isPrimary := d.SQLKey == "primary_key" ||
|
||||
(info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName))
|
||||
|
||||
ci := columnInfo{
|
||||
jsonName: jsonName,
|
||||
sqlName: sqlName,
|
||||
goType: goType,
|
||||
sqlType: d.SQLDataType,
|
||||
isPrimary: isPrimary,
|
||||
isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex",
|
||||
isFK: d.SQLKey == "foreign_key",
|
||||
nullable: d.Nullable,
|
||||
}
|
||||
info.columns = append(info.columns, ci)
|
||||
}
|
||||
|
||||
info.schemaDoc = buildSchemaDoc(info)
|
||||
return info
|
||||
}
|
||||
|
||||
// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name.
|
||||
func fieldJSONName(modelType reflect.Type, fieldName string) string {
|
||||
field, ok := modelType.FieldByName(fieldName)
|
||||
if !ok {
|
||||
return fieldName
|
||||
}
|
||||
tag := field.Tag.Get("json")
|
||||
if tag == "" {
|
||||
return fieldName
|
||||
}
|
||||
parts := strings.SplitN(tag, ",", 2)
|
||||
if parts[0] == "" {
|
||||
return fieldName
|
||||
}
|
||||
return parts[0]
|
||||
}
|
||||
|
||||
// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions.
|
||||
func buildSchemaDoc(info modelInfo) string {
|
||||
if len(info.columns) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Columns:\n")
|
||||
for _, c := range info.columns {
|
||||
line := fmt.Sprintf(" • %s", c.jsonName)
|
||||
|
||||
typeDesc := c.goType
|
||||
if c.sqlType != "" {
|
||||
typeDesc = c.sqlType
|
||||
}
|
||||
if typeDesc != "" {
|
||||
line += fmt.Sprintf(" (%s)", typeDesc)
|
||||
}
|
||||
|
||||
var flags []string
|
||||
if c.isPrimary {
|
||||
flags = append(flags, "primary key")
|
||||
}
|
||||
if c.isUnique {
|
||||
flags = append(flags, "unique")
|
||||
}
|
||||
if c.isFK {
|
||||
flags = append(flags, "foreign key")
|
||||
}
|
||||
if !c.nullable && !c.isPrimary {
|
||||
flags = append(flags, "not null")
|
||||
} else if c.nullable {
|
||||
flags = append(flags, "nullable")
|
||||
}
|
||||
if len(flags) > 0 {
|
||||
line += " — " + strings.Join(flags, ", ")
|
||||
}
|
||||
|
||||
sb.WriteString(line + "\n")
|
||||
}
|
||||
|
||||
if len(info.relationNames) > 0 {
|
||||
sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// columnNameList returns a comma-separated list of JSON column names (for descriptions).
|
||||
func columnNameList(cols []columnInfo) string {
|
||||
names := make([]string, len(cols))
|
||||
for i, c := range cols {
|
||||
names[i] = c.jsonName
|
||||
}
|
||||
return strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
// writableColumnNames returns JSON names for all non-primary-key columns.
|
||||
func writableColumnNames(cols []columnInfo) []string {
|
||||
var names []string
|
||||
for _, c := range cols {
|
||||
if !c.isPrimary {
|
||||
names = append(names, c.jsonName)
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Read tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerReadTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("read", schema, entity)
|
||||
|
||||
var descParts []string
|
||||
descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName))
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName))
|
||||
}
|
||||
if info.schemaDoc != "" {
|
||||
descParts = append(descParts, info.schemaDoc)
|
||||
}
|
||||
descParts = append(descParts,
|
||||
"Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.",
|
||||
"Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.",
|
||||
"Sorting: each sort object requires 'column' and 'direction' (asc or desc).",
|
||||
)
|
||||
if len(info.relationNames) > 0 {
|
||||
descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", ")))
|
||||
}
|
||||
|
||||
description := strings.Join(descParts, "\n\n")
|
||||
|
||||
filterDesc := `Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`
|
||||
if len(info.columns) > 0 {
|
||||
filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
|
||||
}
|
||||
|
||||
sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]`
|
||||
if len(info.columns) > 0 {
|
||||
sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
|
||||
}
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithString("id",
|
||||
mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)),
|
||||
),
|
||||
mcp.WithNumber("limit",
|
||||
mcp.Description("Maximum number of records to return per page. Recommended: 10–100."),
|
||||
),
|
||||
mcp.WithNumber("offset",
|
||||
mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."),
|
||||
),
|
||||
mcp.WithString("cursor_forward",
|
||||
mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)),
|
||||
),
|
||||
mcp.WithString("cursor_backward",
|
||||
mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)),
|
||||
),
|
||||
mcp.WithArray("columns",
|
||||
mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))),
|
||||
),
|
||||
mcp.WithArray("omit_columns",
|
||||
mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))),
|
||||
),
|
||||
mcp.WithArray("filters",
|
||||
mcp.Description(filterDesc),
|
||||
),
|
||||
mcp.WithArray("sort",
|
||||
mcp.Description(sortDesc),
|
||||
),
|
||||
mcp.WithArray("preloads",
|
||||
mcp.Description(buildPreloadDesc(info)),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
id, _ := args["id"].(string)
|
||||
options := parseRequestOptions(args)
|
||||
|
||||
data, metadata, err := h.executeRead(ctx, schema, entity, id, options)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": data,
|
||||
"metadata": metadata,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func buildPreloadDesc(info modelInfo) string {
|
||||
if len(info.relationNames) == 0 {
|
||||
return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.`
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`,
|
||||
strings.Join(info.relationNames, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Create tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerCreateTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("create", schema, entity)
|
||||
|
||||
writable := writableColumnNames(info.columns)
|
||||
|
||||
var descParts []string
|
||||
descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName))
|
||||
if len(writable) > 0 {
|
||||
descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", ")))
|
||||
}
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName))
|
||||
}
|
||||
descParts = append(descParts,
|
||||
"Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).",
|
||||
)
|
||||
if info.schemaDoc != "" {
|
||||
descParts = append(descParts, info.schemaDoc)
|
||||
}
|
||||
|
||||
description := strings.Join(descParts, "\n\n")
|
||||
|
||||
dataDesc := "Record fields to create."
|
||||
if len(writable) > 0 {
|
||||
dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", "))
|
||||
}
|
||||
dataDesc += " Pass a single object or an array of objects."
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithObject("data",
|
||||
mcp.Description(dataDesc),
|
||||
mcp.Required(),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
data, ok := args["data"]
|
||||
if !ok {
|
||||
return mcp.NewToolResultError("missing required argument: data"), nil
|
||||
}
|
||||
|
||||
result, err := h.executeCreate(ctx, schema, entity, data)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Update tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("update", schema, entity)
|
||||
|
||||
writable := writableColumnNames(info.columns)
|
||||
|
||||
var descParts []string
|
||||
descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName))
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName))
|
||||
}
|
||||
if len(writable) > 0 {
|
||||
descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", ")))
|
||||
}
|
||||
descParts = append(descParts,
|
||||
"Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.",
|
||||
)
|
||||
if info.schemaDoc != "" {
|
||||
descParts = append(descParts, info.schemaDoc)
|
||||
}
|
||||
|
||||
description := strings.Join(descParts, "\n\n")
|
||||
|
||||
idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName)
|
||||
|
||||
dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)."
|
||||
if len(writable) > 0 {
|
||||
dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", "))
|
||||
}
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithString("id",
|
||||
mcp.Description(idDesc),
|
||||
),
|
||||
mcp.WithObject("data",
|
||||
mcp.Description(dataDesc),
|
||||
mcp.Required(),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
id, _ := args["id"].(string)
|
||||
|
||||
data, ok := args["data"]
|
||||
if !ok {
|
||||
return mcp.NewToolResultError("missing required argument: data"), nil
|
||||
}
|
||||
dataMap, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return mcp.NewToolResultError("data must be an object"), nil
|
||||
}
|
||||
|
||||
result, err := h.executeUpdate(ctx, schema, entity, id, dataMap)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Delete tool
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) {
|
||||
name := toolName("delete", schema, entity)
|
||||
|
||||
descParts := []string{
|
||||
fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName),
|
||||
}
|
||||
if info.pkName != "" {
|
||||
descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName))
|
||||
}
|
||||
descParts = append(descParts, "Returns the deleted record. This operation is irreversible.")
|
||||
|
||||
description := strings.Join(descParts, " ")
|
||||
|
||||
tool := mcp.NewTool(name,
|
||||
mcp.WithDescription(description),
|
||||
mcp.WithString("id",
|
||||
mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)),
|
||||
mcp.Required(),
|
||||
),
|
||||
)
|
||||
|
||||
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
args := req.GetArguments()
|
||||
id, _ := args["id"].(string)
|
||||
|
||||
result, err := h.executeDelete(ctx, schema, entity, id)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(err.Error()), nil
|
||||
}
|
||||
|
||||
return marshalResult(map[string]interface{}{
|
||||
"success": true,
|
||||
"data": result,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Resource registration
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func registerModelResource(h *Handler, schema, entity string, info modelInfo) {
|
||||
resourceURI := info.fullName
|
||||
|
||||
var resourceDesc strings.Builder
|
||||
fmt.Fprintf(&resourceDesc, "Database table: %s", info.fullName)
|
||||
if info.pkName != "" {
|
||||
fmt.Fprintf(&resourceDesc, " (primary key: %s)", info.pkName)
|
||||
}
|
||||
if info.schemaDoc != "" {
|
||||
resourceDesc.WriteString("\n\n")
|
||||
resourceDesc.WriteString(info.schemaDoc)
|
||||
}
|
||||
|
||||
resource := mcp.NewResource(
|
||||
resourceURI,
|
||||
entity,
|
||||
mcp.WithResourceDescription(resourceDesc.String()),
|
||||
mcp.WithMIMEType("application/json"),
|
||||
)
|
||||
|
||||
h.mcpServer.AddResource(resource, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
|
||||
limit := 100
|
||||
options := common.RequestOptions{Limit: &limit}
|
||||
|
||||
data, metadata, err := h.executeRead(ctx, schema, entity, "", options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
"data": data,
|
||||
"metadata": metadata,
|
||||
}
|
||||
jsonBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling resource: %w", err)
|
||||
}
|
||||
|
||||
return []mcp.ResourceContents{
|
||||
mcp.TextResourceContents{
|
||||
URI: req.Params.URI,
|
||||
MIMEType: "application/json",
|
||||
Text: string(jsonBytes),
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Argument parsing helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// parseRequestOptions converts raw MCP tool arguments into common.RequestOptions.
|
||||
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
|
||||
options := common.RequestOptions{}
|
||||
|
||||
if v, ok := args["limit"]; ok {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
limit := int(n)
|
||||
options.Limit = &limit
|
||||
case int:
|
||||
options.Limit = &n
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := args["offset"]; ok {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
offset := int(n)
|
||||
options.Offset = &offset
|
||||
case int:
|
||||
options.Offset = &n
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := args["cursor_forward"].(string); ok {
|
||||
options.CursorForward = v
|
||||
}
|
||||
if v, ok := args["cursor_backward"].(string); ok {
|
||||
options.CursorBackward = v
|
||||
}
|
||||
|
||||
options.Columns = parseStringArray(args["columns"])
|
||||
options.OmitColumns = parseStringArray(args["omit_columns"])
|
||||
options.Filters = parseFilters(args["filters"])
|
||||
options.Sort = parseSortOptions(args["sort"])
|
||||
options.Preload = parsePreloadOptions(args["preloads"])
|
||||
|
||||
return options
|
||||
}
|
||||
|
||||
func parseStringArray(raw interface{}) []string {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
if s, ok := item.(string); ok {
|
||||
result = append(result, s)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseFilters(raw interface{}) []common.FilterOption {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]common.FilterOption, 0, len(items))
|
||||
for _, item := range items {
|
||||
b, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var f common.FilterOption
|
||||
if err := json.Unmarshal(b, &f); err != nil {
|
||||
continue
|
||||
}
|
||||
if f.Column == "" || f.Operator == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(f.LogicOperator, "or") {
|
||||
f.LogicOperator = "OR"
|
||||
} else {
|
||||
f.LogicOperator = "AND"
|
||||
}
|
||||
result = append(result, f)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parseSortOptions(raw interface{}) []common.SortOption {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]common.SortOption, 0, len(items))
|
||||
for _, item := range items {
|
||||
b, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var s common.SortOption
|
||||
if err := json.Unmarshal(b, &s); err != nil {
|
||||
continue
|
||||
}
|
||||
if s.Column == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, s)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func parsePreloadOptions(raw interface{}) []common.PreloadOption {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
result := make([]common.PreloadOption, 0, len(items))
|
||||
for _, item := range items {
|
||||
b, err := json.Marshal(item)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
var p common.PreloadOption
|
||||
if err := json.Unmarshal(b, &p); err != nil {
|
||||
continue
|
||||
}
|
||||
if p.Relation == "" {
|
||||
continue
|
||||
}
|
||||
result = append(result, p)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// marshalResult marshals a value to JSON and returns it as an MCP text result.
|
||||
func marshalResult(v interface{}) (*mcp.CallToolResult, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("error marshaling result: %v", err)), nil
|
||||
}
|
||||
return mcp.NewToolResultText(string(b)), nil
|
||||
}
|
||||
572
pkg/resolvespec/EXAMPLES.md
Normal file
572
pkg/resolvespec/EXAMPLES.md
Normal file
@@ -0,0 +1,572 @@
|
||||
# ResolveSpec Query Features Examples
|
||||
|
||||
This document provides examples of using the advanced query features in ResolveSpec, including OR logic filters, Custom Operators, and FetchRowNumber.
|
||||
|
||||
## OR Logic in Filters (SearchOr)
|
||||
|
||||
### Basic OR Filter Example
|
||||
|
||||
Find all users with status "active" OR "pending":
|
||||
|
||||
```json
|
||||
POST /users
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "pending",
|
||||
"logic_operator": "OR"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Combined AND/OR Filters
|
||||
|
||||
Find users with (status="active" OR status="pending") AND age >= 18:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "pending",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "age",
|
||||
"operator": "gte",
|
||||
"value": 18
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**SQL Generated:** `WHERE (status = 'active' OR status = 'pending') AND age >= 18`
|
||||
|
||||
**Important Notes:**
|
||||
- By default, filters use AND logic
|
||||
- Consecutive filters with `"logic_operator": "OR"` are automatically grouped with parentheses
|
||||
- This grouping ensures OR conditions don't interfere with AND conditions
|
||||
- You don't need to specify `"logic_operator": "AND"` as it's the default
|
||||
|
||||
### Multiple OR Groups
|
||||
|
||||
You can have multiple separate OR groups:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "pending",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "priority",
|
||||
"operator": "eq",
|
||||
"value": "high"
|
||||
},
|
||||
{
|
||||
"column": "priority",
|
||||
"operator": "eq",
|
||||
"value": "urgent",
|
||||
"logic_operator": "OR"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**SQL Generated:** `WHERE (status = 'active' OR status = 'pending') AND (priority = 'high' OR priority = 'urgent')`
|
||||
|
||||
## Custom Operators
|
||||
|
||||
### Simple Custom SQL Condition
|
||||
|
||||
Filter by email domain using custom SQL:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "company_emails",
|
||||
"sql": "email LIKE '%@company.com'"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Multiple Custom Operators
|
||||
|
||||
Combine multiple custom SQL conditions:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "recent_active",
|
||||
"sql": "last_login > NOW() - INTERVAL '30 days'"
|
||||
},
|
||||
{
|
||||
"name": "high_score",
|
||||
"sql": "score > 1000"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Complex Custom Operator
|
||||
|
||||
Use complex SQL expressions:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "priority_users",
|
||||
"sql": "(subscription = 'premium' AND points > 500) OR (subscription = 'enterprise')"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Combining Custom Operators with Regular Filters
|
||||
|
||||
Mix custom operators with standard filters:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "country",
|
||||
"operator": "eq",
|
||||
"value": "USA"
|
||||
}
|
||||
],
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "active_last_month",
|
||||
"sql": "last_activity > NOW() - INTERVAL '1 month'"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Row Numbers
|
||||
|
||||
### Two Ways to Get Row Numbers
|
||||
|
||||
There are two different features for row numbers:
|
||||
|
||||
1. **`fetch_row_number`** - Get the position of ONE specific record in a sorted/filtered set
|
||||
2. **`RowNumber` field in models** - Automatically number all records in the response
|
||||
|
||||
### 1. FetchRowNumber - Get Position of Specific Record
|
||||
|
||||
Get the rank/position of a specific user in a leaderboard. **Important:** When `fetch_row_number` is specified, the response contains **ONLY that specific record**, not all records.
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "12345"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response - Contains ONLY the specified user:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"id": 12345,
|
||||
"name": "Alice Smith",
|
||||
"score": 9850,
|
||||
"level": 42
|
||||
},
|
||||
"metadata": {
|
||||
"total": 10000,
|
||||
"count": 1,
|
||||
"filtered": 10000,
|
||||
"row_number": 42
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Result:** User "12345" is ranked #42 out of 10,000 users. The response includes only Alice's data, not the other 9,999 users.
|
||||
|
||||
### Row Number with Filters
|
||||
|
||||
Find position within a filtered subset (e.g., "What's my rank in my country?"):
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "country",
|
||||
"operator": "eq",
|
||||
"value": "USA"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "12345"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"id": 12345,
|
||||
"name": "Bob Johnson",
|
||||
"country": "USA",
|
||||
"score": 7200,
|
||||
"status": "active"
|
||||
},
|
||||
"metadata": {
|
||||
"total": 2500,
|
||||
"count": 1,
|
||||
"filtered": 2500,
|
||||
"row_number": 156
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Result:** Bob is ranked #156 out of 2,500 active USA users. Only Bob's record is returned.
|
||||
|
||||
### 2. RowNumber Field - Auto-Number All Records
|
||||
|
||||
If your model has a `RowNumber int64` field, restheadspec will automatically populate it for paginated results.
|
||||
|
||||
**Model Definition:**
|
||||
```go
|
||||
type Player struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Score int64 `json:"score"`
|
||||
RowNumber int64 `json:"row_number"` // Will be auto-populated
|
||||
}
|
||||
```
|
||||
|
||||
**Request (with pagination):**
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"sort": [{"column": "score", "direction": "desc"}],
|
||||
"limit": 10,
|
||||
"offset": 20
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response - RowNumber automatically set:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": [
|
||||
{
|
||||
"id": 456,
|
||||
"name": "Player21",
|
||||
"score": 8900,
|
||||
"row_number": 21
|
||||
},
|
||||
{
|
||||
"id": 789,
|
||||
"name": "Player22",
|
||||
"score": 8850,
|
||||
"row_number": 22
|
||||
},
|
||||
{
|
||||
"id": 123,
|
||||
"name": "Player23",
|
||||
"score": 8800,
|
||||
"row_number": 23
|
||||
}
|
||||
// ... records 24-30 ...
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**How It Works:**
|
||||
- `row_number = offset + index + 1` (1-based)
|
||||
- With offset=20, first record gets row_number=21
|
||||
- With offset=20, second record gets row_number=22
|
||||
- Perfect for displaying "Rank" in paginated tables
|
||||
|
||||
**Use Case:** Displaying leaderboards with rank numbers:
|
||||
```
|
||||
Rank | Player | Score
|
||||
-----|-----------|-------
|
||||
21 | Player21 | 8900
|
||||
22 | Player22 | 8850
|
||||
23 | Player23 | 8800
|
||||
```
|
||||
|
||||
**Note:** This feature is available in all three packages: resolvespec, restheadspec, and websocketspec.
|
||||
|
||||
### When to Use Each Feature
|
||||
|
||||
| Feature | Use Case | Returns | Performance |
|
||||
|---------|----------|---------|-------------|
|
||||
| `fetch_row_number` | "What's my rank?" | 1 record with position | Fast - 1 record |
|
||||
| `RowNumber` field | "Show top 10 with ranks" | Many records numbered | Fast - simple math |
|
||||
|
||||
**Combined Example - Full Leaderboard UI:**
|
||||
|
||||
```javascript
|
||||
// Request 1: Get current user's rank
|
||||
const userRank = await api.read({
|
||||
fetch_row_number: currentUserId,
|
||||
sort: [{column: "score", direction: "desc"}]
|
||||
});
|
||||
// Returns: {id: 123, name: "You", score: 7500, row_number: 156}
|
||||
|
||||
// Request 2: Get top 10 with rank numbers
|
||||
const top10 = await api.read({
|
||||
sort: [{column: "score", direction: "desc"}],
|
||||
limit: 10,
|
||||
offset: 0
|
||||
});
|
||||
// Returns: [{row_number: 1, ...}, {row_number: 2, ...}, ...]
|
||||
|
||||
// Display:
|
||||
// "Your Rank: #156"
|
||||
// "Top Players:"
|
||||
// "#1 - Alice - 9999"
|
||||
// "#2 - Bob - 9876"
|
||||
// ...
|
||||
```
|
||||
|
||||
## Complete Example: Advanced Query
|
||||
|
||||
Combine all features for a complex query:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"columns": ["id", "name", "email", "score", "status"],
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "trial",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "score",
|
||||
"operator": "gte",
|
||||
"value": 100
|
||||
}
|
||||
],
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "recent_activity",
|
||||
"sql": "last_login > NOW() - INTERVAL '7 days'"
|
||||
},
|
||||
{
|
||||
"name": "verified_email",
|
||||
"sql": "email_verified = true"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
},
|
||||
{
|
||||
"column": "created_at",
|
||||
"direction": "asc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "12345",
|
||||
"limit": 50,
|
||||
"offset": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This query:
|
||||
- Selects specific columns
|
||||
- Filters for users with status "active" OR "trial"
|
||||
- AND score >= 100
|
||||
- Applies custom SQL conditions for recent activity and verified emails
|
||||
- Sorts by score (descending) then creation date (ascending)
|
||||
- Returns the row number of user "12345" in this filtered/sorted set
|
||||
- Returns 50 records starting from the first one
|
||||
|
||||
## Use Cases
|
||||
|
||||
### 1. Leaderboards - Get Current User's Rank
|
||||
|
||||
Get the current user's position and data (returns only their record):
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "game_id",
|
||||
"operator": "eq",
|
||||
"value": "game123"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "current_user_id"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Tip:** For full leaderboards, make two requests:
|
||||
1. One with `fetch_row_number` to get user's rank
|
||||
2. One with `limit` and `offset` to get top players list
|
||||
|
||||
### 2. Multi-Status Search
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "order_status",
|
||||
"operator": "eq",
|
||||
"value": "pending"
|
||||
},
|
||||
{
|
||||
"column": "order_status",
|
||||
"operator": "eq",
|
||||
"value": "processing",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "order_status",
|
||||
"operator": "eq",
|
||||
"value": "shipped",
|
||||
"logic_operator": "OR"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Advanced Date Filtering
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "this_month",
|
||||
"sql": "created_at >= DATE_TRUNC('month', CURRENT_DATE)"
|
||||
},
|
||||
{
|
||||
"name": "business_hours",
|
||||
"sql": "EXTRACT(HOUR FROM created_at) BETWEEN 9 AND 17"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
**Warning:** Custom operators allow raw SQL, which can be a security risk if not properly handled:
|
||||
|
||||
1. **Never** directly interpolate user input into custom operator SQL
|
||||
2. Always validate and sanitize custom operator SQL on the backend
|
||||
3. Consider using a whitelist of allowed custom operators
|
||||
4. Use prepared statements or parameterized queries when possible
|
||||
5. Implement proper authorization checks before executing queries
|
||||
|
||||
Example of safe custom operator handling in Go:
|
||||
|
||||
```go
|
||||
// Whitelist of allowed custom operators
|
||||
allowedOperators := map[string]string{
|
||||
"recent_week": "created_at > NOW() - INTERVAL '7 days'",
|
||||
"active_users": "status = 'active' AND last_login > NOW() - INTERVAL '30 days'",
|
||||
"premium_only": "subscription_level = 'premium'",
|
||||
}
|
||||
|
||||
// Validate custom operators from request
|
||||
for _, op := range req.Options.CustomOperators {
|
||||
if sql, ok := allowedOperators[op.Name]; ok {
|
||||
op.SQL = sql // Use whitelisted SQL
|
||||
} else {
|
||||
return errors.New("custom operator not allowed: " + op.Name)
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -214,6 +214,146 @@ Content-Type: application/json
|
||||
|
||||
```json
|
||||
{
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
},
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "pending",
|
||||
"logic_operator": "OR"
|
||||
},
|
||||
{
|
||||
"column": "age",
|
||||
"operator": "gte",
|
||||
"value": 18
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Produces: `WHERE (status = 'active' OR status = 'pending') AND age >= 18`
|
||||
|
||||
This grouping ensures OR conditions don't interfere with other AND conditions in the query.
|
||||
|
||||
### Custom Operators
|
||||
|
||||
Add custom SQL conditions when needed:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"customOperators": [
|
||||
{
|
||||
"name": "email_domain_filter",
|
||||
"sql": "LOWER(email) LIKE '%@example.com'"
|
||||
},
|
||||
{
|
||||
"name": "recent_records",
|
||||
"sql": "created_at > NOW() - INTERVAL '7 days'"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Custom operators are applied as additional WHERE conditions to your query.
|
||||
|
||||
### Fetch Row Number
|
||||
|
||||
Get the row number (position) of a specific record in the filtered and sorted result set. **When `fetch_row_number` is specified, only that specific record is returned** (not all records).
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "active"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "score",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"fetch_row_number": "12345"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response - Returns ONLY the specified record with its position:**
|
||||
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"id": 12345,
|
||||
"name": "John Doe",
|
||||
"score": 850,
|
||||
"status": "active"
|
||||
},
|
||||
"metadata": {
|
||||
"total": 1000,
|
||||
"count": 1,
|
||||
"filtered": 1000,
|
||||
"row_number": 42
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Use Case:** Perfect for "Show me this user and their ranking" - you get just that one user with their position in the leaderboard.
|
||||
|
||||
**Note:** This is different from the `RowNumber` field feature, which automatically numbers all records in a paginated response based on offset. That feature uses simple math (`offset + index + 1`), while `fetch_row_number` uses SQL window functions to calculate the actual position in a sorted/filtered set. To use the `RowNumber` field feature, simply add a `RowNumber int64` field to your model - it will be automatically populated with the row position based on pagination.
|
||||
|
||||
## Preloading
|
||||
|
||||
Load related entities with custom configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"operation": "read",
|
||||
"options": {
|
||||
"columns": ["id", "name", "email"],
|
||||
"preload": [
|
||||
{
|
||||
"relation": "posts",
|
||||
"columns": ["id", "title", "created_at"],
|
||||
"filters": [
|
||||
{
|
||||
"column": "status",
|
||||
"operator": "eq",
|
||||
"value": "published"
|
||||
}
|
||||
],
|
||||
"sort": [
|
||||
{
|
||||
"column": "created_at",
|
||||
"direction": "desc"
|
||||
}
|
||||
],
|
||||
"limit": 5
|
||||
},
|
||||
{
|
||||
"relation": "profile",
|
||||
"columns": ["bio", "website"]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cursor Pagination
|
||||
|
||||
Efficient pagination for large datasets:
|
||||
|
||||
### First Request (No Cursor)
|
||||
|
||||
```json
|
||||
@@ -427,7 +567,7 @@ Define virtual columns using SQL expressions:
|
||||
// Check permissions
|
||||
if !userHasPermission(ctx.Context, ctx.Entity) {
|
||||
return fmt.Errorf("unauthorized access to %s", ctx.Entity)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Modify query options
|
||||
if ctx.Options.Limit == nil || *ctx.Options.Limit > 100 {
|
||||
@@ -435,17 +575,24 @@ Add custom SQL conditions when needed:
|
||||
}
|
||||
|
||||
return nil
|
||||
users[i].Email = maskEmail(users[i].Email)
|
||||
}
|
||||
})
|
||||
|
||||
// Register an after-read hook (e.g., for data transformation)
|
||||
handler.Hooks().Register(resolvespec.AfterRead, func(ctx *resolvespec.HookContext) error {
|
||||
})
|
||||
// Transform or filter results
|
||||
if users, ok := ctx.Result.([]User); ok {
|
||||
for i := range users {
|
||||
users[i].Email = maskEmail(users[i].Email)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Register a before-create hook (e.g., for validation)
|
||||
handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookContext) error {
|
||||
// Validate data
|
||||
if user, ok := ctx.Data.(*User); ok {
|
||||
if user.Email == "" {
|
||||
return fmt.Errorf("email is required")
|
||||
}
|
||||
// Add timestamps
|
||||
@@ -497,6 +644,7 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Tags []Tag `json:"tags,omitempty" gorm:"many2many:post_tags"`
|
||||
}
|
||||
|
||||
// Schema.Table format
|
||||
handler.registry.RegisterModel("core.users", &User{})
|
||||
handler.registry.RegisterModel("core.posts", &Post{})
|
||||
@@ -507,11 +655,13 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/resolvespec"
|
||||
"github.com/gorilla/mux"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ const (
|
||||
// - pkName: primary key column (e.g. "id")
|
||||
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
||||
// - options: the request options containing sort and cursor information
|
||||
// - expandJoins: optional map[alias]string of JOIN clauses for join-column sort support
|
||||
//
|
||||
// Returns SQL snippet to embed in WHERE clause.
|
||||
func GetCursorFilter(
|
||||
@@ -31,8 +32,10 @@ func GetCursorFilter(
|
||||
pkName string,
|
||||
modelColumns []string,
|
||||
options common.RequestOptions,
|
||||
expandJoins map[string]string,
|
||||
) (string, error) {
|
||||
// Remove schema prefix if present
|
||||
// Separate schema prefix from bare table name
|
||||
fullTableName := tableName
|
||||
if strings.Contains(tableName, ".") {
|
||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||
}
|
||||
@@ -57,18 +60,19 @@ func GetCursorFilter(
|
||||
// 3. Prepare
|
||||
// --------------------------------------------------------------------- //
|
||||
var whereClauses []string
|
||||
joinSQL := ""
|
||||
reverse := direction < 0
|
||||
|
||||
// --------------------------------------------------------------------- //
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse: "created_at", "user.name", etc.
|
||||
// Parse: "created_at", "user.name", "fn.sortorder", etc.
|
||||
parts := strings.Split(col, ".")
|
||||
field := strings.TrimSpace(parts[len(parts)-1])
|
||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||
@@ -81,7 +85,7 @@ func GetCursorFilter(
|
||||
}
|
||||
|
||||
// Resolve column
|
||||
cursorCol, targetCol, err := resolveColumn(
|
||||
cursorCol, targetCol, isJoin, err := resolveColumn(
|
||||
field, prefix, tableName, modelColumns,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -89,6 +93,22 @@ func GetCursorFilter(
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle joins
|
||||
if isJoin {
|
||||
if expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
if cursorCol == "" {
|
||||
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Build inequality
|
||||
op := "<"
|
||||
if desc {
|
||||
@@ -112,10 +132,12 @@ func GetCursorFilter(
|
||||
query := fmt.Sprintf(`EXISTS (
|
||||
SELECT 1
|
||||
FROM %s cursor_select
|
||||
%s
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
tableName,
|
||||
fullTableName,
|
||||
joinSQL,
|
||||
pkName,
|
||||
cursorID,
|
||||
orSQL,
|
||||
@@ -136,35 +158,44 @@ func getActiveCursor(options common.RequestOptions) (id string, direction Cursor
|
||||
return "", 0
|
||||
}
|
||||
|
||||
// Helper: resolve column (main table only for now)
|
||||
// Helper: resolve column (main table or join)
|
||||
func resolveColumn(
|
||||
field, prefix, tableName string,
|
||||
modelColumns []string,
|
||||
) (cursorCol, targetCol string, err error) {
|
||||
) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||
|
||||
// JSON field
|
||||
if strings.Contains(field, "->") {
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
// Main table column
|
||||
if modelColumns != nil {
|
||||
for _, col := range modelColumns {
|
||||
if strings.EqualFold(col, field) {
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No validation → allow all main-table fields
|
||||
return "cursor_select." + field, tableName + "." + field, nil
|
||||
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||
}
|
||||
|
||||
// Joined column (not supported in resolvespec yet)
|
||||
// Joined column
|
||||
if prefix != "" && prefix != tableName {
|
||||
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
|
||||
return "", "", true, nil
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("invalid column: %s", field)
|
||||
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||
}
|
||||
|
||||
// Helper: rewrite JOIN clause for cursor subquery
|
||||
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||
cursorAlias = "cursor_select_" + alias
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||
return joinSQL, cursorAlias
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------- //
|
||||
|
||||
@@ -20,7 +20,7 @@ func TestGetCursorFilter_Forward(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -65,7 +65,7 @@ func TestGetCursorFilter_Backward(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -96,7 +96,7 @@ func TestGetCursorFilter_NoCursor(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "created_at"}
|
||||
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no cursor is provided")
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func TestGetCursorFilter_NoSort(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title"}
|
||||
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err == nil {
|
||||
t.Error("Expected error when no sort columns are defined")
|
||||
}
|
||||
@@ -140,7 +140,7 @@ func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "title", "priority", "created_at"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
@@ -170,19 +170,50 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "name", "email"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
// Should handle schema prefix properly
|
||||
if !strings.Contains(filter, "users") {
|
||||
t.Errorf("Filter should reference table name users, got: %s", filter)
|
||||
// Should include full schema-qualified name in FROM clause
|
||||
if !strings.Contains(filter, "public.users") {
|
||||
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_LateralJoin(t *testing.T) {
|
||||
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
|
||||
|
||||
options := common.RequestOptions{
|
||||
Sort: []common.SortOption{{Column: "fn.sortorder", Direction: "ASC"}},
|
||||
CursorForward: "8975",
|
||||
}
|
||||
|
||||
tableName := "core.account"
|
||||
pkName := "rid_account"
|
||||
modelColumns := []string{"rid_account", "description", "pastelno"}
|
||||
expandJoins := map[string]string{"fn": lateralJoin}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, expandJoins)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Generated lateral cursor filter: %s", filter)
|
||||
|
||||
if !strings.Contains(filter, "cursor_select_fn") {
|
||||
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
|
||||
}
|
||||
if !strings.Contains(filter, "sortorder") {
|
||||
t.Errorf("Filter should reference sortorder column, got: %s", filter)
|
||||
}
|
||||
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
|
||||
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetActiveCursor(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -288,18 +319,19 @@ func TestResolveColumn(t *testing.T) {
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Joined column (not supported)",
|
||||
name: "Joined column (isJoin=true, no error)",
|
||||
field: "name",
|
||||
prefix: "user",
|
||||
tableName: "posts",
|
||||
modelColumns: []string{"id", "title"},
|
||||
wantErr: true,
|
||||
wantErr: false,
|
||||
// cursorCol and targetCol are empty when isJoin=true; handled by caller
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
||||
cursor, target, isJoin, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
@@ -312,6 +344,14 @@ func TestResolveColumn(t *testing.T) {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// For join columns, cursor/target are empty and isJoin=true
|
||||
if isJoin {
|
||||
if cursor != "" || target != "" {
|
||||
t.Errorf("Expected empty cursor/target for join column, got %q / %q", cursor, target)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if cursor != tt.wantCursor {
|
||||
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
|
||||
}
|
||||
@@ -362,7 +402,7 @@ func TestCursorFilter_SQL_Safety(t *testing.T) {
|
||||
pkName := "id"
|
||||
modelColumns := []string{"id", "created_at"}
|
||||
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
143
pkg/resolvespec/filter_test.go
Normal file
143
pkg/resolvespec/filter_test.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package resolvespec
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// TestBuildFilterCondition tests the filter condition builder
|
||||
func TestBuildFilterCondition(t *testing.T) {
|
||||
h := &Handler{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filter common.FilterOption
|
||||
expectedCondition string
|
||||
expectedArgsCount int
|
||||
}{
|
||||
{
|
||||
name: "Equal operator",
|
||||
filter: common.FilterOption{
|
||||
Column: "status",
|
||||
Operator: "eq",
|
||||
Value: "active",
|
||||
},
|
||||
expectedCondition: "status = ?",
|
||||
expectedArgsCount: 1,
|
||||
},
|
||||
{
|
||||
name: "Greater than operator",
|
||||
filter: common.FilterOption{
|
||||
Column: "age",
|
||||
Operator: "gt",
|
||||
Value: 18,
|
||||
},
|
||||
expectedCondition: "age > ?",
|
||||
expectedArgsCount: 1,
|
||||
},
|
||||
{
|
||||
name: "IN operator",
|
||||
filter: common.FilterOption{
|
||||
Column: "status",
|
||||
Operator: "in",
|
||||
Value: []string{"active", "pending"},
|
||||
},
|
||||
expectedCondition: "status IN (?,?)",
|
||||
expectedArgsCount: 2,
|
||||
},
|
||||
{
|
||||
name: "LIKE operator",
|
||||
filter: common.FilterOption{
|
||||
Column: "email",
|
||||
Operator: "like",
|
||||
Value: "%@example.com",
|
||||
},
|
||||
expectedCondition: "email LIKE ?",
|
||||
expectedArgsCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
condition, args := h.buildFilterCondition(tt.filter)
|
||||
|
||||
if condition != tt.expectedCondition {
|
||||
t.Errorf("Expected condition '%s', got '%s'", tt.expectedCondition, condition)
|
||||
}
|
||||
|
||||
if len(args) != tt.expectedArgsCount {
|
||||
t.Errorf("Expected %d args, got %d", tt.expectedArgsCount, len(args))
|
||||
}
|
||||
|
||||
// Note: Skip value comparison for slices as they can't be compared with ==
|
||||
// The important part is that args are populated correctly
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestORGrouping tests that consecutive OR filters are properly grouped
|
||||
func TestORGrouping(t *testing.T) {
|
||||
// This is a conceptual test - in practice we'd need a mock SelectQuery
|
||||
// to verify the actual SQL grouping behavior
|
||||
t.Run("Consecutive OR filters should be grouped", func(t *testing.T) {
|
||||
filters := []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
{Column: "status", Operator: "eq", Value: "pending", LogicOperator: "OR"},
|
||||
{Column: "status", Operator: "eq", Value: "trial", LogicOperator: "OR"},
|
||||
{Column: "age", Operator: "gte", Value: 18},
|
||||
}
|
||||
|
||||
// Expected behavior: (status='active' OR status='pending' OR status='trial') AND age>=18
|
||||
// The first three filters should be grouped together
|
||||
// The fourth filter should be separate with AND
|
||||
|
||||
// Count OR groups
|
||||
orGroupCount := 0
|
||||
inORGroup := false
|
||||
|
||||
for i := 1; i < len(filters); i++ {
|
||||
if strings.EqualFold(filters[i].LogicOperator, "OR") && !inORGroup {
|
||||
orGroupCount++
|
||||
inORGroup = true
|
||||
} else if !strings.EqualFold(filters[i].LogicOperator, "OR") {
|
||||
inORGroup = false
|
||||
}
|
||||
}
|
||||
|
||||
// We should have detected one OR group
|
||||
if orGroupCount != 1 {
|
||||
t.Errorf("Expected 1 OR group, detected %d", orGroupCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Multiple OR groups should be handled correctly", func(t *testing.T) {
|
||||
filters := []common.FilterOption{
|
||||
{Column: "status", Operator: "eq", Value: "active"},
|
||||
{Column: "status", Operator: "eq", Value: "pending", LogicOperator: "OR"},
|
||||
{Column: "priority", Operator: "eq", Value: "high"},
|
||||
{Column: "priority", Operator: "eq", Value: "urgent", LogicOperator: "OR"},
|
||||
}
|
||||
|
||||
// Expected: (status='active' OR status='pending') AND (priority='high' OR priority='urgent')
|
||||
// Should have two OR groups
|
||||
|
||||
orGroupCount := 0
|
||||
inORGroup := false
|
||||
|
||||
for i := 1; i < len(filters); i++ {
|
||||
if strings.EqualFold(filters[i].LogicOperator, "OR") && !inORGroup {
|
||||
orGroupCount++
|
||||
inORGroup = true
|
||||
} else if !strings.EqualFold(filters[i].LogicOperator, "OR") {
|
||||
inORGroup = false
|
||||
}
|
||||
}
|
||||
|
||||
// We should have detected two OR groups
|
||||
if orGroupCount != 2 {
|
||||
t.Errorf("Expected 2 OR groups, detected %d", orGroupCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -138,6 +138,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
||||
validator := common.NewColumnValidator(model)
|
||||
req.Options = validator.FilterRequestOptions(req.Options)
|
||||
|
||||
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||
beforeCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Writer: w,
|
||||
Request: r,
|
||||
Operation: req.Operation,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
|
||||
code := http.StatusUnauthorized
|
||||
if beforeCtx.AbortCode != 0 {
|
||||
code = beforeCtx.AbortCode
|
||||
}
|
||||
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch req.Operation {
|
||||
case "read":
|
||||
h.handleRead(ctx, w, id, req.Options)
|
||||
@@ -280,10 +300,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
}
|
||||
|
||||
// Apply filters
|
||||
for _, filter := range options.Filters {
|
||||
logger.Debug("Applying filter: %s %s %v", filter.Column, filter.Operator, filter.Value)
|
||||
query = h.applyFilter(query, filter)
|
||||
// Apply filters with proper grouping for OR logic
|
||||
query = h.applyFilters(query, options.Filters)
|
||||
|
||||
// Apply custom operators
|
||||
for _, customOp := range options.CustomOperators {
|
||||
logger.Debug("Applying custom operator: %s - %s", customOp.Name, customOp.SQL)
|
||||
query = query.Where(customOp.SQL)
|
||||
}
|
||||
|
||||
// Apply sorting
|
||||
@@ -306,8 +329,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Extract model columns for validation
|
||||
modelColumns := reflection.GetModelColumns(model)
|
||||
|
||||
// Get cursor filter SQL
|
||||
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
||||
// Default sort to primary key when none provided
|
||||
if len(options.Sort) == 0 {
|
||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||
}
|
||||
|
||||
// Get cursor filter SQL (expandJoins is empty for resolvespec — no custom SQL join support yet)
|
||||
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||
if err != nil {
|
||||
logger.Error("Error building cursor filter: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
|
||||
@@ -318,6 +346,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
if cursorFilter != "" {
|
||||
logger.Debug("Applying cursor filter: %s", cursorFilter)
|
||||
sanitizedCursor := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
||||
// Ensure outer parentheses to prevent OR logic from escaping
|
||||
sanitizedCursor = common.EnsureOuterParentheses(sanitizedCursor)
|
||||
if sanitizedCursor != "" {
|
||||
query = query.Where(sanitizedCursor)
|
||||
}
|
||||
@@ -379,24 +409,105 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
}
|
||||
|
||||
// Apply pagination
|
||||
if options.Limit != nil && *options.Limit > 0 {
|
||||
logger.Debug("Applying limit: %d", *options.Limit)
|
||||
query = query.Limit(*options.Limit)
|
||||
// Handle FetchRowNumber if requested
|
||||
var rowNumber *int64
|
||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||
logger.Debug("Fetching row number for ID: %s", *options.FetchRowNumber)
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// Build ROW_NUMBER window function SQL
|
||||
rowNumberSQL := "ROW_NUMBER() OVER ("
|
||||
if len(options.Sort) > 0 {
|
||||
rowNumberSQL += "ORDER BY "
|
||||
for i, sort := range options.Sort {
|
||||
if i > 0 {
|
||||
rowNumberSQL += ", "
|
||||
}
|
||||
direction := "ASC"
|
||||
if strings.EqualFold(sort.Direction, "desc") {
|
||||
direction = "DESC"
|
||||
}
|
||||
rowNumberSQL += fmt.Sprintf("%s %s", sort.Column, direction)
|
||||
}
|
||||
}
|
||||
rowNumberSQL += ")"
|
||||
|
||||
// Create a query to fetch the row number using a subquery approach
|
||||
// We'll select the PK and row_number, then filter by the target ID
|
||||
type RowNumResult struct {
|
||||
RowNum int64 `bun:"row_num"`
|
||||
}
|
||||
|
||||
rowNumQuery := h.db.NewSelect().Table(tableName).
|
||||
ColumnExpr(fmt.Sprintf("%s AS row_num", rowNumberSQL)).
|
||||
Column(pkName)
|
||||
|
||||
// Apply the same filters as the main query
|
||||
for _, filter := range options.Filters {
|
||||
rowNumQuery = h.applyFilter(rowNumQuery, filter)
|
||||
}
|
||||
|
||||
// Apply custom operators
|
||||
for _, customOp := range options.CustomOperators {
|
||||
rowNumQuery = rowNumQuery.Where(customOp.SQL)
|
||||
}
|
||||
|
||||
// Filter for the specific ID we want the row number for
|
||||
rowNumQuery = rowNumQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), *options.FetchRowNumber)
|
||||
|
||||
// Execute query to get row number
|
||||
var result RowNumResult
|
||||
if err := rowNumQuery.Scan(ctx, &result); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
// Build filter description for error message
|
||||
filterInfo := fmt.Sprintf("filters: %d", len(options.Filters))
|
||||
if len(options.CustomOperators) > 0 {
|
||||
customOps := make([]string, 0, len(options.CustomOperators))
|
||||
for _, op := range options.CustomOperators {
|
||||
customOps = append(customOps, op.SQL)
|
||||
}
|
||||
filterInfo += fmt.Sprintf(", custom operators: [%s]", strings.Join(customOps, "; "))
|
||||
}
|
||||
logger.Warn("No row found for primary key %s=%s with %s", pkName, *options.FetchRowNumber, filterInfo)
|
||||
} else {
|
||||
logger.Warn("Error fetching row number: %v", err)
|
||||
}
|
||||
} else {
|
||||
rowNumber = &result.RowNum
|
||||
logger.Debug("Found row number: %d", *rowNumber)
|
||||
}
|
||||
}
|
||||
if options.Offset != nil && *options.Offset > 0 {
|
||||
logger.Debug("Applying offset: %d", *options.Offset)
|
||||
query = query.Offset(*options.Offset)
|
||||
|
||||
// Apply pagination (skip if FetchRowNumber is set - we want only that record)
|
||||
if options.FetchRowNumber == nil || *options.FetchRowNumber == "" {
|
||||
if options.Limit != nil && *options.Limit > 0 {
|
||||
logger.Debug("Applying limit: %d", *options.Limit)
|
||||
query = query.Limit(*options.Limit)
|
||||
}
|
||||
if options.Offset != nil && *options.Offset > 0 {
|
||||
logger.Debug("Applying offset: %d", *options.Offset)
|
||||
query = query.Offset(*options.Offset)
|
||||
}
|
||||
}
|
||||
|
||||
// Execute query
|
||||
var result interface{}
|
||||
if id != "" {
|
||||
logger.Debug("Querying single record with ID: %s", id)
|
||||
if id != "" || (options.FetchRowNumber != nil && *options.FetchRowNumber != "") {
|
||||
// Single record query - either by URL ID or FetchRowNumber
|
||||
var targetID string
|
||||
if id != "" {
|
||||
targetID = id
|
||||
logger.Debug("Querying single record with URL ID: %s", id)
|
||||
} else {
|
||||
targetID = *options.FetchRowNumber
|
||||
logger.Debug("Querying single record with FetchRowNumber ID: %s", targetID)
|
||||
}
|
||||
|
||||
// For single record, create a new pointer to the struct type
|
||||
singleResult := reflect.New(modelType).Interface()
|
||||
pkName := reflection.GetPrimaryKeyName(singleResult)
|
||||
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(singleResult))), id)
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
if err := query.Scan(ctx, singleResult); err != nil {
|
||||
logger.Error("Error querying record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
||||
@@ -416,20 +527,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
|
||||
logger.Info("Successfully retrieved records")
|
||||
|
||||
// Build metadata
|
||||
limit := 0
|
||||
if options.Limit != nil {
|
||||
limit = *options.Limit
|
||||
}
|
||||
offset := 0
|
||||
if options.Offset != nil {
|
||||
offset = *options.Offset
|
||||
count := int64(total)
|
||||
|
||||
// When FetchRowNumber is used, we only return 1 record
|
||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||
count = 1
|
||||
// Set the fetched row number on the record
|
||||
if rowNumber != nil {
|
||||
logger.Debug("FetchRowNumber: Setting row number %d on record", *rowNumber)
|
||||
h.setRowNumbersOnRecords(result, int(*rowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
|
||||
}
|
||||
} else {
|
||||
if options.Limit != nil {
|
||||
limit = *options.Limit
|
||||
}
|
||||
if options.Offset != nil {
|
||||
offset = *options.Offset
|
||||
}
|
||||
|
||||
// Set row numbers on records if RowNumber field exists
|
||||
// Only for multiple records (not when fetching single record)
|
||||
h.setRowNumbersOnRecords(result, offset)
|
||||
}
|
||||
|
||||
h.sendResponse(w, result, &common.Metadata{
|
||||
Total: int64(total),
|
||||
Filtered: int64(total),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
Total: int64(total),
|
||||
Filtered: int64(total),
|
||||
Count: count,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
RowNumber: rowNumber,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -701,97 +831,130 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
// Get the primary key name
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
|
||||
// First, read the existing record from the database
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := h.db.NewSelect().Model(existingRecord)
|
||||
// Wrap in transaction to ensure BeforeUpdate hook is inside transaction
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// First, read the existing record from the database
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*")
|
||||
|
||||
// Apply conditions to select
|
||||
if urlID != "" {
|
||||
logger.Debug("Updating by URL ID: %s", urlID)
|
||||
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
||||
} else if reqID != nil {
|
||||
switch id := reqID.(type) {
|
||||
case string:
|
||||
logger.Debug("Updating by request ID: %s", id)
|
||||
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
case []string:
|
||||
if len(id) > 0 {
|
||||
logger.Debug("Updating by multiple IDs: %v", id)
|
||||
selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
||||
// Apply conditions to select
|
||||
if urlID != "" {
|
||||
logger.Debug("Updating by URL ID: %s", urlID)
|
||||
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
||||
} else if reqID != nil {
|
||||
switch id := reqID.(type) {
|
||||
case string:
|
||||
logger.Debug("Updating by request ID: %s", id)
|
||||
selectQuery = selectQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
case []string:
|
||||
if len(id) > 0 {
|
||||
logger.Debug("Updating by multiple IDs: %v", id)
|
||||
selectQuery = selectQuery.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
logger.Warn("No records found to update")
|
||||
h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil)
|
||||
return
|
||||
}
|
||||
logger.Error("Error fetching existing record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error fetching existing record", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert existing record to map
|
||||
existingMap := make(map[string]interface{})
|
||||
jsonData, err := json.Marshal(existingRecord)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling existing record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error processing existing record", err)
|
||||
return
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||
logger.Error("Error unmarshaling existing record: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error processing existing record", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Merge only non-null and non-empty values from the incoming request into the existing record
|
||||
for key, newValue := range updates {
|
||||
// Skip if the value is nil
|
||||
if newValue == nil {
|
||||
continue
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("no records found to update")
|
||||
}
|
||||
return fmt.Errorf("error fetching existing record: %w", err)
|
||||
}
|
||||
|
||||
// Skip if the value is an empty string
|
||||
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||
continue
|
||||
// Convert existing record to map
|
||||
existingMap := make(map[string]interface{})
|
||||
jsonData, err := json.Marshal(existingRecord)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error marshaling existing record: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||
return fmt.Errorf("error unmarshaling existing record: %w", err)
|
||||
}
|
||||
|
||||
// Update the existing map with the new value
|
||||
existingMap[key] = newValue
|
||||
}
|
||||
|
||||
// Build update query with merged data
|
||||
query := h.db.NewUpdate().Table(tableName).SetMap(existingMap)
|
||||
|
||||
// Apply conditions
|
||||
if urlID != "" {
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
||||
} else if reqID != nil {
|
||||
switch id := reqID.(type) {
|
||||
case string:
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
case []string:
|
||||
query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
||||
// Execute BeforeUpdate hooks inside transaction
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: urlID,
|
||||
Data: updates,
|
||||
Writer: w,
|
||||
Tx: tx,
|
||||
}
|
||||
}
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("BeforeUpdate hook failed: %w", err)
|
||||
}
|
||||
|
||||
// Use potentially modified data from hook context
|
||||
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||
updates = modifiedData
|
||||
}
|
||||
|
||||
// Merge only non-null and non-empty values from the incoming request into the existing record
|
||||
for key, newValue := range updates {
|
||||
// Skip if the value is nil
|
||||
if newValue == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if the value is an empty string
|
||||
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Update the existing map with the new value
|
||||
existingMap[key] = newValue
|
||||
}
|
||||
|
||||
// Build update query with merged data
|
||||
query := tx.NewUpdate().Table(tableName).SetMap(existingMap)
|
||||
|
||||
// Apply conditions
|
||||
if urlID != "" {
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), urlID)
|
||||
} else if reqID != nil {
|
||||
switch id := reqID.(type) {
|
||||
case string:
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||
case []string:
|
||||
query = query.Where(fmt.Sprintf("%s IN (?)", common.QuoteIdent(pkName)), id)
|
||||
}
|
||||
}
|
||||
|
||||
result, err := query.Exec(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating record(s): %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return fmt.Errorf("no records found to update")
|
||||
}
|
||||
|
||||
// Execute AfterUpdate hooks inside transaction
|
||||
hookCtx.Result = updates
|
||||
hookCtx.Error = nil
|
||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("AfterUpdate hook failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Update error: %v", err)
|
||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err)
|
||||
if err.Error() == "no records found to update" {
|
||||
h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", err)
|
||||
} else {
|
||||
h.sendError(w, http.StatusInternalServerError, "update_error", "Error updating record(s)", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
logger.Warn("No records found to update")
|
||||
h.sendError(w, http.StatusNotFound, "not_found", "No records found to update", nil)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated %d records", result.RowsAffected())
|
||||
logger.Info("Successfully updated record(s)")
|
||||
// Invalidate cache for this table
|
||||
cacheTags := buildCacheTags(schema, tableName)
|
||||
if err := invalidateCacheForTags(ctx, cacheTags); err != nil {
|
||||
@@ -849,9 +1012,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
for _, item := range updates {
|
||||
if itemID, ok := item["id"]; ok {
|
||||
itemIDStr := fmt.Sprintf("%v", itemID)
|
||||
|
||||
// First, read the existing record
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue // Skip if record not found
|
||||
@@ -869,6 +1034,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
return fmt.Errorf("failed to unmarshal existing record: %w", err)
|
||||
}
|
||||
|
||||
// Execute BeforeUpdate hooks inside transaction
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: itemIDStr,
|
||||
Data: item,
|
||||
Writer: w,
|
||||
Tx: tx,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err)
|
||||
}
|
||||
|
||||
// Use potentially modified data from hook context
|
||||
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||
item = modifiedData
|
||||
}
|
||||
|
||||
// Merge only non-null and non-empty values
|
||||
for key, newValue := range item {
|
||||
if newValue == nil {
|
||||
@@ -884,6 +1072,13 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
if _, err := txQuery.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute AfterUpdate hooks inside transaction
|
||||
hookCtx.Result = item
|
||||
hookCtx.Error = nil
|
||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -957,9 +1152,11 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
for _, item := range updates {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
if itemID, ok := itemMap["id"]; ok {
|
||||
itemIDStr := fmt.Sprintf("%v", itemID)
|
||||
|
||||
// First, read the existing record
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), itemID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
continue // Skip if record not found
|
||||
@@ -977,6 +1174,29 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
return fmt.Errorf("failed to unmarshal existing record: %w", err)
|
||||
}
|
||||
|
||||
// Execute BeforeUpdate hooks inside transaction
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: itemIDStr,
|
||||
Data: itemMap,
|
||||
Writer: w,
|
||||
Tx: tx,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("BeforeUpdate hook failed for ID %v: %w", itemID, err)
|
||||
}
|
||||
|
||||
// Use potentially modified data from hook context
|
||||
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||
itemMap = modifiedData
|
||||
}
|
||||
|
||||
// Merge only non-null and non-empty values
|
||||
for key, newValue := range itemMap {
|
||||
if newValue == nil {
|
||||
@@ -992,6 +1212,14 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, url
|
||||
if _, err := txQuery.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Execute AfterUpdate hooks inside transaction
|
||||
hookCtx.Result = itemMap
|
||||
hookCtx.Error = nil
|
||||
if err := h.hooks.Execute(AfterUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("AfterUpdate hook failed for ID %v: %w", itemID, err)
|
||||
}
|
||||
|
||||
list = append(list, item)
|
||||
}
|
||||
}
|
||||
@@ -1033,6 +1261,24 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
|
||||
logger.Info("Deleting records from %s.%s", schema, entity)
|
||||
|
||||
// Execute BeforeDelete hooks (covers model-rule checks before any deletion)
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
ID: id,
|
||||
Data: data,
|
||||
Writer: w,
|
||||
Tx: h.db,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Error("BeforeDelete hook failed: %v", err)
|
||||
h.sendError(w, http.StatusForbidden, "delete_forbidden", "Delete operation not allowed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle batch delete from request data
|
||||
if data != nil {
|
||||
switch v := data.(type) {
|
||||
@@ -1203,29 +1449,165 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
h.sendResponse(w, recordToDelete, nil)
|
||||
}
|
||||
|
||||
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
||||
// applyFilters applies all filters with proper grouping for OR logic
|
||||
// Groups consecutive OR filters together to ensure proper query precedence
|
||||
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
|
||||
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
i := 0
|
||||
for i < len(filters) {
|
||||
// Check if this starts an OR group (current or next filter has OR logic)
|
||||
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||
|
||||
if startORGroup {
|
||||
// Collect all consecutive filters that are OR'd together
|
||||
orGroup := []common.FilterOption{filters[i]}
|
||||
j := i + 1
|
||||
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||
orGroup = append(orGroup, filters[j])
|
||||
j++
|
||||
}
|
||||
|
||||
// Apply the OR group as a single grouped WHERE clause
|
||||
query = h.applyFilterGroup(query, orGroup)
|
||||
i = j
|
||||
} else {
|
||||
// Single filter with AND logic (or first filter)
|
||||
condition, args := h.buildFilterCondition(filters[i])
|
||||
if condition != "" {
|
||||
query = query.Where(condition, args...)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
// applyFilterGroup applies a group of filters that should be OR'd together
|
||||
// Always wraps them in parentheses and applies as a single WHERE clause
|
||||
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Build all conditions and collect args
|
||||
var conditions []string
|
||||
var args []interface{}
|
||||
|
||||
for _, filter := range filters {
|
||||
condition, filterArgs := h.buildFilterCondition(filter)
|
||||
if condition != "" {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, filterArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Single filter - no need for grouping
|
||||
if len(conditions) == 1 {
|
||||
return query.Where(conditions[0], args...)
|
||||
}
|
||||
|
||||
// Multiple conditions - group with parentheses and OR
|
||||
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
|
||||
return query.Where(groupedCondition, args...)
|
||||
}
|
||||
|
||||
// buildFilterCondition builds a filter condition and returns it with args
|
||||
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
|
||||
var condition string
|
||||
var args []interface{}
|
||||
|
||||
switch filter.Operator {
|
||||
case "eq":
|
||||
return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value)
|
||||
case "neq":
|
||||
return query.Where(fmt.Sprintf("%s != ?", filter.Column), filter.Value)
|
||||
case "gt":
|
||||
return query.Where(fmt.Sprintf("%s > ?", filter.Column), filter.Value)
|
||||
case "gte":
|
||||
return query.Where(fmt.Sprintf("%s >= ?", filter.Column), filter.Value)
|
||||
case "lt":
|
||||
return query.Where(fmt.Sprintf("%s < ?", filter.Column), filter.Value)
|
||||
case "lte":
|
||||
return query.Where(fmt.Sprintf("%s <= ?", filter.Column), filter.Value)
|
||||
case "eq", "=":
|
||||
condition = fmt.Sprintf("%s = ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "neq", "!=", "<>":
|
||||
condition = fmt.Sprintf("%s != ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gt", ">":
|
||||
condition = fmt.Sprintf("%s > ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gte", ">=":
|
||||
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lt", "<":
|
||||
condition = fmt.Sprintf("%s < ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lte", "<=":
|
||||
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "like":
|
||||
return query.Where(fmt.Sprintf("%s LIKE ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
return query.Where(fmt.Sprintf("%s ILIKE ?", filter.Column), filter.Value)
|
||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "in":
|
||||
return query.Where(fmt.Sprintf("%s IN (?)", filter.Column), filter.Value)
|
||||
condition, args = common.BuildInCondition(filter.Column, filter.Value)
|
||||
if condition == "" {
|
||||
return "", nil
|
||||
}
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return condition, args
|
||||
}
|
||||
|
||||
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
||||
// Determine which method to use based on LogicOperator
|
||||
useOrLogic := strings.EqualFold(filter.LogicOperator, "OR")
|
||||
|
||||
var condition string
|
||||
var args []interface{}
|
||||
|
||||
switch filter.Operator {
|
||||
case "eq", "=":
|
||||
condition = fmt.Sprintf("%s = ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "neq", "!=", "<>":
|
||||
condition = fmt.Sprintf("%s != ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gt", ">":
|
||||
condition = fmt.Sprintf("%s > ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "gte", ">=":
|
||||
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lt", "<":
|
||||
condition = fmt.Sprintf("%s < ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "lte", "<=":
|
||||
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "like":
|
||||
condition = fmt.Sprintf("%s LIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
condition = fmt.Sprintf("%s ILIKE ?", filter.Column)
|
||||
args = []interface{}{filter.Value}
|
||||
case "in":
|
||||
condition, args = common.BuildInCondition(filter.Column, filter.Value)
|
||||
if condition == "" {
|
||||
return query
|
||||
}
|
||||
default:
|
||||
return query
|
||||
}
|
||||
|
||||
// Apply filter with appropriate logic operator
|
||||
if useOrLogic {
|
||||
return query.WhereOr(condition, args...)
|
||||
}
|
||||
return query.Where(condition, args...)
|
||||
}
|
||||
|
||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||
@@ -1280,10 +1662,16 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
|
||||
return schema, entity
|
||||
}
|
||||
|
||||
// getTableName returns the full table name including schema (schema.table)
|
||||
// getTableName returns the full table name including schema.
|
||||
// For most drivers the result is "schema.table". For SQLite, which does not
|
||||
// support schema-qualified names, the schema and table are joined with an
|
||||
// underscore: "schema_table".
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||
if schemaName != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||
}
|
||||
return tableName
|
||||
@@ -1558,6 +1946,8 @@ func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, pre
|
||||
// Build RequestOptions with all preloads to allow references to sibling relations
|
||||
preloadOpts := &common.RequestOptions{Preload: preloads}
|
||||
sanitizedWhere := common.SanitizeWhereClause(preload.Where, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
||||
// Ensure outer parentheses to prevent OR logic from escaping
|
||||
sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere)
|
||||
if len(sanitizedWhere) > 0 {
|
||||
sq = sq.Where(sanitizedWhere)
|
||||
}
|
||||
@@ -1601,6 +1991,51 @@ func toSnakeCase(s string) string {
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
||||
// The row number is calculated as offset + index + 1 (1-based)
|
||||
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
||||
// Get the reflect value of the records
|
||||
recordsValue := reflect.ValueOf(records)
|
||||
if recordsValue.Kind() == reflect.Ptr {
|
||||
recordsValue = recordsValue.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a slice
|
||||
if recordsValue.Kind() != reflect.Slice {
|
||||
logger.Debug("setRowNumbersOnRecords: records is not a slice, skipping")
|
||||
return
|
||||
}
|
||||
|
||||
// Iterate through each record
|
||||
for i := 0; i < recordsValue.Len(); i++ {
|
||||
record := recordsValue.Index(i)
|
||||
|
||||
// Dereference if it's a pointer
|
||||
if record.Kind() == reflect.Ptr {
|
||||
if record.IsNil() {
|
||||
continue
|
||||
}
|
||||
record = record.Elem()
|
||||
}
|
||||
|
||||
// Ensure it's a struct
|
||||
if record.Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to find and set the RowNumber field
|
||||
rowNumberField := record.FieldByName("RowNumber")
|
||||
if rowNumberField.IsValid() && rowNumberField.CanSet() {
|
||||
// Check if the field is of type int64
|
||||
if rowNumberField.Kind() == reflect.Int64 {
|
||||
rowNum := int64(offset + i + 1)
|
||||
rowNumberField.SetInt(rowNum)
|
||||
logger.Debug("Set RowNumber=%d for record index %d", rowNum, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// HandleOpenAPI generates and returns the OpenAPI specification
|
||||
func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
|
||||
if h.openAPIGenerator == nil {
|
||||
|
||||
@@ -12,6 +12,10 @@ import (
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||
// Use this for auth checks that need model rules and user context simultaneously.
|
||||
BeforeHandle HookType = "before_handle"
|
||||
|
||||
// Read operation hooks
|
||||
BeforeRead HookType = "before_read"
|
||||
AfterRead HookType = "after_read"
|
||||
@@ -43,6 +47,9 @@ type HookContext struct {
|
||||
Writer common.ResponseWriter
|
||||
Request common.Request
|
||||
|
||||
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||
Operation string
|
||||
|
||||
// Operation-specific fields
|
||||
ID string
|
||||
Data interface{} // For create/update operations
|
||||
|
||||
@@ -50,8 +50,9 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
||||
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
})
|
||||
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
|
||||
@@ -69,17 +70,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
||||
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
|
||||
|
||||
// Create handler functions for this specific entity
|
||||
postEntityHandler := createMuxHandler(handler, schema, entity, "")
|
||||
postEntityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
||||
getEntityHandler := createMuxGetHandler(handler, schema, entity, "")
|
||||
var postEntityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
|
||||
var postEntityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
|
||||
var getEntityHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
|
||||
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
|
||||
|
||||
// Apply authentication middleware if provided
|
||||
if authMiddleware != nil {
|
||||
postEntityHandler = authMiddleware(postEntityHandler).(http.HandlerFunc)
|
||||
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler).(http.HandlerFunc)
|
||||
getEntityHandler = authMiddleware(getEntityHandler).(http.HandlerFunc)
|
||||
postEntityHandler = authMiddleware(postEntityHandler)
|
||||
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler)
|
||||
getEntityHandler = authMiddleware(getEntityHandler)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
}
|
||||
|
||||
@@ -98,7 +99,8 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
|
||||
// Set CORS headers
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
@@ -106,7 +108,7 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
|
||||
if idParam != "" {
|
||||
vars["id"] = mux.Vars(r)[idParam]
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
@@ -117,7 +119,8 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
|
||||
// Set CORS headers
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
@@ -125,7 +128,7 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
|
||||
if idParam != "" {
|
||||
vars["id"] = mux.Vars(r)[idParam]
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
@@ -137,13 +140,14 @@ func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMet
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
corsConfig.AllowedMethods = allowedMethods
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
// Return metadata in the OPTIONS response body
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
vars["entity"] = entity
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
@@ -212,9 +216,34 @@ type BunRouterHandler interface {
|
||||
Handle(method, path string, handler bunrouter.HandlerFunc)
|
||||
}
|
||||
|
||||
// wrapBunRouterHandler wraps a bunrouter handler with auth middleware if provided
|
||||
func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware MiddlewareFunc) bunrouter.HandlerFunc {
|
||||
if authMiddleware == nil {
|
||||
return handler
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
// Create an http.Handler that calls the bunrouter handler
|
||||
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Replace the embedded *http.Request with the middleware-enriched one
|
||||
// so that auth context (user ID, etc.) is visible to the handler.
|
||||
enrichedReq := req
|
||||
enrichedReq.Request = r
|
||||
_ = handler(w, enrichedReq)
|
||||
})
|
||||
|
||||
// Wrap with auth middleware and execute
|
||||
wrappedHandler := authMiddleware(httpHandler)
|
||||
wrappedHandler.ServeHTTP(w, req.Request)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
|
||||
// Accepts bunrouter.Router or bunrouter.Group
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||
|
||||
// CORS config
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
@@ -222,15 +251,16 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
// Add global /openapi route
|
||||
r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -251,85 +281,97 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
currentEntity := entity
|
||||
|
||||
// POST route without ID
|
||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityPath, wrapBunRouterHandler(postEntityHandler, authMiddleware))
|
||||
|
||||
// POST route with ID
|
||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityWithIDPath, wrapBunRouterHandler(postEntityWithIDHandler, authMiddleware))
|
||||
|
||||
// GET route without ID
|
||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityPath, wrapBunRouterHandler(getEntityHandler, authMiddleware))
|
||||
|
||||
// GET route with ID
|
||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityWithIDPath, wrapBunRouterHandler(getEntityWithIDHandler, authMiddleware))
|
||||
|
||||
// OPTIONS route without ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
optionsCorsConfig := corsConfig
|
||||
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
|
||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
// OPTIONS route with ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
optionsCorsConfig := corsConfig
|
||||
optionsCorsConfig.AllowedMethods = []string{"POST", "OPTIONS"}
|
||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
@@ -344,8 +386,8 @@ func ExampleWithBunRouter(bunDB *bun.DB) {
|
||||
// Create bunrouter
|
||||
bunRouter := bunrouter.New()
|
||||
|
||||
// Setup ResolveSpec routes with bunrouter
|
||||
SetupBunRouterRoutes(bunRouter, handler)
|
||||
// Setup ResolveSpec routes with bunrouter without authentication
|
||||
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
@@ -366,8 +408,8 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
|
||||
// Create bunrouter
|
||||
bunRouter := bunrouter.New()
|
||||
|
||||
// Setup ResolveSpec routes
|
||||
SetupBunRouterRoutes(bunRouter, handler)
|
||||
// Setup ResolveSpec routes without authentication
|
||||
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||
|
||||
// This gives you the full uptrace stack: bunrouter + Bun ORM
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
@@ -385,8 +427,87 @@ func ExampleBunRouterWithGroup(bunDB *bun.DB) {
|
||||
apiGroup := bunRouter.NewGroup("/api")
|
||||
|
||||
// Setup ResolveSpec routes on the group - routes will be under /api
|
||||
SetupBunRouterRoutes(apiGroup, handler)
|
||||
SetupBunRouterRoutes(apiGroup, handler, nil)
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
}
|
||||
|
||||
// ExampleWithGORMAndAuth shows how to use ResolveSpec with GORM and authentication
|
||||
func ExampleWithGORMAndAuth(db *gorm.DB) {
|
||||
// Create handler using GORM
|
||||
_ = NewHandlerWithGORM(db)
|
||||
|
||||
// Create auth middleware
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
|
||||
// Setup router with authentication
|
||||
_ = mux.NewRouter()
|
||||
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||
|
||||
// Register models
|
||||
// handler.RegisterModel("public", "users", &User{})
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", muxRouter)
|
||||
}
|
||||
|
||||
// ExampleWithBunAndAuth shows how to use ResolveSpec with Bun and authentication
|
||||
func ExampleWithBunAndAuth(bunDB *bun.DB) {
|
||||
// Create Bun adapter
|
||||
dbAdapter := database.NewBunAdapter(bunDB)
|
||||
|
||||
// Create model registry
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
// registry.RegisterModel("public.users", &User{})
|
||||
|
||||
// Create handler
|
||||
_ = NewHandler(dbAdapter, registry)
|
||||
|
||||
// Create auth middleware
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
|
||||
// Setup routes with authentication
|
||||
_ = mux.NewRouter()
|
||||
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||
|
||||
// Start server
|
||||
// http.ListenAndServe(":8080", muxRouter)
|
||||
}
|
||||
|
||||
// ExampleBunRouterWithBunDBAndAuth shows the full uptrace stack with authentication
|
||||
func ExampleBunRouterWithBunDBAndAuth(bunDB *bun.DB) {
|
||||
// Create Bun database adapter
|
||||
dbAdapter := database.NewBunAdapter(bunDB)
|
||||
|
||||
// Create model registry
|
||||
registry := modelregistry.NewModelRegistry()
|
||||
// registry.RegisterModel("public.users", &User{})
|
||||
|
||||
// Create handler with Bun
|
||||
_ = NewHandler(dbAdapter, registry)
|
||||
|
||||
// Create auth middleware
|
||||
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
// secList := security.NewSecurityList(myProvider)
|
||||
// authMiddleware := func(h http.Handler) http.Handler {
|
||||
// return security.NewAuthHandler(secList, h)
|
||||
// }
|
||||
|
||||
// Create bunrouter
|
||||
_ = bunrouter.New()
|
||||
|
||||
// Setup ResolveSpec routes with authentication
|
||||
// SetupBunRouterRoutes(bunRouter, handler, authMiddleware)
|
||||
|
||||
// This gives you the full uptrace stack: bunrouter + Bun ORM with authentication
|
||||
// http.ListenAndServe(":8080", bunRouter)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package resolvespec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
@@ -10,6 +11,17 @@ import (
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
@@ -34,6 +46,18 @@ func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Hook 5: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelUpdateAllowed(secCtx)
|
||||
})
|
||||
|
||||
// Hook 6: BeforeDelete - enforce CanDelete rule from context/registry
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelDeleteAllowed(secCtx)
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for resolvespec handler")
|
||||
}
|
||||
|
||||
|
||||
@@ -214,14 +214,46 @@ x-expand: department:id,name,code
|
||||
**Note:** Currently, expand falls back to preload behavior. Full JOIN expansion is planned for future implementation.
|
||||
|
||||
#### `x-custom-sql-join`
|
||||
Raw SQL JOIN statement.
|
||||
Custom SQL JOIN clauses for joining tables in queries.
|
||||
|
||||
**Format:** SQL JOIN clause
|
||||
**Format:** SQL JOIN clause or multiple clauses separated by `|`
|
||||
|
||||
**Single JOIN:**
|
||||
```
|
||||
x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id
|
||||
```
|
||||
|
||||
⚠️ **Note:** Not yet fully implemented.
|
||||
**Multiple JOINs:**
|
||||
```
|
||||
x-custom-sql-join: LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Supports any type of JOIN (INNER, LEFT, RIGHT, FULL, CROSS)
|
||||
- Multiple JOINs can be specified using the pipe `|` separator
|
||||
- JOINs are sanitized for security
|
||||
- Can be specified via headers or query parameters
|
||||
- **Table aliases are automatically extracted and allowed for filtering and sorting**
|
||||
|
||||
**Using Join Aliases in Filters and Sorts:**
|
||||
|
||||
When you specify a custom SQL join with an alias, you can use that alias in your filter and sort parameters:
|
||||
|
||||
```
|
||||
# Join with alias
|
||||
x-custom-sql-join: LEFT JOIN departments d ON d.id = employees.department_id
|
||||
|
||||
# Sort by joined table column
|
||||
x-sort: d.name,employees.id
|
||||
|
||||
# Filter by joined table column
|
||||
x-searchop-eq-d.name: Engineering
|
||||
```
|
||||
|
||||
The system automatically:
|
||||
1. Extracts the alias from the JOIN clause (e.g., `d` from `departments d`)
|
||||
2. Validates that prefixed columns (like `d.name`) refer to valid join aliases
|
||||
3. Allows these prefixed columns in filters and sorts
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -147,6 +147,7 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
|
||||
```
|
||||
|
||||
**Available Hook Types**:
|
||||
* `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
|
||||
* `BeforeRead`, `AfterRead`
|
||||
* `BeforeCreate`, `AfterCreate`
|
||||
* `BeforeUpdate`, `AfterUpdate`
|
||||
@@ -157,11 +158,13 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
|
||||
* `Handler`: Access to handler, database, and registry
|
||||
* `Schema`, `Entity`, `TableName`: Request info
|
||||
* `Model`: The registered model type
|
||||
* `Operation`: Current operation string (`"read"`, `"create"`, `"update"`, `"delete"`)
|
||||
* `Options`: Parsed request options (filters, sorting, etc.)
|
||||
* `ID`: Record ID (for single-record operations)
|
||||
* `Data`: Request data (for create/update)
|
||||
* `Result`: Operation result (for after hooks)
|
||||
* `Writer`: Response writer (allows hooks to modify response)
|
||||
* `Abort`, `AbortMessage`, `AbortCode`: Set in hook to abort with an error response
|
||||
|
||||
## Cursor Pagination
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ type queryCacheKey struct {
|
||||
Sort []common.SortOption `json:"sort"`
|
||||
CustomSQLWhere string `json:"custom_sql_where,omitempty"`
|
||||
CustomSQLOr string `json:"custom_sql_or,omitempty"`
|
||||
CustomSQLJoin []string `json:"custom_sql_join,omitempty"`
|
||||
Expand []expandOptionKey `json:"expand,omitempty"`
|
||||
Distinct bool `json:"distinct,omitempty"`
|
||||
CursorForward string `json:"cursor_forward,omitempty"`
|
||||
@@ -40,7 +41,7 @@ type cachedTotal struct {
|
||||
// buildExtendedQueryCacheKey builds a cache key for extended query options (restheadspec)
|
||||
// Includes expand, distinct, and cursor pagination options
|
||||
func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption, sort []common.SortOption,
|
||||
customWhere, customOr string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
||||
customWhere, customOr string, customJoin []string, expandOpts []interface{}, distinct bool, cursorFwd, cursorBwd string) string {
|
||||
|
||||
key := queryCacheKey{
|
||||
TableName: tableName,
|
||||
@@ -48,6 +49,7 @@ func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
|
||||
Sort: sort,
|
||||
CustomSQLWhere: customWhere,
|
||||
CustomSQLOr: customOr,
|
||||
CustomSQLJoin: customJoin,
|
||||
Distinct: distinct,
|
||||
CursorForward: cursorFwd,
|
||||
CursorBackward: cursorBwd,
|
||||
@@ -75,8 +77,8 @@ func buildExtendedQueryCacheKey(tableName string, filters []common.FilterOption,
|
||||
jsonData, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
// Fallback to simple string concatenation if JSON fails
|
||||
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%s_%s",
|
||||
tableName, filters, sort, customWhere, customOr, expandOpts, distinct, cursorFwd, cursorBwd))
|
||||
return hashString(fmt.Sprintf("%s_%v_%v_%s_%s_%v_%v_%v_%s_%s",
|
||||
tableName, filters, sort, customWhere, customOr, customJoin, expandOpts, distinct, cursorFwd, cursorBwd))
|
||||
}
|
||||
|
||||
return hashString(string(jsonData))
|
||||
|
||||
@@ -32,6 +32,8 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
modelColumns []string, // optional: for validation
|
||||
expandJoins map[string]string, // optional: alias → JOIN SQL
|
||||
) (string, error) {
|
||||
// Separate schema prefix from bare table name
|
||||
fullTableName := tableName
|
||||
if strings.Contains(tableName, ".") {
|
||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||
}
|
||||
@@ -62,7 +64,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
// 4. Process each sort column
|
||||
// --------------------------------------------------------------------- //
|
||||
for _, s := range sortItems {
|
||||
col := strings.TrimSpace(s.Column)
|
||||
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||
if col == "" {
|
||||
continue
|
||||
}
|
||||
@@ -91,12 +93,18 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
}
|
||||
|
||||
// Handle joins
|
||||
if isJoin && expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
if isJoin {
|
||||
if expandJoins != nil {
|
||||
if joinClause, ok := expandJoins[prefix]; ok {
|
||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||
joinSQL = jSQL
|
||||
cursorCol = cRef + "." + field
|
||||
targetCol = prefix + "." + field
|
||||
}
|
||||
}
|
||||
if cursorCol == "" {
|
||||
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -127,7 +135,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
||||
WHERE cursor_select.%s = %s
|
||||
AND (%s)
|
||||
)`,
|
||||
tableName,
|
||||
fullTableName,
|
||||
joinSQL,
|
||||
pkName,
|
||||
cursorID,
|
||||
|
||||
@@ -187,9 +187,9 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
// Should handle schema prefix properly
|
||||
if !strings.Contains(filter, "users") {
|
||||
t.Errorf("Filter should reference table name users, got: %s", filter)
|
||||
// Should include full schema-qualified name in FROM clause
|
||||
if !strings.Contains(filter, "public.users") {
|
||||
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
|
||||
}
|
||||
|
||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||
@@ -278,6 +278,47 @@ func TestCleanSortField(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCursorFilter_LateralJoin(t *testing.T) {
|
||||
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
|
||||
|
||||
opts := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Sort: []common.SortOption{
|
||||
{Column: "fn.sortorder", Direction: "ASC"},
|
||||
},
|
||||
},
|
||||
}
|
||||
opts.CursorForward = "8975"
|
||||
|
||||
tableName := "core.account"
|
||||
pkName := "rid_account"
|
||||
// modelColumns does not contain "sortorder" - it's a lateral join computed column
|
||||
modelColumns := []string{"rid_account", "description", "pastelno"}
|
||||
expandJoins := map[string]string{"fn": lateralJoin}
|
||||
|
||||
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
|
||||
if err != nil {
|
||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Generated lateral cursor filter: %s", filter)
|
||||
|
||||
// Should contain the rewritten lateral join inside the EXISTS subquery
|
||||
if !strings.Contains(filter, "cursor_select_fn") {
|
||||
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
|
||||
}
|
||||
|
||||
// Should compare fn.sortorder values
|
||||
if !strings.Contains(filter, "sortorder") {
|
||||
t.Errorf("Filter should reference sortorder column, got: %s", filter)
|
||||
}
|
||||
|
||||
// Should NOT contain empty comparison like "< "
|
||||
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
|
||||
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPriorityChain(t *testing.T) {
|
||||
clauses := []string{
|
||||
"cursor_select.priority > posts.priority",
|
||||
|
||||
@@ -133,6 +133,41 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
||||
// Add request-scoped data to context (including options)
|
||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
||||
|
||||
// Derive operation for auth check
|
||||
var operation string
|
||||
switch method {
|
||||
case "GET":
|
||||
operation = "read"
|
||||
case "POST":
|
||||
operation = "create"
|
||||
case "PUT", "PATCH":
|
||||
operation = "update"
|
||||
case "DELETE":
|
||||
operation = "delete"
|
||||
default:
|
||||
operation = "read"
|
||||
}
|
||||
|
||||
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||
beforeCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
Model: model,
|
||||
Writer: w,
|
||||
Request: r,
|
||||
Operation: operation,
|
||||
}
|
||||
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
|
||||
code := http.StatusUnauthorized
|
||||
if beforeCtx.AbortCode != 0 {
|
||||
code = beforeCtx.AbortCode
|
||||
}
|
||||
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch method {
|
||||
case "GET":
|
||||
if id != "" {
|
||||
@@ -435,9 +470,11 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
|
||||
// Apply preloading
|
||||
logger.Debug("Total preloads to apply: %d", len(options.Preload))
|
||||
for idx := range options.Preload {
|
||||
preload := options.Preload[idx]
|
||||
logger.Debug("Applying preload: %s", preload.Relation)
|
||||
logger.Debug("Applying preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, Where=%s",
|
||||
idx, preload.Relation, preload.Recursive, preload.RelatedKey, preload.Where)
|
||||
|
||||
// Validate and fix WHERE clause to ensure it contains the relation prefix
|
||||
if len(preload.Where) > 0 {
|
||||
@@ -463,7 +500,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
|
||||
// Apply filters - validate and adjust for column types first
|
||||
for i := range options.Filters {
|
||||
// Group consecutive OR filters together to prevent OR logic from escaping
|
||||
for i := 0; i < len(options.Filters); {
|
||||
filter := &options.Filters[i]
|
||||
|
||||
// Validate and adjust filter based on column type
|
||||
@@ -475,8 +513,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
logicOp = "AND"
|
||||
}
|
||||
|
||||
logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp)
|
||||
query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp)
|
||||
// Check if this is the start of an OR group
|
||||
if logicOp == "OR" {
|
||||
// Collect all consecutive OR filters
|
||||
orFilters := []*common.FilterOption{filter}
|
||||
orCastInfo := []ColumnCastInfo{castInfo}
|
||||
|
||||
j := i + 1
|
||||
for j < len(options.Filters) {
|
||||
nextFilter := &options.Filters[j]
|
||||
nextLogicOp := nextFilter.LogicOperator
|
||||
if nextLogicOp == "" {
|
||||
nextLogicOp = "AND"
|
||||
}
|
||||
if nextLogicOp == "OR" {
|
||||
nextCastInfo := h.ValidateAndAdjustFilterForColumnType(nextFilter, model)
|
||||
orFilters = append(orFilters, nextFilter)
|
||||
orCastInfo = append(orCastInfo, nextCastInfo)
|
||||
j++
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the OR group as a single grouped condition
|
||||
logger.Debug("Applying OR filter group with %d conditions", len(orFilters))
|
||||
query = h.applyOrFilterGroup(query, orFilters, orCastInfo, tableName)
|
||||
i = j
|
||||
} else {
|
||||
// Single AND filter - apply normally
|
||||
logger.Debug("Applying filter: %s %s %v (needsCast=%v, logic=%s)", filter.Column, filter.Operator, filter.Value, castInfo.NeedsCast, logicOp)
|
||||
query = h.applyFilter(query, *filter, tableName, castInfo.NeedsCast, logicOp)
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom SQL WHERE clause (AND condition)
|
||||
@@ -486,6 +555,8 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
prefixedWhere := common.AddTablePrefixToColumns(options.CustomSQLWhere, reflection.ExtractTableNameOnly(tableName))
|
||||
// Then sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
// Ensure outer parentheses to prevent OR logic from escaping
|
||||
sanitizedWhere = common.EnsureOuterParentheses(sanitizedWhere)
|
||||
if sanitizedWhere != "" {
|
||||
query = query.Where(sanitizedWhere)
|
||||
}
|
||||
@@ -497,13 +568,46 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
customOr := common.AddTablePrefixToColumns(options.CustomSQLOr, reflection.ExtractTableNameOnly(tableName))
|
||||
// Sanitize and allow preload table prefixes since custom SQL may reference multiple tables
|
||||
sanitizedOr := common.SanitizeWhereClause(customOr, reflection.ExtractTableNameOnly(tableName), &options.RequestOptions)
|
||||
// Ensure outer parentheses to prevent OR logic from escaping
|
||||
sanitizedOr = common.EnsureOuterParentheses(sanitizedOr)
|
||||
if sanitizedOr != "" {
|
||||
query = query.WhereOr(sanitizedOr)
|
||||
}
|
||||
}
|
||||
|
||||
// If ID is provided, filter by ID
|
||||
if id != "" {
|
||||
// Apply custom SQL JOIN clauses
|
||||
if len(options.CustomSQLJoin) > 0 {
|
||||
for _, joinClause := range options.CustomSQLJoin {
|
||||
logger.Debug("Applying custom SQL JOIN: %s", joinClause)
|
||||
// Joins are already sanitized during parsing, so we can apply them directly
|
||||
query = query.Join(joinClause)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle FetchRowNumber before applying ID filter
|
||||
// This must happen before the query to get the row position, then filter by PK
|
||||
var fetchedRowNumber *int64
|
||||
var fetchRowNumberPKValue string
|
||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
fetchRowNumberPKValue = *options.FetchRowNumber
|
||||
|
||||
logger.Debug("FetchRowNumber: Fetching row number for PK %s = %s", pkName, fetchRowNumberPKValue)
|
||||
|
||||
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, fetchRowNumberPKValue, options, model)
|
||||
if err != nil {
|
||||
logger.Error("Failed to fetch row number: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "fetch_rownumber_error", "Failed to fetch row number", err)
|
||||
return
|
||||
}
|
||||
|
||||
fetchedRowNumber = &rowNum
|
||||
logger.Debug("FetchRowNumber: Row number %d for PK %s = %s", rowNum, pkName, fetchRowNumberPKValue)
|
||||
|
||||
// Now filter the main query to this specific primary key
|
||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), fetchRowNumberPKValue)
|
||||
} else if id != "" {
|
||||
// If ID is provided (and not FetchRowNumber), filter by ID
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
logger.Debug("Filtering by ID=%s: %s", pkName, id)
|
||||
|
||||
@@ -552,6 +656,7 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
options.Sort,
|
||||
options.CustomSQLWhere,
|
||||
options.CustomSQLOr,
|
||||
options.CustomSQLJoin,
|
||||
expandOpts,
|
||||
options.Distinct,
|
||||
options.CursorForward,
|
||||
@@ -618,12 +723,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
// Extract model columns for validation using the generic database function
|
||||
modelColumns := reflection.GetModelColumns(model)
|
||||
|
||||
// Build expand joins map (if needed in future)
|
||||
var expandJoins map[string]string
|
||||
if len(options.Expand) > 0 {
|
||||
expandJoins = make(map[string]string)
|
||||
// TODO: Build actual JOIN SQL for each expand relation
|
||||
// For now, pass empty map as joins are handled via Preload
|
||||
// Build expand joins map: custom SQL joins are available in cursor subquery
|
||||
expandJoins := make(map[string]string)
|
||||
for _, joinClause := range options.CustomSQLJoin {
|
||||
alias := extractJoinAlias(joinClause)
|
||||
if alias != "" {
|
||||
expandJoins[alias] = joinClause
|
||||
}
|
||||
}
|
||||
// TODO: also add Expand relation JOINs when those are built as SQL rather than Preload
|
||||
|
||||
// Default sort to primary key when none provided
|
||||
if len(options.Sort) == 0 {
|
||||
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||
}
|
||||
|
||||
// Get cursor filter SQL
|
||||
@@ -682,7 +794,14 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
}
|
||||
|
||||
// Set row numbers on each record if the model has a RowNumber field
|
||||
h.setRowNumbersOnRecords(modelPtr, offset)
|
||||
// If FetchRowNumber was used, set the fetched row number instead of offset-based
|
||||
if fetchedRowNumber != nil {
|
||||
// FetchRowNumber: set the actual row position on the record
|
||||
logger.Debug("FetchRowNumber: Setting row number %d on record", *fetchedRowNumber)
|
||||
h.setRowNumbersOnRecords(modelPtr, int(*fetchedRowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
|
||||
} else {
|
||||
h.setRowNumbersOnRecords(modelPtr, offset)
|
||||
}
|
||||
|
||||
metadata := &common.Metadata{
|
||||
Total: int64(total),
|
||||
@@ -692,21 +811,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
// Fetch row number for a specific record if requested
|
||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||
pkName := reflection.GetPrimaryKeyName(model)
|
||||
pkValue := *options.FetchRowNumber
|
||||
|
||||
logger.Debug("Fetching row number for specific PK %s = %s", pkName, pkValue)
|
||||
|
||||
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, pkValue, options, model)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to fetch row number: %v", err)
|
||||
// Don't fail the entire request, just log the warning
|
||||
} else {
|
||||
metadata.RowNumber = &rowNum
|
||||
logger.Debug("Row number for PK %s: %d", pkValue, rowNum)
|
||||
}
|
||||
// If FetchRowNumber was used, also set it in metadata
|
||||
if fetchedRowNumber != nil {
|
||||
metadata.RowNumber = fetchedRowNumber
|
||||
logger.Debug("FetchRowNumber: Row number %d set in metadata", *fetchedRowNumber)
|
||||
}
|
||||
|
||||
// Execute AfterRead hooks
|
||||
@@ -836,6 +944,15 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
}
|
||||
}
|
||||
|
||||
// Apply custom SQL joins from XFiles
|
||||
if len(preload.SqlJoins) > 0 {
|
||||
logger.Debug("Applying %d SQL joins to preload %s", len(preload.SqlJoins), preload.Relation)
|
||||
for _, joinClause := range preload.SqlJoins {
|
||||
sq = sq.Join(joinClause)
|
||||
logger.Debug("Applied SQL join to preload %s: %s", preload.Relation, joinClause)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply filters
|
||||
if len(preload.Filters) > 0 {
|
||||
for _, filter := range preload.Filters {
|
||||
@@ -861,10 +978,25 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
if len(preload.Where) > 0 {
|
||||
// Build RequestOptions with all preloads to allow references to sibling relations
|
||||
preloadOpts := &common.RequestOptions{Preload: allPreloads}
|
||||
// First add table prefixes to unqualified columns
|
||||
prefixedWhere := common.AddTablePrefixToColumns(preload.Where, reflection.ExtractTableNameOnly(preload.Relation))
|
||||
// Then sanitize and allow preload table prefixes
|
||||
sanitizedWhere := common.SanitizeWhereClause(prefixedWhere, reflection.ExtractTableNameOnly(preload.Relation), preloadOpts)
|
||||
|
||||
// Determine the table name to use for WHERE clause processing
|
||||
// Prefer the explicit TableName field (set by XFiles), otherwise extract from relation name
|
||||
tableName := preload.TableName
|
||||
if tableName == "" {
|
||||
tableName = reflection.ExtractTableNameOnly(preload.Relation)
|
||||
}
|
||||
|
||||
// In Bun's Relation context, table prefixes are only needed when there are JOINs
|
||||
// Without JOINs, Bun already knows which table is being queried
|
||||
whereClause := preload.Where
|
||||
if len(preload.SqlJoins) > 0 {
|
||||
// Has JOINs: add table prefixes to disambiguate columns
|
||||
whereClause = common.AddTablePrefixToColumns(preload.Where, tableName)
|
||||
logger.Debug("Added table prefix for preload with joins: '%s' -> '%s'", preload.Where, whereClause)
|
||||
}
|
||||
|
||||
// Sanitize the WHERE clause and allow preload table prefixes
|
||||
sanitizedWhere := common.SanitizeWhereClause(whereClause, tableName, preloadOpts)
|
||||
if len(sanitizedWhere) > 0 {
|
||||
sq = sq.Where(sanitizedWhere)
|
||||
}
|
||||
@@ -883,21 +1015,82 @@ func (h *Handler) applyPreloadWithRecursion(query common.SelectQuery, preload co
|
||||
})
|
||||
|
||||
// Handle recursive preloading
|
||||
if preload.Recursive && depth < 5 {
|
||||
if preload.Recursive && depth < 8 {
|
||||
logger.Debug("Applying recursive preload for %s at depth %d", preload.Relation, depth+1)
|
||||
|
||||
// For recursive relationships, we need to get the last part of the relation path
|
||||
// e.g., "MastertaskItems" -> "MastertaskItems.MastertaskItems"
|
||||
relationParts := strings.Split(preload.Relation, ".")
|
||||
lastRelationName := relationParts[len(relationParts)-1]
|
||||
|
||||
// Create a recursive preload with the same configuration
|
||||
// but with the relation path extended
|
||||
recursivePreload := preload
|
||||
recursivePreload.Relation = preload.Relation + "." + lastRelationName
|
||||
// Generate FK-based relation name for children
|
||||
// Use RecursiveChildKey if available, otherwise fall back to RelatedKey
|
||||
recursiveFK := preload.RecursiveChildKey
|
||||
if recursiveFK == "" {
|
||||
recursiveFK = preload.RelatedKey
|
||||
}
|
||||
|
||||
// Recursively apply preload until we reach depth 5
|
||||
recursiveRelationName := lastRelationName
|
||||
if recursiveFK != "" {
|
||||
// Check if the last relation name already contains the FK suffix
|
||||
// (this happens when XFiles already generated the FK-based name)
|
||||
fkUpper := strings.ToUpper(recursiveFK)
|
||||
expectedSuffix := "_" + fkUpper
|
||||
|
||||
if strings.HasSuffix(lastRelationName, expectedSuffix) {
|
||||
// Already has FK suffix, just reuse the same name
|
||||
recursiveRelationName = lastRelationName
|
||||
logger.Debug("Reusing FK-based relation name for recursion: %s", recursiveRelationName)
|
||||
} else {
|
||||
// Generate FK-based name
|
||||
recursiveRelationName = lastRelationName + expectedSuffix
|
||||
keySource := "RelatedKey"
|
||||
if preload.RecursiveChildKey != "" {
|
||||
keySource = "RecursiveChildKey"
|
||||
}
|
||||
logger.Debug("Generated recursive relation name from %s: %s (from %s)",
|
||||
keySource, recursiveRelationName, recursiveFK)
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Recursive preload for %s has no RecursiveChildKey or RelatedKey, falling back to %s.%s",
|
||||
preload.Relation, preload.Relation, lastRelationName)
|
||||
}
|
||||
|
||||
// Create recursive preload
|
||||
recursivePreload := preload
|
||||
recursivePreload.Relation = preload.Relation + "." + recursiveRelationName
|
||||
recursivePreload.Recursive = false // Prevent infinite recursion at this level
|
||||
|
||||
// Use the recursive FK for child relations, not the parent's RelatedKey
|
||||
if preload.RecursiveChildKey != "" {
|
||||
recursivePreload.RelatedKey = preload.RecursiveChildKey
|
||||
}
|
||||
|
||||
// CRITICAL: Clear parent's WHERE clause - let Bun use FK traversal
|
||||
recursivePreload.Where = ""
|
||||
recursivePreload.Filters = []common.FilterOption{}
|
||||
logger.Debug("Cleared WHERE clause for recursive preload %s at depth %d",
|
||||
recursivePreload.Relation, depth+1)
|
||||
|
||||
// Apply recursively up to depth 8
|
||||
query = h.applyPreloadWithRecursion(query, recursivePreload, allPreloads, model, depth+1)
|
||||
|
||||
// ALSO: Extend any child relations (like DEF) to recursive levels
|
||||
baseRelation := preload.Relation + "."
|
||||
for i := range allPreloads {
|
||||
relatedPreload := allPreloads[i]
|
||||
if strings.HasPrefix(relatedPreload.Relation, baseRelation) &&
|
||||
!strings.Contains(strings.TrimPrefix(relatedPreload.Relation, baseRelation), ".") {
|
||||
childRelationName := strings.TrimPrefix(relatedPreload.Relation, baseRelation)
|
||||
|
||||
extendedChildPreload := relatedPreload
|
||||
extendedChildPreload.Relation = recursivePreload.Relation + "." + childRelationName
|
||||
extendedChildPreload.Recursive = false
|
||||
|
||||
logger.Debug("Extending related preload '%s' to '%s' at recursive depth %d",
|
||||
relatedPreload.Relation, extendedChildPreload.Relation, depth+1)
|
||||
|
||||
query = h.applyPreloadWithRecursion(query, extendedChildPreload, allPreloads, model, depth+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return query
|
||||
@@ -1110,30 +1303,6 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
|
||||
logger.Info("Updating record in %s.%s", schema, entity)
|
||||
|
||||
// Execute BeforeUpdate hooks
|
||||
hookCtx := &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
TableName: tableName,
|
||||
Tx: h.db,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: id,
|
||||
Data: data,
|
||||
Writer: w,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
logger.Error("BeforeUpdate hook failed: %v", err)
|
||||
h.sendError(w, http.StatusBadRequest, "hook_error", "Hook execution failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Use potentially modified data from hook context
|
||||
data = hookCtx.Data
|
||||
|
||||
// Convert data to map
|
||||
dataMap, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
@@ -1167,6 +1336,9 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
// Variable to store the updated record
|
||||
var updatedRecord interface{}
|
||||
|
||||
// Declare hook context to be used inside and outside transaction
|
||||
var hookCtx *HookContext
|
||||
|
||||
// Process nested relations if present
|
||||
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||
// Create temporary nested processor with transaction
|
||||
@@ -1174,7 +1346,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
|
||||
// First, read the existing record from the database
|
||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return fmt.Errorf("record not found with ID: %v", targetID)
|
||||
@@ -1204,6 +1376,30 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
||||
nestedRelations = relations
|
||||
}
|
||||
|
||||
// Execute BeforeUpdate hooks inside transaction
|
||||
hookCtx = &HookContext{
|
||||
Context: ctx,
|
||||
Handler: h,
|
||||
Schema: schema,
|
||||
Entity: entity,
|
||||
TableName: tableName,
|
||||
Tx: tx,
|
||||
Model: model,
|
||||
Options: options,
|
||||
ID: id,
|
||||
Data: dataMap,
|
||||
Writer: w,
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||
return fmt.Errorf("BeforeUpdate hook failed: %w", err)
|
||||
}
|
||||
|
||||
// Use potentially modified data from hook context
|
||||
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||
dataMap = modifiedData
|
||||
}
|
||||
|
||||
// Merge only non-null and non-empty values from the incoming request into the existing record
|
||||
for key, newValue := range dataMap {
|
||||
// Skip if the value is nil
|
||||
@@ -1344,8 +1540,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Warn("BeforeDelete hook failed for ID %s: %v", itemID, err)
|
||||
continue
|
||||
logger.Error("BeforeDelete hook failed for ID %s: %v", itemID, err)
|
||||
return fmt.Errorf("delete not allowed for ID %s: %w", itemID, err)
|
||||
}
|
||||
|
||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
@@ -1418,8 +1614,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Warn("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||
continue
|
||||
logger.Error("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||
return fmt.Errorf("delete not allowed for ID %v: %w", itemID, err)
|
||||
}
|
||||
|
||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
@@ -1476,8 +1672,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
||||
}
|
||||
|
||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||
logger.Warn("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||
continue
|
||||
logger.Error("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||
return fmt.Errorf("delete not allowed for ID %v: %w", itemID, err)
|
||||
}
|
||||
|
||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||
@@ -1791,10 +1987,46 @@ func (h *Handler) processChildRelationsForField(
|
||||
parentIDs[baseName] = parentID
|
||||
}
|
||||
|
||||
// Determine which field name to use for setting parent ID in child data
|
||||
// Priority: Use foreign key field name if specified, otherwise use parent's PK name
|
||||
var foreignKeyFieldName string
|
||||
if relInfo.ForeignKey != "" {
|
||||
// Get the JSON name for the foreign key field in the child model
|
||||
foreignKeyFieldName = reflection.GetJSONNameForField(relatedModelType, relInfo.ForeignKey)
|
||||
if foreignKeyFieldName == "" {
|
||||
// Fallback to lowercase field name
|
||||
foreignKeyFieldName = strings.ToLower(relInfo.ForeignKey)
|
||||
}
|
||||
} else {
|
||||
// Fallback: use parent's primary key name
|
||||
parentPKName := reflection.GetPrimaryKeyName(parentModelType)
|
||||
foreignKeyFieldName = reflection.GetJSONNameForField(parentModelType, parentPKName)
|
||||
if foreignKeyFieldName == "" {
|
||||
foreignKeyFieldName = strings.ToLower(parentPKName)
|
||||
}
|
||||
}
|
||||
|
||||
// Get the primary key name for the child model to avoid overwriting it in recursive relationships
|
||||
childPKName := reflection.GetPrimaryKeyName(relatedModel)
|
||||
childPKFieldName := reflection.GetJSONNameForField(relatedModelType, childPKName)
|
||||
if childPKFieldName == "" {
|
||||
childPKFieldName = strings.ToLower(childPKName)
|
||||
}
|
||||
|
||||
logger.Debug("Setting parent ID in child data: foreignKeyField=%s, parentID=%v, relForeignKey=%s, childPK=%s",
|
||||
foreignKeyFieldName, parentID, relInfo.ForeignKey, childPKFieldName)
|
||||
|
||||
// Process based on relation type and data structure
|
||||
switch v := relationValue.(type) {
|
||||
case map[string]interface{}:
|
||||
// Single related object
|
||||
// Single related object - add parent ID to foreign key field
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
v[foreignKeyFieldName] = parentID
|
||||
logger.Debug("Set foreign key in single relation: %s=%v", foreignKeyFieldName, parentID)
|
||||
} else if foreignKeyFieldName == childPKFieldName {
|
||||
logger.Debug("Skipping foreign key assignment - same as primary key (recursive relationship): %s", foreignKeyFieldName)
|
||||
}
|
||||
_, err := processor.ProcessNestedCUD(ctx, operation, v, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process single relation: %w", err)
|
||||
@@ -1804,6 +2036,14 @@ func (h *Handler) processChildRelationsForField(
|
||||
// Multiple related objects
|
||||
for i, item := range v {
|
||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||
// Add parent ID to foreign key field
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
itemMap[foreignKeyFieldName] = parentID
|
||||
logger.Debug("Set foreign key in relation array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
|
||||
} else if foreignKeyFieldName == childPKFieldName {
|
||||
logger.Debug("Skipping foreign key assignment in array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
|
||||
}
|
||||
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process relation item %d: %w", i, err)
|
||||
@@ -1814,6 +2054,14 @@ func (h *Handler) processChildRelationsForField(
|
||||
case []map[string]interface{}:
|
||||
// Multiple related objects (typed slice)
|
||||
for i, itemMap := range v {
|
||||
// Add parent ID to foreign key field
|
||||
// IMPORTANT: In recursive relationships, don't overwrite the primary key
|
||||
if parentID != nil && foreignKeyFieldName != "" && foreignKeyFieldName != childPKFieldName {
|
||||
itemMap[foreignKeyFieldName] = parentID
|
||||
logger.Debug("Set foreign key in relation typed array[%d]: %s=%v", i, foreignKeyFieldName, parentID)
|
||||
} else if foreignKeyFieldName == childPKFieldName {
|
||||
logger.Debug("Skipping foreign key assignment in typed array[%d] - same as primary key (recursive relationship): %s", i, foreignKeyFieldName)
|
||||
}
|
||||
_, err := processor.ProcessNestedCUD(ctx, operation, itemMap, relatedModel, parentIDs, relatedTableName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to process relation item %d: %w", i, err)
|
||||
@@ -1827,11 +2075,18 @@ func (h *Handler) processChildRelationsForField(
|
||||
return nil
|
||||
}
|
||||
|
||||
// getTableNameForRelatedModel gets the table name for a related model
|
||||
// getTableNameForRelatedModel gets the table name for a related model.
|
||||
// If the model's TableName() is schema-qualified (e.g. "public.users") the
|
||||
// separator is adjusted for the active driver: underscore for SQLite, dot otherwise.
|
||||
func (h *Handler) getTableNameForRelatedModel(model interface{}, defaultName string) string {
|
||||
if provider, ok := model.(common.TableNameProvider); ok {
|
||||
tableName := provider.TableName()
|
||||
if tableName != "" {
|
||||
if schema, table := h.parseTableName(tableName); schema != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schema, table)
|
||||
}
|
||||
}
|
||||
return tableName
|
||||
}
|
||||
}
|
||||
@@ -1898,7 +2153,11 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
||||
// Column is already cast to TEXT if needed
|
||||
return applyWhere(fmt.Sprintf("%s ILIKE ?", qualifiedColumn), filter.Value)
|
||||
case "in":
|
||||
return applyWhere(fmt.Sprintf("%s IN (?)", qualifiedColumn), filter.Value)
|
||||
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
|
||||
if cond == "" {
|
||||
return query
|
||||
}
|
||||
return applyWhere(cond, inArgs...)
|
||||
case "between":
|
||||
// Handle between operator - exclusive (> val1 AND < val2)
|
||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||
@@ -1931,6 +2190,100 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
||||
}
|
||||
}
|
||||
|
||||
// applyOrFilterGroup applies a group of OR filters as a single grouped condition
|
||||
// This ensures OR conditions are properly grouped with parentheses to prevent OR logic from escaping
|
||||
func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common.FilterOption, castInfo []ColumnCastInfo, tableName string) common.SelectQuery {
|
||||
if len(filters) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Build individual filter conditions
|
||||
conditions := []string{}
|
||||
args := []interface{}{}
|
||||
|
||||
for i, filter := range filters {
|
||||
// Qualify the column name with table name if not already qualified
|
||||
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||
|
||||
// Apply casting to text if needed for non-numeric columns or non-numeric values
|
||||
if castInfo[i].NeedsCast {
|
||||
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn)
|
||||
}
|
||||
|
||||
// Build the condition based on operator
|
||||
condition, filterArgs := h.buildFilterCondition(qualifiedColumn, filter, tableName)
|
||||
if condition != "" {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, filterArgs...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
// Join all conditions with OR and wrap in parentheses
|
||||
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
|
||||
logger.Debug("Applying grouped OR conditions: %s", groupedCondition)
|
||||
|
||||
// Apply as AND condition (the OR is already inside the parentheses)
|
||||
return query.Where(groupedCondition, args...)
|
||||
}
|
||||
|
||||
// buildFilterCondition builds a single filter condition and returns the condition string and args
|
||||
func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) {
|
||||
switch strings.ToLower(filter.Operator) {
|
||||
case "eq", "equals", "=":
|
||||
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "neq", "not_equals", "ne", "!=", "<>":
|
||||
return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "gt", "greater_than", ">":
|
||||
return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "gte", "greater_than_equals", "ge", ">=":
|
||||
return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "lt", "less_than", "<":
|
||||
return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "lte", "less_than_equals", "le", "<=":
|
||||
return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "like":
|
||||
return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "ilike":
|
||||
return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
case "in":
|
||||
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
|
||||
return cond, inArgs
|
||||
case "between":
|
||||
// Handle between operator - exclusive (> val1 AND < val2)
|
||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||
return fmt.Sprintf("(%s > ? AND %s < ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
|
||||
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
|
||||
return fmt.Sprintf("(%s > ? AND %s < ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
|
||||
}
|
||||
logger.Warn("Invalid BETWEEN filter value format")
|
||||
return "", nil
|
||||
case "between_inclusive":
|
||||
// Handle between inclusive operator - inclusive (>= val1 AND <= val2)
|
||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||
return fmt.Sprintf("(%s >= ? AND %s <= ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
|
||||
} else if values, ok := filter.Value.([]string); ok && len(values) == 2 {
|
||||
return fmt.Sprintf("(%s >= ? AND %s <= ?)", qualifiedColumn, qualifiedColumn), []interface{}{values[0], values[1]}
|
||||
}
|
||||
logger.Warn("Invalid BETWEEN INCLUSIVE filter value format")
|
||||
return "", nil
|
||||
case "is_null", "isnull":
|
||||
// Check for NULL values - don't use cast for NULL checks
|
||||
colName := h.qualifyColumnName(filter.Column, tableName)
|
||||
return fmt.Sprintf("(%s IS NULL OR %s = '')", colName, colName), nil
|
||||
case "is_not_null", "isnotnull":
|
||||
// Check for NOT NULL values - don't use cast for NULL checks
|
||||
colName := h.qualifyColumnName(filter.Column, tableName)
|
||||
return fmt.Sprintf("(%s IS NOT NULL AND %s != '')", colName, colName), nil
|
||||
default:
|
||||
logger.Warn("Unknown filter operator: %s, defaulting to equals", filter.Operator)
|
||||
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
|
||||
}
|
||||
}
|
||||
|
||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||
func (h *Handler) parseTableName(fullTableName string) (schema, table string) {
|
||||
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
||||
@@ -1983,10 +2336,16 @@ func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interfac
|
||||
return schema, entity
|
||||
}
|
||||
|
||||
// getTableName returns the full table name including schema (schema.table)
|
||||
// getTableName returns the full table name including schema.
|
||||
// For most drivers the result is "schema.table". For SQLite, which does not
|
||||
// support schema-qualified names, the schema and table are joined with an
|
||||
// underscore: "schema_table".
|
||||
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||
if schemaName != "" {
|
||||
if h.db.DriverName() == "sqlite" {
|
||||
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||
}
|
||||
return tableName
|
||||
@@ -2308,21 +2667,8 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
|
||||
sortSQL = fmt.Sprintf("%s.%s ASC", tableName, pkName)
|
||||
}
|
||||
|
||||
// Build WHERE clauses from filters
|
||||
whereClauses := make([]string, 0)
|
||||
for i := range options.Filters {
|
||||
filter := &options.Filters[i]
|
||||
whereClause := h.buildFilterSQL(filter, tableName)
|
||||
if whereClause != "" {
|
||||
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", whereClause))
|
||||
}
|
||||
}
|
||||
|
||||
// Combine WHERE clauses
|
||||
whereSQL := ""
|
||||
if len(whereClauses) > 0 {
|
||||
whereSQL = "WHERE " + strings.Join(whereClauses, " AND ")
|
||||
}
|
||||
// Build WHERE clause from filters with proper OR grouping
|
||||
whereSQL := h.buildWhereClauseWithORGrouping(options.Filters, tableName)
|
||||
|
||||
// Add custom SQL WHERE if provided
|
||||
if options.CustomSQLWhere != "" {
|
||||
@@ -2370,19 +2716,86 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
|
||||
var result []struct {
|
||||
RN int64 `bun:"rn"`
|
||||
}
|
||||
logger.Debug("[FetchRowNumber] BEFORE Query call - about to execute raw query")
|
||||
err := h.db.Query(ctx, &result, queryStr, pkValue)
|
||||
logger.Debug("[FetchRowNumber] AFTER Query call - query completed with %d results, err: %v", len(result), err)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to fetch row number: %w", err)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return 0, fmt.Errorf("no row found for primary key %s", pkValue)
|
||||
whereInfo := "none"
|
||||
if whereSQL != "" {
|
||||
whereInfo = whereSQL
|
||||
}
|
||||
return 0, fmt.Errorf("no row found for primary key %s=%s with active filters: %s", pkName, pkValue, whereInfo)
|
||||
}
|
||||
|
||||
return result[0].RN, nil
|
||||
}
|
||||
|
||||
// buildFilterSQL converts a filter to SQL WHERE clause string
|
||||
// buildWhereClauseWithORGrouping builds a WHERE clause from filters with proper OR grouping
|
||||
// Groups consecutive OR filters together to ensure proper SQL precedence
|
||||
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
|
||||
func (h *Handler) buildWhereClauseWithORGrouping(filters []common.FilterOption, tableName string) string {
|
||||
if len(filters) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var groups []string
|
||||
i := 0
|
||||
|
||||
for i < len(filters) {
|
||||
// Check if this starts an OR group (next filter has OR logic)
|
||||
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||
|
||||
if startORGroup {
|
||||
// Collect all consecutive filters that are OR'd together
|
||||
orGroup := []string{}
|
||||
|
||||
// Add current filter
|
||||
filterSQL := h.buildFilterSQL(&filters[i], tableName)
|
||||
if filterSQL != "" {
|
||||
orGroup = append(orGroup, filterSQL)
|
||||
}
|
||||
|
||||
// Collect remaining OR filters
|
||||
j := i + 1
|
||||
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||
filterSQL := h.buildFilterSQL(&filters[j], tableName)
|
||||
if filterSQL != "" {
|
||||
orGroup = append(orGroup, filterSQL)
|
||||
}
|
||||
j++
|
||||
}
|
||||
|
||||
// Group OR filters with parentheses
|
||||
if len(orGroup) > 0 {
|
||||
if len(orGroup) == 1 {
|
||||
groups = append(groups, orGroup[0])
|
||||
} else {
|
||||
groups = append(groups, "("+strings.Join(orGroup, " OR ")+")")
|
||||
}
|
||||
}
|
||||
i = j
|
||||
} else {
|
||||
// Single filter with AND logic (or first filter)
|
||||
filterSQL := h.buildFilterSQL(&filters[i], tableName)
|
||||
if filterSQL != "" {
|
||||
groups = append(groups, filterSQL)
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
if len(groups) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return "WHERE " + strings.Join(groups, " AND ")
|
||||
}
|
||||
|
||||
func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string) string {
|
||||
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||
|
||||
@@ -2473,6 +2886,8 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
|
||||
|
||||
// Filter base RequestOptions
|
||||
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
|
||||
// Restore JoinAliases cleared by FilterRequestOptions — still needed for SanitizeWhereClause
|
||||
filtered.RequestOptions.JoinAliases = options.JoinAliases
|
||||
|
||||
// Filter SearchColumns
|
||||
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)
|
||||
|
||||
@@ -26,7 +26,9 @@ type ExtendedRequestOptions struct {
|
||||
CustomSQLOr string
|
||||
|
||||
// Joins
|
||||
Expand []ExpandOption
|
||||
Expand []ExpandOption
|
||||
CustomSQLJoin []string // Custom SQL JOIN clauses
|
||||
JoinAliases []string // Extracted table aliases from CustomSQLJoin for validation
|
||||
|
||||
// Advanced features
|
||||
AdvancedSQL map[string]string // Column -> SQL expression
|
||||
@@ -46,7 +48,8 @@ type ExtendedRequestOptions struct {
|
||||
AtomicTransaction bool
|
||||
|
||||
// X-Files configuration - comprehensive query options as a single JSON object
|
||||
XFiles *XFiles
|
||||
XFiles *XFiles
|
||||
XFilesPresent bool // Flag to indicate if X-Files header was provided
|
||||
}
|
||||
|
||||
// ExpandOption represents a relation expansion configuration
|
||||
@@ -111,6 +114,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
AdvancedSQL: make(map[string]string),
|
||||
ComputedQL: make(map[string]string),
|
||||
Expand: make([]ExpandOption, 0),
|
||||
CustomSQLJoin: make([]string, 0),
|
||||
ResponseFormat: "simple", // Default response format
|
||||
SingleRecordAsObject: true, // Default: normalize single-element arrays to objects
|
||||
}
|
||||
@@ -185,8 +189,7 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
case strings.HasPrefix(key, "x-expand"):
|
||||
h.parseExpand(&options, decodedValue)
|
||||
case strings.HasPrefix(key, "x-custom-sql-join"):
|
||||
// TODO: Implement custom SQL join
|
||||
logger.Debug("Custom SQL join not yet implemented: %s", decodedValue)
|
||||
h.parseCustomSQLJoin(&options, decodedValue)
|
||||
|
||||
// Sorting & Pagination
|
||||
case strings.HasPrefix(key, "x-sort"):
|
||||
@@ -271,7 +274,10 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve relation names (convert table names to field names) if model is provided
|
||||
// Resolve relation names (convert table names/prefixes to actual model field names) if model is provided.
|
||||
// This runs for both regular headers and X-Files, because XFile prefixes don't always match model
|
||||
// field names (e.g., prefix "HUB" vs field "HUB_RID_HUB"). RelatedKey/ForeignKey are used to
|
||||
// disambiguate when multiple fields point to the same related type.
|
||||
if model != nil {
|
||||
h.resolveRelationNamesInOptions(&options, model)
|
||||
}
|
||||
@@ -354,6 +360,12 @@ func (h *Handler) parseSearchOp(options *ExtendedRequestOptions, headerKey, valu
|
||||
operator := parts[0]
|
||||
colName := parts[1]
|
||||
|
||||
if strings.HasPrefix(colName, "cql") {
|
||||
// Computed column - Will not filter on it
|
||||
logger.Warn("Search operators on computed columns are not supported: %s", colName)
|
||||
return
|
||||
}
|
||||
|
||||
// Map operator names to filter operators
|
||||
filterOp := h.mapSearchOperator(colName, operator, value)
|
||||
|
||||
@@ -489,6 +501,112 @@ func (h *Handler) parseExpand(options *ExtendedRequestOptions, value string) {
|
||||
}
|
||||
}
|
||||
|
||||
// parseCustomSQLJoin parses x-custom-sql-join header
|
||||
// Format: Single JOIN clause or multiple JOIN clauses separated by |
|
||||
// Example: "LEFT JOIN departments d ON d.id = employees.department_id"
|
||||
// Example: "LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id"
|
||||
func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value string) {
|
||||
if value == "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Split by | for multiple joins
|
||||
joins := strings.Split(value, "|")
|
||||
for _, joinStr := range joins {
|
||||
joinStr = strings.TrimSpace(joinStr)
|
||||
if joinStr == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Basic validation: should contain "JOIN" keyword
|
||||
upperJoin := strings.ToUpper(joinStr)
|
||||
if !strings.Contains(upperJoin, "JOIN") {
|
||||
logger.Warn("Invalid custom SQL join (missing JOIN keyword): %s", joinStr)
|
||||
continue
|
||||
}
|
||||
|
||||
// Sanitize the join clause using common.SanitizeWhereClause
|
||||
// Note: This is basic sanitization - in production you may want stricter validation
|
||||
sanitizedJoin := common.SanitizeWhereClause(joinStr, "", nil)
|
||||
if sanitizedJoin == "" {
|
||||
logger.Warn("Custom SQL join failed sanitization: %s", joinStr)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract table alias from the JOIN clause
|
||||
alias := extractJoinAlias(sanitizedJoin)
|
||||
if alias != "" {
|
||||
options.JoinAliases = append(options.JoinAliases, alias)
|
||||
// Also add to the embedded RequestOptions for validation
|
||||
options.RequestOptions.JoinAliases = append(options.RequestOptions.JoinAliases, alias)
|
||||
logger.Debug("Extracted join alias: %s", alias)
|
||||
}
|
||||
|
||||
logger.Debug("Adding custom SQL join: %s", sanitizedJoin)
|
||||
options.CustomSQLJoin = append(options.CustomSQLJoin, sanitizedJoin)
|
||||
}
|
||||
}
|
||||
|
||||
// extractJoinAlias extracts the table alias from a JOIN clause
|
||||
// Examples:
|
||||
// - "LEFT JOIN departments d ON ..." -> "d"
|
||||
// - "INNER JOIN users AS u ON ..." -> "u"
|
||||
// - "JOIN roles r ON ..." -> "r"
|
||||
// - "INNER JOIN LATERAL (...) fn ON true" -> "fn"
|
||||
func extractJoinAlias(joinClause string) string {
|
||||
upperJoin := strings.ToUpper(joinClause)
|
||||
|
||||
// Find the "JOIN" keyword position
|
||||
joinIdx := strings.Index(upperJoin, "JOIN")
|
||||
if joinIdx == -1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Lateral joins: alias is the word after the closing ) and before ON
|
||||
if strings.Contains(upperJoin, "LATERAL") {
|
||||
lastClose := strings.LastIndex(joinClause, ")")
|
||||
if lastClose != -1 {
|
||||
words := strings.Fields(joinClause[lastClose+1:])
|
||||
// words should be like ["fn", "on", "true"] or ["on", "true"]
|
||||
if len(words) >= 1 && !strings.EqualFold(words[0], "on") {
|
||||
return words[0]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Regular joins: find the "ON" keyword position (first occurrence)
|
||||
onIdx := strings.Index(upperJoin, " ON ")
|
||||
if onIdx == -1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Extract the part between JOIN and ON
|
||||
betweenJoinAndOn := strings.TrimSpace(joinClause[joinIdx+4 : onIdx])
|
||||
|
||||
// Split by spaces to get words
|
||||
words := strings.Fields(betweenJoinAndOn)
|
||||
if len(words) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// If there's an AS keyword, the alias is after it
|
||||
for i, word := range words {
|
||||
if strings.EqualFold(word, "AS") && i+1 < len(words) {
|
||||
return words[i+1]
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, the alias is the last word (if there are 2+ words)
|
||||
// Format: "table_name alias" or just "table_name"
|
||||
if len(words) >= 2 {
|
||||
return words[len(words)-1]
|
||||
}
|
||||
|
||||
// Only one word means it's just the table name, no alias
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseSorting parses x-sort header
|
||||
// Format: +field1,-field2,field3 (+ for ASC, - for DESC, default ASC)
|
||||
func (h *Handler) parseSorting(options *ExtendedRequestOptions, value string) {
|
||||
@@ -590,6 +708,7 @@ func (h *Handler) parseXFiles(options *ExtendedRequestOptions, value string) {
|
||||
|
||||
// Store the original XFiles for reference
|
||||
options.XFiles = &xfiles
|
||||
options.XFilesPresent = true // Mark that X-Files header was provided
|
||||
|
||||
// Map XFiles fields to ExtendedRequestOptions
|
||||
|
||||
@@ -757,8 +876,21 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
|
||||
|
||||
// Resolve each part of the path
|
||||
currentModel := model
|
||||
for _, part := range parts {
|
||||
resolvedPart := h.resolveRelationName(currentModel, part)
|
||||
for partIdx, part := range parts {
|
||||
isLast := partIdx == len(parts)-1
|
||||
var resolvedPart string
|
||||
if isLast {
|
||||
// For the final part, use join-key-aware resolution to disambiguate when
|
||||
// multiple fields point to the same type (e.g., HUB_RID_HUB vs HUB_RID_ASSIGNEDTO).
|
||||
// RelatedKey = parent's local column linking to child; ForeignKey = local column linking to parent.
|
||||
localKey := preload.RelatedKey
|
||||
if localKey == "" {
|
||||
localKey = preload.ForeignKey
|
||||
}
|
||||
resolvedPart = h.resolveRelationNameWithJoinKey(currentModel, part, localKey)
|
||||
} else {
|
||||
resolvedPart = h.resolveRelationName(currentModel, part)
|
||||
}
|
||||
resolvedParts = append(resolvedParts, resolvedPart)
|
||||
|
||||
// Try to get the model type for the next level
|
||||
@@ -874,6 +1006,101 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
|
||||
return nameOrTable
|
||||
}
|
||||
|
||||
// resolveRelationNameWithJoinKey resolves a relation name like resolveRelationName, but when
|
||||
// multiple fields point to the same related type, uses localKey to pick the one whose bun join
|
||||
// tag starts with "join:localKey=". Falls back to resolveRelationName if no key match is found.
|
||||
func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable string, localKey string) string {
|
||||
if localKey == "" {
|
||||
return h.resolveRelationName(model, nameOrTable)
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(model)
|
||||
if modelType == nil {
|
||||
return nameOrTable
|
||||
}
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
modelType = modelType.Elem()
|
||||
}
|
||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||
return nameOrTable
|
||||
}
|
||||
|
||||
// If it's already a direct field name, return as-is (no ambiguity).
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
if modelType.Field(i).Name == nameOrTable {
|
||||
return nameOrTable
|
||||
}
|
||||
}
|
||||
|
||||
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
|
||||
localKeyLower := strings.ToLower(localKey)
|
||||
|
||||
// Find all fields whose related type matches nameOrTable, then pick the one
|
||||
// whose bun join tag local key matches localKey.
|
||||
var fallbackField string
|
||||
for i := 0; i < modelType.NumField(); i++ {
|
||||
field := modelType.Field(i)
|
||||
fieldType := field.Type
|
||||
|
||||
var targetType reflect.Type
|
||||
if fieldType.Kind() == reflect.Slice {
|
||||
targetType = fieldType.Elem()
|
||||
} else if fieldType.Kind() == reflect.Ptr {
|
||||
targetType = fieldType.Elem()
|
||||
}
|
||||
if targetType != nil && targetType.Kind() == reflect.Ptr {
|
||||
targetType = targetType.Elem()
|
||||
}
|
||||
if targetType == nil || targetType.Kind() != reflect.Struct {
|
||||
continue
|
||||
}
|
||||
|
||||
normalizedTypeName := strings.ToLower(targetType.Name())
|
||||
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
|
||||
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
|
||||
if normalizedTypeName != normalizedInput {
|
||||
continue
|
||||
}
|
||||
|
||||
// Type name matches; record as fallback.
|
||||
if fallbackField == "" {
|
||||
fallbackField = field.Name
|
||||
}
|
||||
|
||||
// Check bun join tag: "join:localKey=foreignKey"
|
||||
bunTag := field.Tag.Get("bun")
|
||||
for _, tagPart := range strings.Split(bunTag, ",") {
|
||||
tagPart = strings.TrimSpace(tagPart)
|
||||
if !strings.HasPrefix(tagPart, "join:") {
|
||||
continue
|
||||
}
|
||||
joinSpec := strings.TrimPrefix(tagPart, "join:")
|
||||
// joinSpec can be "col1=col2" or "col1=col2 col3=col4" (multi-col joins)
|
||||
joinCols := strings.Fields(joinSpec)
|
||||
if len(joinCols) == 0 {
|
||||
joinCols = []string{joinSpec}
|
||||
}
|
||||
for _, joinCol := range joinCols {
|
||||
eqIdx := strings.Index(joinCol, "=")
|
||||
if eqIdx < 0 {
|
||||
continue
|
||||
}
|
||||
joinLocalKey := strings.ToLower(joinCol[:eqIdx])
|
||||
if joinLocalKey == localKeyLower {
|
||||
logger.Debug("Resolved '%s' (localKey: %s) -> field '%s'", nameOrTable, localKey, field.Name)
|
||||
return field.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if fallbackField != "" {
|
||||
logger.Debug("No join key match for '%s' (localKey: %s), using first type match: '%s'", nameOrTable, localKey, fallbackField)
|
||||
return fallbackField
|
||||
}
|
||||
return h.resolveRelationName(model, nameOrTable)
|
||||
}
|
||||
|
||||
// addXFilesPreload converts an XFiles relation into a PreloadOption
|
||||
// and recursively processes its children
|
||||
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
|
||||
@@ -881,11 +1108,33 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
return
|
||||
}
|
||||
|
||||
// Store the table name as-is for now - it will be resolved to field name later
|
||||
// when we have the model instance available
|
||||
relationPath := xfile.TableName
|
||||
// Use the Prefix (e.g., "MAL") as the relation name, which matches the Go struct field name
|
||||
// Fall back to TableName if Prefix is not specified
|
||||
relationName := xfile.Prefix
|
||||
if relationName == "" {
|
||||
relationName = xfile.TableName
|
||||
}
|
||||
|
||||
// SPECIAL CASE: For recursive child tables, generate FK-based relation name
|
||||
// Example: If prefix is "MAL" and relatedkey is "rid_parentmastertaskitem",
|
||||
// the actual struct field is "MAL_RID_PARENTMASTERTASKITEM", not "MAL"
|
||||
if xfile.Recursive && xfile.RelatedKey != "" && basePath != "" {
|
||||
// Check if this is a self-referencing recursive relation (same table as parent)
|
||||
// by comparing the last part of basePath with the current prefix
|
||||
basePathParts := strings.Split(basePath, ".")
|
||||
lastPrefix := basePathParts[len(basePathParts)-1]
|
||||
|
||||
if lastPrefix == relationName {
|
||||
// This is a recursive self-reference, use FK-based name
|
||||
fkUpper := strings.ToUpper(xfile.RelatedKey)
|
||||
relationName = relationName + "_" + fkUpper
|
||||
logger.Debug("X-Files: Generated FK-based relation name for recursive table: %s", relationName)
|
||||
}
|
||||
}
|
||||
|
||||
relationPath := relationName
|
||||
if basePath != "" {
|
||||
relationPath = basePath + "." + xfile.TableName
|
||||
relationPath = basePath + "." + relationName
|
||||
}
|
||||
|
||||
logger.Debug("X-Files: Adding preload for relation: %s", relationPath)
|
||||
@@ -893,6 +1142,7 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
// Create PreloadOption from XFiles configuration
|
||||
preloadOpt := common.PreloadOption{
|
||||
Relation: relationPath,
|
||||
TableName: xfile.TableName, // Store the actual database table name for WHERE clause processing
|
||||
Columns: xfile.Columns,
|
||||
OmitColumns: xfile.OmitColumns,
|
||||
}
|
||||
@@ -932,15 +1182,42 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
}
|
||||
}
|
||||
|
||||
// Transfer SqlJoins from XFiles to PreloadOption first, so aliases are available for WHERE sanitization
|
||||
if len(xfile.SqlJoins) > 0 {
|
||||
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
|
||||
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
|
||||
|
||||
for _, joinClause := range xfile.SqlJoins {
|
||||
// Sanitize the join clause
|
||||
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
|
||||
if sanitizedJoin == "" {
|
||||
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
|
||||
continue
|
||||
}
|
||||
|
||||
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
|
||||
|
||||
// Extract join alias for validation
|
||||
alias := extractJoinAlias(sanitizedJoin)
|
||||
if alias != "" {
|
||||
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
|
||||
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
|
||||
}
|
||||
|
||||
// Add WHERE clause if SQL conditions specified
|
||||
// SqlJoins must be processed first so join aliases are known and not incorrectly replaced
|
||||
whereConditions := make([]string, 0)
|
||||
if len(xfile.SqlAnd) > 0 {
|
||||
// Process each SQL condition: add table prefixes and sanitize
|
||||
var sqlAndOpts *common.RequestOptions
|
||||
if len(preloadOpt.JoinAliases) > 0 {
|
||||
sqlAndOpts = &common.RequestOptions{JoinAliases: preloadOpt.JoinAliases}
|
||||
}
|
||||
for _, sqlCond := range xfile.SqlAnd {
|
||||
// First add table prefixes to unqualified columns
|
||||
prefixedCond := common.AddTablePrefixToColumns(sqlCond, xfile.TableName)
|
||||
// Then sanitize the condition
|
||||
sanitizedCond := common.SanitizeWhereClause(prefixedCond, xfile.TableName)
|
||||
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName, sqlAndOpts)
|
||||
if sanitizedCond != "" {
|
||||
whereConditions = append(whereConditions, sanitizedCond)
|
||||
}
|
||||
@@ -985,13 +1262,46 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
||||
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
||||
}
|
||||
|
||||
// Check if this table has a recursive child - if so, mark THIS preload as recursive
|
||||
// and store the recursive child's RelatedKey for recursion generation
|
||||
hasRecursiveChild := false
|
||||
if len(xfile.ChildTables) > 0 {
|
||||
for _, childTable := range xfile.ChildTables {
|
||||
if childTable.Recursive && childTable.TableName == xfile.TableName {
|
||||
hasRecursiveChild = true
|
||||
preloadOpt.Recursive = true
|
||||
preloadOpt.RecursiveChildKey = childTable.RelatedKey
|
||||
logger.Debug("X-Files: Detected recursive child for %s, marking parent as recursive (recursive FK: %s)",
|
||||
relationPath, childTable.RelatedKey)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip adding this preload if it's a recursive child (it will be handled by parent's Recursive flag)
|
||||
if xfile.Recursive && basePath != "" {
|
||||
logger.Debug("X-Files: Skipping recursive child preload: %s (will be handled by parent)", relationPath)
|
||||
// Still process its parent/child tables for relations like DEF
|
||||
h.processXFilesRelations(xfile, options, relationPath)
|
||||
return
|
||||
}
|
||||
|
||||
// Add the preload option
|
||||
options.Preload = append(options.Preload, preloadOpt)
|
||||
logger.Debug("X-Files: Added preload [%d]: Relation=%s, Recursive=%v, RelatedKey=%s, RecursiveChildKey=%s, Where=%s",
|
||||
len(options.Preload)-1, preloadOpt.Relation, preloadOpt.Recursive, preloadOpt.RelatedKey, preloadOpt.RecursiveChildKey, preloadOpt.Where)
|
||||
|
||||
// Recursively process nested ParentTables and ChildTables
|
||||
if xfile.Recursive {
|
||||
logger.Debug("X-Files: Recursive preload enabled for: %s", relationPath)
|
||||
h.processXFilesRelations(xfile, options, relationPath)
|
||||
// Skip processing child tables if we already detected and handled a recursive child
|
||||
if hasRecursiveChild {
|
||||
logger.Debug("X-Files: Skipping child table processing for %s (recursive child already handled)", relationPath)
|
||||
// But still process parent tables
|
||||
if len(xfile.ParentTables) > 0 {
|
||||
logger.Debug("X-Files: Processing %d parent tables for %s", len(xfile.ParentTables), relationPath)
|
||||
for _, parentTable := range xfile.ParentTables {
|
||||
h.addXFilesPreload(parentTable, options, relationPath)
|
||||
}
|
||||
}
|
||||
} else if len(xfile.ParentTables) > 0 || len(xfile.ChildTables) > 0 {
|
||||
h.processXFilesRelations(xfile, options, relationPath)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package restheadspec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
func TestDecodeHeaderValue(t *testing.T) {
|
||||
@@ -37,6 +39,131 @@ func TestDecodeHeaderValue(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddXFilesPreload_WithSqlJoins(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
options := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Preload: make([]common.PreloadOption, 0),
|
||||
},
|
||||
}
|
||||
|
||||
// Create an XFiles with SqlJoins
|
||||
xfile := &XFiles{
|
||||
TableName: "users",
|
||||
SqlJoins: []string{
|
||||
"LEFT JOIN departments d ON d.id = users.department_id",
|
||||
"INNER JOIN roles r ON r.id = users.role_id",
|
||||
},
|
||||
FilterFields: []struct {
|
||||
Field string `json:"field"`
|
||||
Value string `json:"value"`
|
||||
Operator string `json:"operator"`
|
||||
}{
|
||||
{Field: "d.active", Value: "true", Operator: "eq"},
|
||||
{Field: "r.name", Value: "admin", Operator: "eq"},
|
||||
},
|
||||
}
|
||||
|
||||
// Add the XFiles preload
|
||||
handler.addXFilesPreload(xfile, options, "")
|
||||
|
||||
// Verify that a preload was added
|
||||
if len(options.Preload) != 1 {
|
||||
t.Fatalf("Expected 1 preload, got %d", len(options.Preload))
|
||||
}
|
||||
|
||||
preload := options.Preload[0]
|
||||
|
||||
// Verify relation name
|
||||
if preload.Relation != "users" {
|
||||
t.Errorf("Expected relation 'users', got '%s'", preload.Relation)
|
||||
}
|
||||
|
||||
// Verify SqlJoins were transferred
|
||||
if len(preload.SqlJoins) != 2 {
|
||||
t.Fatalf("Expected 2 SQL joins, got %d", len(preload.SqlJoins))
|
||||
}
|
||||
|
||||
// Verify JoinAliases were extracted
|
||||
if len(preload.JoinAliases) != 2 {
|
||||
t.Fatalf("Expected 2 join aliases, got %d", len(preload.JoinAliases))
|
||||
}
|
||||
|
||||
// Verify the aliases are correct
|
||||
expectedAliases := []string{"d", "r"}
|
||||
for i, expected := range expectedAliases {
|
||||
if preload.JoinAliases[i] != expected {
|
||||
t.Errorf("Expected alias '%s', got '%s'", expected, preload.JoinAliases[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify filters were added
|
||||
if len(preload.Filters) != 2 {
|
||||
t.Fatalf("Expected 2 filters, got %d", len(preload.Filters))
|
||||
}
|
||||
|
||||
// Verify filter columns reference joined tables
|
||||
if preload.Filters[0].Column != "d.active" {
|
||||
t.Errorf("Expected filter column 'd.active', got '%s'", preload.Filters[0].Column)
|
||||
}
|
||||
if preload.Filters[1].Column != "r.name" {
|
||||
t.Errorf("Expected filter column 'r.name', got '%s'", preload.Filters[1].Column)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractJoinAlias(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
joinClause string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "LEFT JOIN with alias",
|
||||
joinClause: "LEFT JOIN departments d ON d.id = users.department_id",
|
||||
expected: "d",
|
||||
},
|
||||
{
|
||||
name: "INNER JOIN with AS keyword",
|
||||
joinClause: "INNER JOIN users AS u ON u.id = orders.user_id",
|
||||
expected: "u",
|
||||
},
|
||||
{
|
||||
name: "JOIN without alias",
|
||||
joinClause: "JOIN roles ON roles.id = users.role_id",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Complex join with multiple conditions",
|
||||
joinClause: "LEFT OUTER JOIN products p ON p.id = items.product_id AND p.active = true",
|
||||
expected: "p",
|
||||
},
|
||||
{
|
||||
name: "Invalid join (no ON clause)",
|
||||
joinClause: "LEFT JOIN departments",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "LATERAL join with alias",
|
||||
joinClause: "inner join lateral (select sortorder from compute_fn(t.id)) fn on true",
|
||||
expected: "fn",
|
||||
},
|
||||
{
|
||||
name: "LATERAL join with multiline subquery containing inner ON",
|
||||
joinClause: "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(t.id) r\ninner join account a on a.id = r.id\n) fn on true",
|
||||
expected: "fn",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractJoinAlias(tt.joinClause)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected alias '%s', got '%s'", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Note: The following functions are unexported (lowercase) and cannot be tested directly:
|
||||
// - parseSelectFields
|
||||
// - parseFieldFilter
|
||||
|
||||
@@ -12,6 +12,10 @@ import (
|
||||
type HookType string
|
||||
|
||||
const (
|
||||
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||
// Use this for auth checks that need model rules and user context simultaneously.
|
||||
BeforeHandle HookType = "before_handle"
|
||||
|
||||
// Read operation hooks
|
||||
BeforeRead HookType = "before_read"
|
||||
AfterRead HookType = "after_read"
|
||||
@@ -42,6 +46,9 @@ type HookContext struct {
|
||||
Model interface{}
|
||||
Options ExtendedRequestOptions
|
||||
|
||||
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||
Operation string
|
||||
|
||||
// Operation-specific fields
|
||||
ID string
|
||||
Data interface{} // For create/update operations
|
||||
@@ -56,6 +63,14 @@ type HookContext struct {
|
||||
// Response writer - allows hooks to modify response
|
||||
Writer common.ResponseWriter
|
||||
|
||||
// Request - the original HTTP request
|
||||
Request common.Request
|
||||
|
||||
// Allow hooks to abort the operation
|
||||
Abort bool // If set to true, the operation will be aborted
|
||||
AbortMessage string // Message to return if aborted
|
||||
AbortCode int // HTTP status code if aborted
|
||||
|
||||
// Tx provides access to the database/transaction for executing additional SQL
|
||||
// This allows hooks to run custom queries in addition to the main Query chain
|
||||
Tx common.Database
|
||||
@@ -110,6 +125,12 @@ func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
|
||||
return fmt.Errorf("hook execution failed: %w", err)
|
||||
}
|
||||
|
||||
// Check if hook requested abort
|
||||
if ctx.Abort {
|
||||
logger.Warn("Hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||
}
|
||||
}
|
||||
|
||||
// logger.Debug("All hooks for %s executed successfully", hookType)
|
||||
|
||||
110
pkg/restheadspec/preload_tablename_test.go
Normal file
110
pkg/restheadspec/preload_tablename_test.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// TestPreloadOption_TableName verifies that TableName field is properly used
|
||||
// when provided in PreloadOption for WHERE clause processing
|
||||
func TestPreloadOption_TableName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
preload common.PreloadOption
|
||||
expectedTable string
|
||||
}{
|
||||
{
|
||||
name: "TableName provided explicitly",
|
||||
preload: common.PreloadOption{
|
||||
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
|
||||
TableName: "mastertaskitem",
|
||||
Where: "rid_parentmastertaskitem is null",
|
||||
},
|
||||
expectedTable: "mastertaskitem",
|
||||
},
|
||||
{
|
||||
name: "TableName empty, should use empty string",
|
||||
preload: common.PreloadOption{
|
||||
Relation: "MTL.MAL.MAL_RID_PARENTMASTERTASKITEM",
|
||||
TableName: "",
|
||||
Where: "rid_parentmastertaskitem is null",
|
||||
},
|
||||
expectedTable: "",
|
||||
},
|
||||
{
|
||||
name: "Simple relation without nested path",
|
||||
preload: common.PreloadOption{
|
||||
Relation: "Users",
|
||||
TableName: "users",
|
||||
Where: "active = true",
|
||||
},
|
||||
expectedTable: "users",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test that the TableName field stores the correct value
|
||||
if tt.preload.TableName != tt.expectedTable {
|
||||
t.Errorf("PreloadOption.TableName = %q, want %q", tt.preload.TableName, tt.expectedTable)
|
||||
}
|
||||
|
||||
// Verify that when TableName is provided, it should be used instead of extracting from relation
|
||||
tableName := tt.preload.TableName
|
||||
if tableName == "" {
|
||||
// This simulates the fallback logic in handler.go
|
||||
// In reality, reflection.ExtractTableNameOnly would be called
|
||||
tableName = tt.expectedTable
|
||||
}
|
||||
|
||||
if tableName != tt.expectedTable {
|
||||
t.Errorf("Resolved table name = %q, want %q", tableName, tt.expectedTable)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestXFilesPreload_StoresTableName verifies that XFiles processing
|
||||
// stores the table name in PreloadOption and doesn't add table prefixes to WHERE clauses
|
||||
func TestXFilesPreload_StoresTableName(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
xfiles := &XFiles{
|
||||
TableName: "mastertaskitem",
|
||||
Prefix: "MAL",
|
||||
PrimaryKey: "rid_mastertaskitem",
|
||||
RelatedKey: "rid_mastertask", // Changed from rid_parentmastertaskitem
|
||||
Recursive: false, // Changed from true (recursive children are now skipped)
|
||||
SqlAnd: []string{"rid_parentmastertaskitem is null"},
|
||||
}
|
||||
|
||||
options := &ExtendedRequestOptions{}
|
||||
|
||||
// Process XFiles
|
||||
handler.addXFilesPreload(xfiles, options, "MTL")
|
||||
|
||||
// Verify that a preload was added
|
||||
if len(options.Preload) == 0 {
|
||||
t.Fatal("Expected at least one preload to be added")
|
||||
}
|
||||
|
||||
preload := options.Preload[0]
|
||||
|
||||
// Verify the table name is stored
|
||||
if preload.TableName != "mastertaskitem" {
|
||||
t.Errorf("PreloadOption.TableName = %q, want %q", preload.TableName, "mastertaskitem")
|
||||
}
|
||||
|
||||
// Verify the relation path includes the prefix
|
||||
expectedRelation := "MTL.MAL"
|
||||
if preload.Relation != expectedRelation {
|
||||
t.Errorf("PreloadOption.Relation = %q, want %q", preload.Relation, expectedRelation)
|
||||
}
|
||||
|
||||
// Verify WHERE clause does NOT have table prefix (prefixes only needed for JOINs)
|
||||
expectedWhere := "rid_parentmastertaskitem is null"
|
||||
if preload.Where != expectedWhere {
|
||||
t.Errorf("PreloadOption.Where = %q, want %q (no table prefix)", preload.Where, expectedWhere)
|
||||
}
|
||||
}
|
||||
91
pkg/restheadspec/preload_where_joins_test.go
Normal file
91
pkg/restheadspec/preload_where_joins_test.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestPreloadWhereClause_WithJoins verifies that table prefixes are added
|
||||
// to WHERE clauses when SqlJoins are present
|
||||
func TestPreloadWhereClause_WithJoins(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
where string
|
||||
sqlJoins []string
|
||||
expectedPrefix bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "No joins - no prefix needed",
|
||||
where: "status = 'active'",
|
||||
sqlJoins: []string{},
|
||||
expectedPrefix: false,
|
||||
description: "Without JOINs, Bun knows the table context",
|
||||
},
|
||||
{
|
||||
name: "Has joins - prefix needed",
|
||||
where: "status = 'active'",
|
||||
sqlJoins: []string{"LEFT JOIN other_table ot ON ot.id = main.other_id"},
|
||||
expectedPrefix: true,
|
||||
description: "With JOINs, table prefix disambiguates columns",
|
||||
},
|
||||
{
|
||||
name: "Already has prefix - no change",
|
||||
where: "users.status = 'active'",
|
||||
sqlJoins: []string{"LEFT JOIN roles r ON r.id = users.role_id"},
|
||||
expectedPrefix: true,
|
||||
description: "Existing prefix should be preserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// This test documents the expected behavior
|
||||
// The actual logic is in handler.go lines 916-937
|
||||
|
||||
hasJoins := len(tt.sqlJoins) > 0
|
||||
if hasJoins != tt.expectedPrefix {
|
||||
t.Errorf("Test expectation mismatch: hasJoins=%v, expectedPrefix=%v",
|
||||
hasJoins, tt.expectedPrefix)
|
||||
}
|
||||
|
||||
t.Logf("%s: %s", tt.name, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestXFilesWithJoins_AddsTablePrefix verifies that XFiles with SqlJoins
|
||||
// results in table prefixes being added to WHERE clauses
|
||||
func TestXFilesWithJoins_AddsTablePrefix(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
xfiles := &XFiles{
|
||||
TableName: "users",
|
||||
Prefix: "USR",
|
||||
PrimaryKey: "id",
|
||||
SqlAnd: []string{"status = 'active'"},
|
||||
SqlJoins: []string{"LEFT JOIN departments d ON d.id = users.department_id"},
|
||||
}
|
||||
|
||||
options := &ExtendedRequestOptions{}
|
||||
handler.addXFilesPreload(xfiles, options, "")
|
||||
|
||||
if len(options.Preload) == 0 {
|
||||
t.Fatal("Expected at least one preload to be added")
|
||||
}
|
||||
|
||||
preload := options.Preload[0]
|
||||
|
||||
// Verify SqlJoins were stored
|
||||
if len(preload.SqlJoins) != 1 {
|
||||
t.Errorf("Expected 1 SqlJoin, got %d", len(preload.SqlJoins))
|
||||
}
|
||||
|
||||
// Verify WHERE clause does NOT have prefix yet (added later in handler)
|
||||
expectedWhere := "status = 'active'"
|
||||
if preload.Where != expectedWhere {
|
||||
t.Errorf("PreloadOption.Where = %q, want %q", preload.Where, expectedWhere)
|
||||
}
|
||||
|
||||
// Note: The handler will add the prefix when it sees SqlJoins
|
||||
// This is tested in the handler itself, not during XFiles parsing
|
||||
}
|
||||
@@ -301,6 +301,163 @@ func TestParseOptionsFromQueryParams(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse custom SQL JOIN from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.CustomSQLJoin) == 0 {
|
||||
t.Error("Expected CustomSQLJoin to be set")
|
||||
return
|
||||
}
|
||||
if len(options.CustomSQLJoin) != 1 {
|
||||
t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin))
|
||||
return
|
||||
}
|
||||
expected := `LEFT JOIN departments d ON d.id = employees.department_id`
|
||||
if options.CustomSQLJoin[0] != expected {
|
||||
t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected, options.CustomSQLJoin[0])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse multiple custom SQL JOINs from query params",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles r ON r.id = e.role_id`,
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.CustomSQLJoin) != 2 {
|
||||
t.Errorf("Expected 2 custom SQL joins, got %d", len(options.CustomSQLJoin))
|
||||
return
|
||||
}
|
||||
expected1 := `LEFT JOIN departments d ON d.id = e.dept_id`
|
||||
expected2 := `INNER JOIN roles r ON r.id = e.role_id`
|
||||
if options.CustomSQLJoin[0] != expected1 {
|
||||
t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected1, options.CustomSQLJoin[0])
|
||||
}
|
||||
if options.CustomSQLJoin[1] != expected2 {
|
||||
t.Errorf("Expected CustomSQLJoin[1]=%q, got %q", expected2, options.CustomSQLJoin[1])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Parse custom SQL JOIN from headers",
|
||||
headers: map[string]string{
|
||||
"X-Custom-SQL-Join": `LEFT JOIN users u ON u.id = posts.user_id`,
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.CustomSQLJoin) == 0 {
|
||||
t.Error("Expected CustomSQLJoin to be set from header")
|
||||
return
|
||||
}
|
||||
expected := `LEFT JOIN users u ON u.id = posts.user_id`
|
||||
if options.CustomSQLJoin[0] != expected {
|
||||
t.Errorf("Expected CustomSQLJoin[0]=%q, got %q", expected, options.CustomSQLJoin[0])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Extract aliases from custom SQL JOIN",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.JoinAliases) == 0 {
|
||||
t.Error("Expected JoinAliases to be extracted")
|
||||
return
|
||||
}
|
||||
if len(options.JoinAliases) != 1 {
|
||||
t.Errorf("Expected 1 join alias, got %d", len(options.JoinAliases))
|
||||
return
|
||||
}
|
||||
if options.JoinAliases[0] != "d" {
|
||||
t.Errorf("Expected join alias 'd', got %q", options.JoinAliases[0])
|
||||
}
|
||||
// Also check that it's in the embedded RequestOptions
|
||||
if len(options.RequestOptions.JoinAliases) != 1 || options.RequestOptions.JoinAliases[0] != "d" {
|
||||
t.Error("Expected join alias to also be in RequestOptions.JoinAliases")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Extract multiple aliases from multiple custom SQL JOINs",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = e.dept_id | INNER JOIN roles AS r ON r.id = e.role_id`,
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
if len(options.JoinAliases) != 2 {
|
||||
t.Errorf("Expected 2 join aliases, got %d", len(options.JoinAliases))
|
||||
return
|
||||
}
|
||||
expectedAliases := []string{"d", "r"}
|
||||
for i, expected := range expectedAliases {
|
||||
if options.JoinAliases[i] != expected {
|
||||
t.Errorf("Expected join alias[%d]=%q, got %q", i, expected, options.JoinAliases[i])
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Custom JOIN with sort on joined table",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
|
||||
"x-sort": "d.name,employees.id",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
// Verify join was added
|
||||
if len(options.CustomSQLJoin) != 1 {
|
||||
t.Errorf("Expected 1 custom SQL join, got %d", len(options.CustomSQLJoin))
|
||||
return
|
||||
}
|
||||
// Verify alias was extracted
|
||||
if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" {
|
||||
t.Error("Expected join alias 'd' to be extracted")
|
||||
return
|
||||
}
|
||||
// Verify sort was parsed
|
||||
if len(options.Sort) != 2 {
|
||||
t.Errorf("Expected 2 sort options, got %d", len(options.Sort))
|
||||
return
|
||||
}
|
||||
if options.Sort[0].Column != "d.name" {
|
||||
t.Errorf("Expected first sort column 'd.name', got %q", options.Sort[0].Column)
|
||||
}
|
||||
if options.Sort[1].Column != "employees.id" {
|
||||
t.Errorf("Expected second sort column 'employees.id', got %q", options.Sort[1].Column)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Custom JOIN with filter on joined table",
|
||||
queryParams: map[string]string{
|
||||
"x-custom-sql-join": `LEFT JOIN departments d ON d.id = employees.department_id`,
|
||||
"x-searchop-eq-d.name": "Engineering",
|
||||
},
|
||||
validate: func(t *testing.T, options ExtendedRequestOptions) {
|
||||
// Verify join was added
|
||||
if len(options.CustomSQLJoin) != 1 {
|
||||
t.Error("Expected 1 custom SQL join")
|
||||
return
|
||||
}
|
||||
// Verify alias was extracted
|
||||
if len(options.JoinAliases) != 1 || options.JoinAliases[0] != "d" {
|
||||
t.Error("Expected join alias 'd' to be extracted")
|
||||
return
|
||||
}
|
||||
// Verify filter was parsed
|
||||
if len(options.Filters) != 1 {
|
||||
t.Errorf("Expected 1 filter, got %d", len(options.Filters))
|
||||
return
|
||||
}
|
||||
if options.Filters[0].Column != "d.name" {
|
||||
t.Errorf("Expected filter column 'd.name', got %q", options.Filters[0].Column)
|
||||
}
|
||||
if options.Filters[0].Operator != "eq" {
|
||||
t.Errorf("Expected filter operator 'eq', got %q", options.Filters[0].Operator)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -395,6 +552,55 @@ func TestHeadersAndQueryParamsCombined(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomJoinAliasExtraction tests the extractJoinAlias helper function
|
||||
func TestCustomJoinAliasExtraction(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
join string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "LEFT JOIN with alias",
|
||||
join: "LEFT JOIN departments d ON d.id = employees.department_id",
|
||||
expected: "d",
|
||||
},
|
||||
{
|
||||
name: "INNER JOIN with AS keyword",
|
||||
join: "INNER JOIN users AS u ON u.id = posts.user_id",
|
||||
expected: "u",
|
||||
},
|
||||
{
|
||||
name: "Simple JOIN with alias",
|
||||
join: "JOIN roles r ON r.id = user_roles.role_id",
|
||||
expected: "r",
|
||||
},
|
||||
{
|
||||
name: "JOIN without alias (just table name)",
|
||||
join: "JOIN departments ON departments.id = employees.dept_id",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "RIGHT JOIN with alias",
|
||||
join: "RIGHT JOIN orders o ON o.customer_id = customers.id",
|
||||
expected: "o",
|
||||
},
|
||||
{
|
||||
name: "FULL OUTER JOIN with AS",
|
||||
join: "FULL OUTER JOIN products AS p ON p.id = order_items.product_id",
|
||||
expected: "p",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractJoinAlias(tt.join)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractJoinAlias(%q) = %q, want %q", tt.join, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to check if a string contains a substring
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && containsHelper(s, substr))
|
||||
|
||||
391
pkg/restheadspec/recursive_preload_test.go
Normal file
391
pkg/restheadspec/recursive_preload_test.go
Normal file
@@ -0,0 +1,391 @@
|
||||
//go:build !integration
|
||||
// +build !integration
|
||||
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
)
|
||||
|
||||
// TestRecursivePreloadClearsWhereClause tests that recursive preloads
|
||||
// correctly clear the WHERE clause from the parent level to allow
|
||||
// Bun to use foreign key relationships for loading children
|
||||
func TestRecursivePreloadClearsWhereClause(t *testing.T) {
|
||||
// Create a mock handler
|
||||
handler := &Handler{}
|
||||
|
||||
// Create a preload option with a WHERE clause that filters root items
|
||||
// This simulates the xfiles use case where the first level has a filter
|
||||
// like "rid_parentmastertaskitem is null" to get root items
|
||||
preload := common.PreloadOption{
|
||||
Relation: "MastertaskItems",
|
||||
Recursive: true,
|
||||
RelatedKey: "rid_parentmastertaskitem",
|
||||
Where: "rid_parentmastertaskitem is null",
|
||||
Filters: []common.FilterOption{
|
||||
{
|
||||
Column: "rid_parentmastertaskitem",
|
||||
Operator: "is null",
|
||||
Value: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create a mock query that tracks operations
|
||||
mockQuery := &mockSelectQuery{
|
||||
operations: []string{},
|
||||
}
|
||||
|
||||
// Apply the recursive preload at depth 0
|
||||
// This should:
|
||||
// 1. Apply the initial preload with the WHERE clause
|
||||
// 2. Create a recursive preload without the WHERE clause
|
||||
allPreloads := []common.PreloadOption{preload}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
|
||||
|
||||
// Verify the mock query received the operations
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// Check that we have at least 2 PreloadRelation calls:
|
||||
// 1. The initial "MastertaskItems" with WHERE clause
|
||||
// 2. The recursive "MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM" without WHERE clause
|
||||
preloadCount := 0
|
||||
recursivePreloadFound := false
|
||||
whereAppliedToRecursive := false
|
||||
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MastertaskItems" {
|
||||
preloadCount++
|
||||
}
|
||||
if op == "PreloadRelation:MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM" {
|
||||
recursivePreloadFound = true
|
||||
}
|
||||
// Check if WHERE was applied to the recursive preload (it shouldn't be)
|
||||
if op == "Where:rid_parentmastertaskitem is null" && recursivePreloadFound {
|
||||
whereAppliedToRecursive = true
|
||||
}
|
||||
}
|
||||
|
||||
if preloadCount < 1 {
|
||||
t.Errorf("Expected at least 1 PreloadRelation call, got %d", preloadCount)
|
||||
}
|
||||
|
||||
if !recursivePreloadFound {
|
||||
t.Errorf("Expected recursive preload 'MastertaskItems.MastertaskItems_RID_PARENTMASTERTASKITEM' to be created. Operations: %v", mock.operations)
|
||||
}
|
||||
|
||||
if whereAppliedToRecursive {
|
||||
t.Error("WHERE clause should not be applied to recursive preload levels")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecursivePreloadWithChildRelations tests that child relations
|
||||
// (like DEF in MAL.DEF) are properly extended to recursive levels
|
||||
func TestRecursivePreloadWithChildRelations(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
// Create the main recursive preload
|
||||
recursivePreload := common.PreloadOption{
|
||||
Relation: "MAL",
|
||||
Recursive: true,
|
||||
RelatedKey: "rid_parentmastertaskitem",
|
||||
Where: "rid_parentmastertaskitem is null",
|
||||
}
|
||||
|
||||
// Create a child relation that should be extended
|
||||
childPreload := common.PreloadOption{
|
||||
Relation: "MAL.DEF",
|
||||
}
|
||||
|
||||
mockQuery := &mockSelectQuery{
|
||||
operations: []string{},
|
||||
}
|
||||
|
||||
allPreloads := []common.PreloadOption{recursivePreload, childPreload}
|
||||
|
||||
// Apply both preloads - the child preload should be extended when the recursive one processes
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, allPreloads, nil, 0)
|
||||
|
||||
// Also need to apply the child preload separately (as would happen in normal flow)
|
||||
result = handler.applyPreloadWithRecursion(result, childPreload, allPreloads, nil, 0)
|
||||
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// Check that the child relation was extended to recursive levels
|
||||
// We should see:
|
||||
// - MAL (with WHERE)
|
||||
// - MAL.DEF
|
||||
// - MAL.MAL_RID_PARENTMASTERTASKITEM (without WHERE)
|
||||
// - MAL.MAL_RID_PARENTMASTERTASKITEM.DEF (extended by recursive logic)
|
||||
foundMALDEF := false
|
||||
foundRecursiveMAL := false
|
||||
foundMALMALDEF := false
|
||||
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MAL.DEF" {
|
||||
foundMALDEF = true
|
||||
}
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundRecursiveMAL = true
|
||||
}
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" {
|
||||
foundMALMALDEF = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundMALDEF {
|
||||
t.Errorf("Expected child preload 'MAL.DEF' to be applied. Operations: %v", mock.operations)
|
||||
}
|
||||
|
||||
if !foundRecursiveMAL {
|
||||
t.Errorf("Expected recursive preload 'MAL.MAL_RID_PARENTMASTERTASKITEM' to be created. Operations: %v", mock.operations)
|
||||
}
|
||||
|
||||
if !foundMALMALDEF {
|
||||
t.Errorf("Expected child preload to be extended to 'MAL.MAL_RID_PARENTMASTERTASKITEM.DEF' at recursive level. Operations: %v", mock.operations)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRecursivePreloadGeneratesCorrectRelationName tests that the recursive
|
||||
// preload generates the correct FK-based relation name using RelatedKey
|
||||
func TestRecursivePreloadGeneratesCorrectRelationName(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
// Test case 1: With RelatedKey - should generate FK-based name
|
||||
t.Run("WithRelatedKey", func(t *testing.T) {
|
||||
preload := common.PreloadOption{
|
||||
Relation: "MAL",
|
||||
Recursive: true,
|
||||
RelatedKey: "rid_parentmastertaskitem",
|
||||
}
|
||||
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
allPreloads := []common.PreloadOption{preload}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
|
||||
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// Should generate MAL.MAL_RID_PARENTMASTERTASKITEM
|
||||
foundCorrectRelation := false
|
||||
foundIncorrectRelation := false
|
||||
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundCorrectRelation = true
|
||||
}
|
||||
if op == "PreloadRelation:MAL.MAL" {
|
||||
foundIncorrectRelation = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundCorrectRelation {
|
||||
t.Errorf("Expected 'MAL.MAL_RID_PARENTMASTERTASKITEM' relation, operations: %v", mock.operations)
|
||||
}
|
||||
|
||||
if foundIncorrectRelation {
|
||||
t.Error("Should NOT generate 'MAL.MAL' relation when RelatedKey is specified")
|
||||
}
|
||||
})
|
||||
|
||||
// Test case 2: Without RelatedKey - should fallback to old behavior
|
||||
t.Run("WithoutRelatedKey", func(t *testing.T) {
|
||||
preload := common.PreloadOption{
|
||||
Relation: "MAL",
|
||||
Recursive: true,
|
||||
// No RelatedKey
|
||||
}
|
||||
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
allPreloads := []common.PreloadOption{preload}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 0)
|
||||
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// Should fallback to MAL.MAL
|
||||
foundFallback := false
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MAL.MAL" {
|
||||
foundFallback = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundFallback {
|
||||
t.Errorf("Expected fallback 'MAL.MAL' relation when no RelatedKey, operations: %v", mock.operations)
|
||||
}
|
||||
})
|
||||
|
||||
// Test case 3: Depth limit of 8
|
||||
t.Run("DepthLimit", func(t *testing.T) {
|
||||
preload := common.PreloadOption{
|
||||
Relation: "MAL",
|
||||
Recursive: true,
|
||||
RelatedKey: "rid_parentmastertaskitem",
|
||||
}
|
||||
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
allPreloads := []common.PreloadOption{preload}
|
||||
|
||||
// Start at depth 7 - should create one more level
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 7)
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
foundDepth8 := false
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundDepth8 = true
|
||||
}
|
||||
}
|
||||
|
||||
if !foundDepth8 {
|
||||
t.Error("Expected to create recursive level at depth 8")
|
||||
}
|
||||
|
||||
// Start at depth 8 - should NOT create another level
|
||||
mockQuery2 := &mockSelectQuery{operations: []string{}}
|
||||
result2 := handler.applyPreloadWithRecursion(mockQuery2, preload, allPreloads, nil, 8)
|
||||
mock2 := result2.(*mockSelectQuery)
|
||||
|
||||
foundDepth9 := false
|
||||
for _, op := range mock2.operations {
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundDepth9 = true
|
||||
}
|
||||
}
|
||||
|
||||
if foundDepth9 {
|
||||
t.Error("Should NOT create recursive level beyond depth 8")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// mockSelectQuery implements common.SelectQuery for testing
|
||||
type mockSelectQuery struct {
|
||||
operations []string
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Model")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Table(table string) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Table:"+table)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Column(columns ...string) common.SelectQuery {
|
||||
for _, col := range columns {
|
||||
m.operations = append(m.operations, "Column:"+col)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "ColumnExpr:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Where:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "WhereOr:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "WhereIn:"+column)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Order(order string) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Order:"+order)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "OrderExpr:"+order)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Limit(limit int) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Limit")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Offset(offset int) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Offset")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Join(join string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Join:"+join)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) LeftJoin(join string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "LeftJoin:"+join)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Group(columns string) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Group")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Having(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Having:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Preload:"+relation)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
m.operations = append(m.operations, "PreloadRelation:"+relation)
|
||||
// Apply the preload modifiers
|
||||
for _, fn := range apply {
|
||||
fn(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
m.operations = append(m.operations, "JoinRelation:"+relation)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||
m.operations = append(m.operations, "Scan")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) ScanModel(ctx context.Context) error {
|
||||
m.operations = append(m.operations, "ScanModel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) {
|
||||
m.operations = append(m.operations, "Count")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) {
|
||||
m.operations = append(m.operations, "Exists")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) GetUnderlyingQuery() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) GetModel() interface{} {
|
||||
return nil
|
||||
}
|
||||
@@ -32,6 +32,7 @@
|
||||
// - X-Clean-JSON: Boolean to remove null/empty fields
|
||||
// - X-Custom-SQL-Where: Custom SQL WHERE clause (AND)
|
||||
// - X-Custom-SQL-Or: Custom SQL WHERE clause (OR)
|
||||
// - X-Custom-SQL-Join: Custom SQL JOIN clauses (pipe-separated for multiple)
|
||||
//
|
||||
// # Usage Example
|
||||
//
|
||||
@@ -103,8 +104,9 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
||||
openAPIHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
})
|
||||
muxRouter.Handle("/openapi", openAPIHandler).Methods("GET", "OPTIONS")
|
||||
@@ -123,17 +125,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
||||
metadataPath := buildRoutePath(schema, entity) + "/metadata"
|
||||
|
||||
// Create handler functions for this specific entity
|
||||
entityHandler := createMuxHandler(handler, schema, entity, "")
|
||||
entityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
||||
metadataHandler := createMuxGetHandler(handler, schema, entity, "")
|
||||
var entityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
|
||||
var entityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
|
||||
var metadataHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
|
||||
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"})
|
||||
|
||||
// Apply authentication middleware if provided
|
||||
if authMiddleware != nil {
|
||||
entityHandler = authMiddleware(entityHandler).(http.HandlerFunc)
|
||||
entityWithIDHandler = authMiddleware(entityWithIDHandler).(http.HandlerFunc)
|
||||
metadataHandler = authMiddleware(metadataHandler).(http.HandlerFunc)
|
||||
entityHandler = authMiddleware(entityHandler)
|
||||
entityWithIDHandler = authMiddleware(entityWithIDHandler)
|
||||
metadataHandler = authMiddleware(metadataHandler)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
}
|
||||
|
||||
@@ -161,7 +163,8 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
|
||||
// Set CORS headers
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
@@ -169,7 +172,7 @@ func createMuxHandler(handler *Handler, schema, entity, idParam string) http.Han
|
||||
if idParam != "" {
|
||||
vars["id"] = mux.Vars(r)[idParam]
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
@@ -180,7 +183,8 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
|
||||
// Set CORS headers
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
@@ -188,7 +192,7 @@ func createMuxGetHandler(handler *Handler, schema, entity, idParam string) http.
|
||||
if idParam != "" {
|
||||
vars["id"] = mux.Vars(r)[idParam]
|
||||
}
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
@@ -200,13 +204,14 @@ func createMuxOptionsHandler(handler *Handler, schema, entity string, allowedMet
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
corsConfig.AllowedMethods = allowedMethods
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
|
||||
// Return metadata in the OPTIONS response body
|
||||
vars := make(map[string]string)
|
||||
vars["schema"] = schema
|
||||
vars["entity"] = entity
|
||||
reqAdapter := router.NewHTTPRequest(r)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, vars)
|
||||
}
|
||||
}
|
||||
@@ -275,9 +280,34 @@ type BunRouterHandler interface {
|
||||
Handle(method, path string, handler bunrouter.HandlerFunc)
|
||||
}
|
||||
|
||||
// wrapBunRouterHandler wraps a bunrouter handler with auth middleware if provided
|
||||
func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware MiddlewareFunc) bunrouter.HandlerFunc {
|
||||
if authMiddleware == nil {
|
||||
return handler
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
// Create an http.Handler that calls the bunrouter handler
|
||||
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Replace the embedded *http.Request with the middleware-enriched one
|
||||
// so that auth context (user ID, etc.) is visible to the handler.
|
||||
enrichedReq := req
|
||||
enrichedReq.Request = r
|
||||
_ = handler(w, enrichedReq)
|
||||
})
|
||||
|
||||
// Wrap with auth middleware and execute
|
||||
wrappedHandler := authMiddleware(httpHandler)
|
||||
wrappedHandler.ServeHTTP(w, req.Request)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API
|
||||
// Accepts bunrouter.Router or bunrouter.Group
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||
|
||||
// CORS config
|
||||
corsConfig := common.DefaultCORSConfig()
|
||||
@@ -285,15 +315,16 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
// Add global /openapi route
|
||||
r.Handle("GET", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -315,135 +346,155 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
||||
currentEntity := entity
|
||||
|
||||
// GET and POST for /{schema}/{entity}
|
||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
}
|
||||
r.Handle("GET", entityPath, wrapBunRouterHandler(getEntityHandler, authMiddleware))
|
||||
|
||||
postEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityPath, wrapBunRouterHandler(postEntityHandler, authMiddleware))
|
||||
|
||||
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}/:id
|
||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
getEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", entityWithIDPath, wrapBunRouterHandler(getEntityWithIDHandler, authMiddleware))
|
||||
|
||||
r.Handle("PUT", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
postEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
r.Handle("PATCH", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("POST", entityWithIDPath, wrapBunRouterHandler(postEntityWithIDHandler, authMiddleware))
|
||||
|
||||
r.Handle("DELETE", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
putEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("PUT", entityWithIDPath, wrapBunRouterHandler(putEntityWithIDHandler, authMiddleware))
|
||||
|
||||
patchEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
}
|
||||
r.Handle("PATCH", entityWithIDPath, wrapBunRouterHandler(patchEntityWithIDHandler, authMiddleware))
|
||||
|
||||
deleteEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
"id": req.Param("id"),
|
||||
}
|
||||
|
||||
handler.Handle(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
}
|
||||
r.Handle("DELETE", entityWithIDPath, wrapBunRouterHandler(deleteEntityWithIDHandler, authMiddleware))
|
||||
|
||||
// Metadata endpoint
|
||||
r.Handle("GET", metadataPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
metadataHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
common.SetCORSHeaders(respAdapter, corsConfig)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
r.Handle("GET", metadataPath, wrapBunRouterHandler(metadataHandler, authMiddleware))
|
||||
|
||||
// OPTIONS route without ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
optionsCorsConfig := corsConfig
|
||||
optionsCorsConfig.AllowedMethods = []string{"GET", "POST", "OPTIONS"}
|
||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
|
||||
// OPTIONS route with ID (returns metadata)
|
||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||
respAdapter := router.NewHTTPResponseWriter(w)
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
optionsCorsConfig := corsConfig
|
||||
optionsCorsConfig.AllowedMethods = []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"}
|
||||
common.SetCORSHeaders(respAdapter, optionsCorsConfig)
|
||||
common.SetCORSHeaders(respAdapter, reqAdapter, optionsCorsConfig)
|
||||
params := map[string]string{
|
||||
"schema": currentSchema,
|
||||
"entity": currentEntity,
|
||||
}
|
||||
reqAdapter := router.NewBunRouterRequest(req)
|
||||
|
||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||
return nil
|
||||
})
|
||||
@@ -458,8 +509,8 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
|
||||
// Create bunrouter
|
||||
bunRouter := bunrouter.New()
|
||||
|
||||
// Setup routes
|
||||
SetupBunRouterRoutes(bunRouter, handler)
|
||||
// Setup routes without authentication
|
||||
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||
|
||||
// Start server
|
||||
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
||||
@@ -479,7 +530,7 @@ func ExampleBunRouterWithGroup(bunDB *bun.DB) {
|
||||
apiGroup := bunRouter.NewGroup("/api")
|
||||
|
||||
// Setup RestHeadSpec routes on the group - routes will be under /api
|
||||
SetupBunRouterRoutes(apiGroup, handler)
|
||||
SetupBunRouterRoutes(apiGroup, handler, nil)
|
||||
|
||||
// Start server
|
||||
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package restheadspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
@@ -9,6 +10,17 @@ import (
|
||||
|
||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||
hookCtx.Abort = true
|
||||
hookCtx.AbortMessage = err.Error()
|
||||
hookCtx.AbortCode = http.StatusUnauthorized
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Hook 1: BeforeRead - Load security rules
|
||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
@@ -33,6 +45,18 @@ func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList
|
||||
return security.LogDataAccess(secCtx)
|
||||
})
|
||||
|
||||
// Hook 5: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelUpdateAllowed(secCtx)
|
||||
})
|
||||
|
||||
// Hook 6: BeforeDelete - enforce CanDelete rule from context/registry
|
||||
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||
secCtx := newSecurityContext(hookCtx)
|
||||
return security.CheckModelDeleteAllowed(secCtx)
|
||||
})
|
||||
|
||||
logger.Info("Security hooks registered for restheadspec handler")
|
||||
}
|
||||
|
||||
|
||||
527
pkg/restheadspec/xfiles_integration_test.go
Normal file
527
pkg/restheadspec/xfiles_integration_test.go
Normal file
@@ -0,0 +1,527 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package restheadspec
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockSelectQuery implements common.SelectQuery for testing (integration version)
|
||||
type mockSelectQuery struct {
|
||||
operations []string
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Model")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Table(table string) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Table:"+table)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Column(columns ...string) common.SelectQuery {
|
||||
for _, col := range columns {
|
||||
m.operations = append(m.operations, "Column:"+col)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) ColumnExpr(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "ColumnExpr:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Where:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "WhereOr:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) WhereIn(column string, values interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "WhereIn:"+column)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Order(order string) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Order:"+order)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) OrderExpr(order string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "OrderExpr:"+order)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Limit(limit int) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Limit")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Offset(offset int) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Offset")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Join(join string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Join:"+join)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) LeftJoin(join string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "LeftJoin:"+join)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Group(columns string) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Group")
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Having(query string, args ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Having:"+query)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Preload(relation string, conditions ...interface{}) common.SelectQuery {
|
||||
m.operations = append(m.operations, "Preload:"+relation)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) PreloadRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
m.operations = append(m.operations, "PreloadRelation:"+relation)
|
||||
// Apply the preload modifiers
|
||||
for _, fn := range apply {
|
||||
fn(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) JoinRelation(relation string, apply ...func(common.SelectQuery) common.SelectQuery) common.SelectQuery {
|
||||
m.operations = append(m.operations, "JoinRelation:"+relation)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||
m.operations = append(m.operations, "Scan")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) ScanModel(ctx context.Context) error {
|
||||
m.operations = append(m.operations, "ScanModel")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Count(ctx context.Context) (int, error) {
|
||||
m.operations = append(m.operations, "Count")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) Exists(ctx context.Context) (bool, error) {
|
||||
m.operations = append(m.operations, "Exists")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) GetUnderlyingQuery() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSelectQuery) GetModel() interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestXFilesRecursivePreload is an integration test that validates the XFiles
|
||||
// recursive preload functionality using real test data files.
|
||||
//
|
||||
// This test ensures:
|
||||
// 1. XFiles request JSON is correctly parsed into PreloadOptions
|
||||
// 2. Recursive preload generates correct FK-based relation names (MAL_RID_PARENTMASTERTASKITEM)
|
||||
// 3. Parent WHERE clauses don't leak to child levels
|
||||
// 4. Child relations (like DEF) are extended to all recursive levels
|
||||
// 5. Hierarchical data structure matches expected output
|
||||
func TestXFilesRecursivePreload(t *testing.T) {
|
||||
// Load the XFiles request configuration
|
||||
requestPath := filepath.Join("..", "..", "tests", "data", "xfiles.request.json")
|
||||
requestData, err := os.ReadFile(requestPath)
|
||||
require.NoError(t, err, "Failed to read xfiles.request.json")
|
||||
|
||||
var xfileConfig XFiles
|
||||
err = json.Unmarshal(requestData, &xfileConfig)
|
||||
require.NoError(t, err, "Failed to parse xfiles.request.json")
|
||||
|
||||
// Create handler and parse XFiles into PreloadOptions
|
||||
handler := &Handler{}
|
||||
options := &ExtendedRequestOptions{
|
||||
RequestOptions: common.RequestOptions{
|
||||
Preload: []common.PreloadOption{},
|
||||
},
|
||||
}
|
||||
|
||||
// Process the XFiles configuration - start with the root table
|
||||
handler.processXFilesRelations(&xfileConfig, options, "")
|
||||
|
||||
// Verify that preload options were created
|
||||
require.NotEmpty(t, options.Preload, "Expected preload options to be created")
|
||||
|
||||
// Test 1: Verify mastertaskitem preload is marked as recursive with correct RelatedKey
|
||||
t.Run("RecursivePreloadHasRelatedKey", func(t *testing.T) {
|
||||
// Find the mastertaskitem preload - it should be marked as recursive
|
||||
var recursivePreload *common.PreloadOption
|
||||
for i := range options.Preload {
|
||||
preload := &options.Preload[i]
|
||||
if preload.Relation == "MTL.MAL" && preload.Recursive {
|
||||
recursivePreload = preload
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, recursivePreload, "Expected to find recursive mastertaskitem preload MTL.MAL")
|
||||
|
||||
// RelatedKey should be the parent relationship key (MTL -> MAL)
|
||||
assert.Equal(t, "rid_mastertask", recursivePreload.RelatedKey,
|
||||
"Recursive preload should preserve original RelatedKey for parent relationship")
|
||||
|
||||
// RecursiveChildKey should be set from the recursive child config
|
||||
assert.Equal(t, "rid_parentmastertaskitem", recursivePreload.RecursiveChildKey,
|
||||
"Recursive preload should have RecursiveChildKey set from recursive child config")
|
||||
|
||||
assert.True(t, recursivePreload.Recursive, "mastertaskitem preload should be marked as recursive")
|
||||
})
|
||||
|
||||
// Test 2: Verify mastertaskitem has WHERE clause for filtering root items
|
||||
t.Run("RootLevelHasWhereClause", func(t *testing.T) {
|
||||
var rootPreload *common.PreloadOption
|
||||
for i := range options.Preload {
|
||||
preload := &options.Preload[i]
|
||||
if preload.Relation == "MTL.MAL" {
|
||||
rootPreload = preload
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, rootPreload, "Expected to find mastertaskitem preload")
|
||||
assert.NotEmpty(t, rootPreload.Where, "Mastertaskitem should have WHERE clause")
|
||||
// The WHERE clause should filter for root items (rid_parentmastertaskitem is null)
|
||||
assert.True(t, rootPreload.Recursive, "Mastertaskitem preload should be marked as recursive")
|
||||
})
|
||||
|
||||
// Test 3: Verify actiondefinition relation exists for mastertaskitem
|
||||
t.Run("DEFRelationExists", func(t *testing.T) {
|
||||
var defPreload *common.PreloadOption
|
||||
for i := range options.Preload {
|
||||
preload := &options.Preload[i]
|
||||
if preload.Relation == "MTL.MAL.DEF" {
|
||||
defPreload = preload
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, defPreload, "Expected to find actiondefinition preload for mastertaskitem")
|
||||
assert.Equal(t, "rid_actiondefinition", defPreload.ForeignKey,
|
||||
"actiondefinition preload should have ForeignKey set")
|
||||
})
|
||||
|
||||
// Test 4: Verify relation name generation with mock query
|
||||
t.Run("RelationNameGeneration", func(t *testing.T) {
|
||||
// Find the mastertaskitem preload - it should be marked as recursive
|
||||
var recursivePreload common.PreloadOption
|
||||
found := false
|
||||
for _, preload := range options.Preload {
|
||||
if preload.Relation == "MTL.MAL" && preload.Recursive {
|
||||
recursivePreload = preload
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL")
|
||||
|
||||
// Create mock query to track operations
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
|
||||
// Apply the recursive preload
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// Verify the correct FK-based relation name was generated
|
||||
foundCorrectRelation := false
|
||||
|
||||
for _, op := range mock.operations {
|
||||
// Should generate: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM
|
||||
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundCorrectRelation = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundCorrectRelation,
|
||||
"Expected FK-based relation name 'MTL.MAL.MAL_RID_PARENTMASTERTASKITEM' to be generated. Operations: %v",
|
||||
mock.operations)
|
||||
})
|
||||
|
||||
// Test 5: Verify WHERE clause is cleared for recursive levels
|
||||
t.Run("WhereClauseClearedForChildren", func(t *testing.T) {
|
||||
// Find the mastertaskitem preload - it should be marked as recursive
|
||||
var recursivePreload common.PreloadOption
|
||||
found := false
|
||||
for _, preload := range options.Preload {
|
||||
if preload.Relation == "MTL.MAL" && preload.Recursive {
|
||||
recursivePreload = preload
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, found, "Expected to find recursive mastertaskitem preload MTL.MAL")
|
||||
|
||||
// The root level has a WHERE clause (rid_parentmastertaskitem is null)
|
||||
// But when we apply recursion, it should be cleared
|
||||
assert.NotEmpty(t, recursivePreload.Where, "Root preload should have WHERE clause")
|
||||
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// After the first level, WHERE clauses should not be reapplied
|
||||
// We check that the recursive relation was created (which means WHERE was cleared internally)
|
||||
foundRecursiveRelation := false
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundRecursiveRelation = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundRecursiveRelation,
|
||||
"Recursive relation should be created (WHERE clause should be cleared internally)")
|
||||
})
|
||||
|
||||
// Test 6: Verify child relations are extended to recursive levels
|
||||
t.Run("ChildRelationsExtended", func(t *testing.T) {
|
||||
// Find the mastertaskitem preload - it should be marked as recursive
|
||||
var recursivePreload common.PreloadOption
|
||||
foundRecursive := false
|
||||
|
||||
for _, preload := range options.Preload {
|
||||
if preload.Relation == "MTL.MAL" && preload.Recursive {
|
||||
recursivePreload = preload
|
||||
foundRecursive = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
require.True(t, foundRecursive, "Expected to find recursive mastertaskitem preload MTL.MAL")
|
||||
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, recursivePreload, options.Preload, nil, 0)
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
// actiondefinition should be extended to the recursive level
|
||||
// Expected: MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF
|
||||
foundExtendedDEF := false
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MTL.MAL.MAL_RID_PARENTMASTERTASKITEM.DEF" {
|
||||
foundExtendedDEF = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundExtendedDEF,
|
||||
"Expected actiondefinition relation to be extended to recursive level. Operations: %v",
|
||||
mock.operations)
|
||||
})
|
||||
}
|
||||
|
||||
// TestXFilesRecursivePreloadDepth tests that recursive preloads respect the depth limit of 8
|
||||
func TestXFilesRecursivePreloadDepth(t *testing.T) {
|
||||
handler := &Handler{}
|
||||
|
||||
preload := common.PreloadOption{
|
||||
Relation: "MAL",
|
||||
Recursive: true,
|
||||
RelatedKey: "rid_parentmastertaskitem",
|
||||
}
|
||||
|
||||
allPreloads := []common.PreloadOption{preload}
|
||||
|
||||
t.Run("Depth7CreatesLevel8", func(t *testing.T) {
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 7)
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
foundDepth8 := false
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundDepth8 = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(t, foundDepth8, "Should create level 8 when starting at depth 7")
|
||||
})
|
||||
|
||||
t.Run("Depth8DoesNotCreateLevel9", func(t *testing.T) {
|
||||
mockQuery := &mockSelectQuery{operations: []string{}}
|
||||
result := handler.applyPreloadWithRecursion(mockQuery, preload, allPreloads, nil, 8)
|
||||
mock := result.(*mockSelectQuery)
|
||||
|
||||
foundDepth9 := false
|
||||
for _, op := range mock.operations {
|
||||
if op == "PreloadRelation:MAL.MAL_RID_PARENTMASTERTASKITEM" {
|
||||
foundDepth9 = true
|
||||
}
|
||||
}
|
||||
|
||||
assert.False(t, foundDepth9, "Should NOT create level 9 (depth limit is 8)")
|
||||
})
|
||||
}
|
||||
|
||||
// TestXFilesResponseStructure validates the actual structure of the response
|
||||
// This test can be expanded when we have a full database integration test environment
|
||||
func TestXFilesResponseStructure(t *testing.T) {
|
||||
// Load the expected correct response
|
||||
correctResponsePath := filepath.Join("..", "..", "tests", "data", "xfiles.response.correct.json")
|
||||
correctData, err := os.ReadFile(correctResponsePath)
|
||||
require.NoError(t, err, "Failed to read xfiles.response.correct.json")
|
||||
|
||||
var correctResponse []map[string]interface{}
|
||||
err = json.Unmarshal(correctData, &correctResponse)
|
||||
require.NoError(t, err, "Failed to parse xfiles.response.correct.json")
|
||||
|
||||
// Test 1: Verify root level has exactly 1 masterprocess
|
||||
t.Run("RootLevelHasOneItem", func(t *testing.T) {
|
||||
assert.Len(t, correctResponse, 1, "Root level should have exactly 1 masterprocess record")
|
||||
})
|
||||
|
||||
// Test 2: Verify the root item has MTL relation
|
||||
t.Run("RootHasMTLRelation", func(t *testing.T) {
|
||||
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||
rootItem := correctResponse[0]
|
||||
|
||||
mtl, exists := rootItem["MTL"]
|
||||
assert.True(t, exists, "Root item should have MTL relation")
|
||||
assert.NotNil(t, mtl, "MTL relation should not be null")
|
||||
})
|
||||
|
||||
// Test 3: Verify MTL has MAL items
|
||||
t.Run("MTLHasMALItems", func(t *testing.T) {
|
||||
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||
rootItem := correctResponse[0]
|
||||
|
||||
mtl, ok := rootItem["MTL"].([]interface{})
|
||||
require.True(t, ok, "MTL should be an array")
|
||||
require.NotEmpty(t, mtl, "MTL should have items")
|
||||
|
||||
firstMTL, ok := mtl[0].(map[string]interface{})
|
||||
require.True(t, ok, "MTL item should be a map")
|
||||
|
||||
mal, exists := firstMTL["MAL"]
|
||||
assert.True(t, exists, "MTL item should have MAL relation")
|
||||
assert.NotNil(t, mal, "MAL relation should not be null")
|
||||
})
|
||||
|
||||
// Test 4: Verify MAL items have MAL_RID_PARENTMASTERTASKITEM relation (recursive)
|
||||
t.Run("MALHasRecursiveRelation", func(t *testing.T) {
|
||||
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||
rootItem := correctResponse[0]
|
||||
|
||||
mtl, ok := rootItem["MTL"].([]interface{})
|
||||
require.True(t, ok, "MTL should be an array")
|
||||
require.NotEmpty(t, mtl, "MTL should have items")
|
||||
|
||||
firstMTL, ok := mtl[0].(map[string]interface{})
|
||||
require.True(t, ok, "MTL item should be a map")
|
||||
|
||||
mal, ok := firstMTL["MAL"].([]interface{})
|
||||
require.True(t, ok, "MAL should be an array")
|
||||
require.NotEmpty(t, mal, "MAL should have items")
|
||||
|
||||
firstMAL, ok := mal[0].(map[string]interface{})
|
||||
require.True(t, ok, "MAL item should be a map")
|
||||
|
||||
// The key assertion: check for FK-based relation name
|
||||
recursiveRelation, exists := firstMAL["MAL_RID_PARENTMASTERTASKITEM"]
|
||||
assert.True(t, exists,
|
||||
"MAL item should have MAL_RID_PARENTMASTERTASKITEM relation (FK-based name)")
|
||||
|
||||
// It can be null or an array, depending on whether this item has children
|
||||
if recursiveRelation != nil {
|
||||
_, isArray := recursiveRelation.([]interface{})
|
||||
assert.True(t, isArray,
|
||||
"MAL_RID_PARENTMASTERTASKITEM should be an array when not null")
|
||||
}
|
||||
})
|
||||
|
||||
// Test 5: Verify "Receive COB Document for" appears as a child, not at root
|
||||
t.Run("ChildItemsAreNested", func(t *testing.T) {
|
||||
// This test verifies that "Receive COB Document for" doesn't appear
|
||||
// multiple times at the wrong level, but is properly nested
|
||||
|
||||
// Count how many times we find this description at the MAL level (should be 0 or 1)
|
||||
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||
rootItem := correctResponse[0]
|
||||
|
||||
mtl, ok := rootItem["MTL"].([]interface{})
|
||||
require.True(t, ok, "MTL should be an array")
|
||||
require.NotEmpty(t, mtl, "MTL should have items")
|
||||
|
||||
firstMTL, ok := mtl[0].(map[string]interface{})
|
||||
require.True(t, ok, "MTL item should be a map")
|
||||
|
||||
mal, ok := firstMTL["MAL"].([]interface{})
|
||||
require.True(t, ok, "MAL should be an array")
|
||||
|
||||
// Count root-level MAL items (before the fix, there were 12; should be 1)
|
||||
assert.Len(t, mal, 1,
|
||||
"MAL should have exactly 1 root-level item (before fix: 12 duplicates)")
|
||||
|
||||
// Verify the root item has a description
|
||||
firstMAL, ok := mal[0].(map[string]interface{})
|
||||
require.True(t, ok, "MAL item should be a map")
|
||||
|
||||
description, exists := firstMAL["description"]
|
||||
assert.True(t, exists, "MAL item should have a description")
|
||||
assert.Equal(t, "Capture COB Information", description,
|
||||
"Root MAL item should be 'Capture COB Information'")
|
||||
})
|
||||
|
||||
// Test 6: Verify DEF relation exists at MAL level
|
||||
t.Run("DEFRelationExists", func(t *testing.T) {
|
||||
require.NotEmpty(t, correctResponse, "Response should not be empty")
|
||||
rootItem := correctResponse[0]
|
||||
|
||||
mtl, ok := rootItem["MTL"].([]interface{})
|
||||
require.True(t, ok, "MTL should be an array")
|
||||
require.NotEmpty(t, mtl, "MTL should have items")
|
||||
|
||||
firstMTL, ok := mtl[0].(map[string]interface{})
|
||||
require.True(t, ok, "MTL item should be a map")
|
||||
|
||||
mal, ok := firstMTL["MAL"].([]interface{})
|
||||
require.True(t, ok, "MAL should be an array")
|
||||
require.NotEmpty(t, mal, "MAL should have items")
|
||||
|
||||
firstMAL, ok := mal[0].(map[string]interface{})
|
||||
require.True(t, ok, "MAL item should be a map")
|
||||
|
||||
// Verify DEF relation exists (child relation extension)
|
||||
def, exists := firstMAL["DEF"]
|
||||
assert.True(t, exists, "MAL item should have DEF relation")
|
||||
|
||||
// DEF can be null or an object
|
||||
if def != nil {
|
||||
_, isMap := def.(map[string]interface{})
|
||||
assert.True(t, isMap, "DEF should be an object when not null")
|
||||
}
|
||||
})
|
||||
}
|
||||
527
pkg/security/OAUTH2.md
Normal file
527
pkg/security/OAUTH2.md
Normal file
@@ -0,0 +1,527 @@
|
||||
# OAuth2 Authentication Guide
|
||||
|
||||
## Overview
|
||||
|
||||
The security package provides OAuth2 authentication support for any OAuth2-compliant provider including Google, GitHub, Microsoft, Facebook, and custom providers.
|
||||
|
||||
## Features
|
||||
|
||||
- **Universal OAuth2 Support**: Works with any OAuth2 provider
|
||||
- **Pre-configured Providers**: Google, GitHub, Microsoft, Facebook
|
||||
- **Multi-Provider Support**: Use all OAuth2 providers simultaneously
|
||||
- **Custom Providers**: Easy configuration for any OAuth2 service
|
||||
- **Session Management**: Database-backed session storage
|
||||
- **Token Refresh**: Automatic token refresh support
|
||||
- **State Validation**: Built-in CSRF protection
|
||||
- **User Auto-Creation**: Automatically creates users on first login
|
||||
- **Unified Authentication**: OAuth2 and traditional auth share same session storage
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Database Setup
|
||||
|
||||
```sql
|
||||
-- Run the schema from database_schema.sql
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(255) NOT NULL UNIQUE,
|
||||
email VARCHAR(255) NOT NULL UNIQUE,
|
||||
password VARCHAR(255),
|
||||
user_level INTEGER DEFAULT 0,
|
||||
roles VARCHAR(500),
|
||||
is_active BOOLEAN DEFAULT true,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_login_at TIMESTAMP,
|
||||
remote_id VARCHAR(255),
|
||||
auth_provider VARCHAR(50)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
session_token VARCHAR(500) NOT NULL UNIQUE,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
token_type VARCHAR(50) DEFAULT 'Bearer',
|
||||
auth_provider VARCHAR(50)
|
||||
);
|
||||
|
||||
-- OAuth2 stored procedures (7 functions)
|
||||
-- See database_schema.sql for full implementation
|
||||
```
|
||||
|
||||
### 2. Google OAuth2
|
||||
|
||||
```go
|
||||
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
|
||||
// Create authenticator
|
||||
oauth2Auth := security.NewGoogleAuthenticator(
|
||||
"your-google-client-id",
|
||||
"your-google-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Login route - redirects to Google
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL(state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Callback route - handles Google response
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
})
|
||||
```
|
||||
|
||||
### 3. GitHub OAuth2
|
||||
|
||||
```go
|
||||
oauth2Auth := security.NewGitHubAuthenticator(
|
||||
"your-github-client-id",
|
||||
"your-github-client-secret",
|
||||
"http://localhost:8080/auth/github/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Same routes pattern as Google
|
||||
router.HandleFunc("/auth/github/login", ...)
|
||||
router.HandleFunc("/auth/github/callback", ...)
|
||||
```
|
||||
|
||||
### 4. Microsoft OAuth2
|
||||
|
||||
```go
|
||||
oauth2Auth := security.NewMicrosoftAuthenticator(
|
||||
"your-microsoft-client-id",
|
||||
"your-microsoft-client-secret",
|
||||
"http://localhost:8080/auth/microsoft/callback",
|
||||
db,
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Facebook OAuth2
|
||||
|
||||
```go
|
||||
oauth2Auth := security.NewFacebookAuthenticator(
|
||||
"your-facebook-client-id",
|
||||
"your-facebook-client-secret",
|
||||
"http://localhost:8080/auth/facebook/callback",
|
||||
db,
|
||||
)
|
||||
```
|
||||
|
||||
## Custom OAuth2 Provider
|
||||
|
||||
```go
|
||||
oauth2Auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "your-client-id",
|
||||
ClientSecret: "your-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://your-provider.com/oauth/authorize",
|
||||
TokenURL: "https://your-provider.com/oauth/token",
|
||||
UserInfoURL: "https://your-provider.com/oauth/userinfo",
|
||||
DB: db,
|
||||
ProviderName: "custom",
|
||||
|
||||
// Optional: Custom user info parser
|
||||
UserInfoParser: func(userInfo map[string]any) (*security.UserContext, error) {
|
||||
return &security.UserContext{
|
||||
UserName: userInfo["username"].(string),
|
||||
Email: userInfo["email"].(string),
|
||||
RemoteID: userInfo["id"].(string),
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
Claims: userInfo,
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
```
|
||||
|
||||
## Protected Routes
|
||||
|
||||
```go
|
||||
// Create security provider
|
||||
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
||||
provider, _ := security.NewCompositeSecurityProvider(oauth2Auth, colSec, rowSec)
|
||||
securityList, _ := security.NewSecurityList(provider)
|
||||
|
||||
// Apply middleware to protected routes
|
||||
protectedRouter := router.PathPrefix("/api").Subrouter()
|
||||
protectedRouter.Use(security.NewAuthMiddleware(securityList))
|
||||
protectedRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||
|
||||
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := security.GetUserContext(r.Context())
|
||||
json.NewEncoder(w).Encode(userCtx)
|
||||
})
|
||||
```
|
||||
|
||||
## Token Refresh
|
||||
|
||||
OAuth2 access tokens expire after a period of time. Use the refresh token to obtain a new access token without requiring the user to log in again.
|
||||
|
||||
```go
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Provider string `json:"provider"` // "google", "github", etc.
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Default to google if not specified
|
||||
if req.Provider == "" {
|
||||
req.Provider = "google"
|
||||
}
|
||||
|
||||
// Use OAuth2-specific refresh method
|
||||
loginResp, err := oauth2Auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set new session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
})
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- The refresh token is returned in the `LoginResponse.RefreshToken` field after successful OAuth2 callback
|
||||
- Store the refresh token securely on the client side
|
||||
- Each provider must be configured with the appropriate scopes to receive a refresh token (e.g., `access_type=offline` for Google)
|
||||
- The `OAuth2RefreshToken` method requires the provider name to identify which OAuth2 provider to use for refreshing
|
||||
|
||||
## Logout
|
||||
|
||||
```go
|
||||
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := security.GetUserContext(r.Context())
|
||||
|
||||
oauth2Auth.Logout(r.Context(), security.LogoutRequest{
|
||||
Token: userCtx.SessionID,
|
||||
UserID: userCtx.UserID,
|
||||
})
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: "",
|
||||
MaxAge: -1,
|
||||
})
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
```
|
||||
|
||||
## Multi-Provider Setup
|
||||
|
||||
```go
|
||||
// Single DatabaseAuthenticator with ALL OAuth2 providers
|
||||
auth := security.NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
}).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
ProviderName: "github",
|
||||
})
|
||||
|
||||
// Get list of configured providers
|
||||
providers := auth.OAuth2GetProviders() // ["google", "github"]
|
||||
|
||||
// Google routes
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "google",
|
||||
r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
// ... handle response
|
||||
})
|
||||
|
||||
// GitHub routes
|
||||
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "github",
|
||||
r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
// ... handle response
|
||||
})
|
||||
|
||||
// Use same authenticator for protected routes - works for ALL providers
|
||||
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList, _ := security.NewSecurityList(provider)
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### OAuth2Config Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| ClientID | string | OAuth2 client ID from provider |
|
||||
| ClientSecret | string | OAuth2 client secret |
|
||||
| RedirectURL | string | Callback URL registered with provider |
|
||||
| Scopes | []string | OAuth2 scopes to request |
|
||||
| AuthURL | string | Provider's authorization endpoint |
|
||||
| TokenURL | string | Provider's token endpoint |
|
||||
| UserInfoURL | string | Provider's user info endpoint |
|
||||
| DB | *sql.DB | Database connection for sessions |
|
||||
| UserInfoParser | func | Custom parser for user info (optional) |
|
||||
| StateValidator | func | Custom state validator (optional) |
|
||||
| ProviderName | string | Provider name for logging (optional) |
|
||||
|
||||
## User Info Parsing
|
||||
|
||||
The default parser extracts these standard fields:
|
||||
- `sub` → RemoteID
|
||||
- `email` → Email, UserName
|
||||
- `name` → UserName
|
||||
- `login` → UserName (GitHub)
|
||||
|
||||
Custom parser example:
|
||||
|
||||
```go
|
||||
UserInfoParser: func(userInfo map[string]any) (*security.UserContext, error) {
|
||||
// Extract custom fields
|
||||
ctx := &security.UserContext{
|
||||
UserName: userInfo["preferred_username"].(string),
|
||||
Email: userInfo["email"].(string),
|
||||
RemoteID: userInfo["sub"].(string),
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
Claims: userInfo, // Store all claims
|
||||
}
|
||||
|
||||
// Add custom roles based on provider data
|
||||
if groups, ok := userInfo["groups"].([]interface{}); ok {
|
||||
for _, g := range groups {
|
||||
ctx.Roles = append(ctx.Roles, g.(string))
|
||||
}
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
```
|
||||
|
||||
## Security Best Practices
|
||||
|
||||
1. **Always use HTTPS in production**
|
||||
```go
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Secure: true, // Only send over HTTPS
|
||||
HttpOnly: true, // Prevent XSS access
|
||||
SameSite: http.SameSiteLaxMode, // CSRF protection
|
||||
})
|
||||
```
|
||||
|
||||
2. **Store secrets securely**
|
||||
```go
|
||||
clientID := os.Getenv("GOOGLE_CLIENT_ID")
|
||||
clientSecret := os.Getenv("GOOGLE_CLIENT_SECRET")
|
||||
```
|
||||
|
||||
3. **Validate redirect URLs**
|
||||
- Only register trusted redirect URLs with OAuth2 providers
|
||||
- Never accept redirect URL from request parameters
|
||||
|
||||
5. **Session expiration**
|
||||
- OAuth2 sessions automatically expire based on token expiry
|
||||
- Clean up expired sessions periodically:
|
||||
```sql
|
||||
DELETE FROM user_sessions WHERE expires_at < NOW();
|
||||
```
|
||||
|
||||
4. **State parameter**
|
||||
- Automatically generated with cryptographic randomness
|
||||
- One-time use and expires after 10 minutes
|
||||
- Prevents CSRF attacks
|
||||
|
||||
## Implementation Details
|
||||
|
||||
All database operations use stored procedures for consistency and security:
|
||||
- `resolvespec_oauth_getorcreateuser` - Find or create OAuth2 user
|
||||
- `resolvespec_oauth_createsession` - Create OAuth2 session
|
||||
- `resolvespec_oauth_getsession` - Validate and retrieve session
|
||||
- `resolvespec_oauth_deletesession` - Logout/delete session
|
||||
- `resolvespec_oauth_getrefreshtoken` - Get session by refresh token
|
||||
- `resolvespec_oauth_updaterefreshtoken` - Update tokens after refresh
|
||||
- `resolvespec_oauth_getuser` - Get user data by ID
|
||||
|
||||
## Provider Setup Guides
|
||||
|
||||
### Google
|
||||
|
||||
1. Go to [Google Cloud Console](https://console.cloud.google.com/)
|
||||
2. Create a new project or select existing
|
||||
3. Enable Google+ API
|
||||
4. Create OAuth 2.0 credentials
|
||||
5. Add authorized redirect URI: `http://localhost:8080/auth/google/callback`
|
||||
6. Copy Client ID and Client Secret
|
||||
|
||||
### GitHub
|
||||
|
||||
1. Go to [GitHub Developer Settings](https://github.com/settings/developers)
|
||||
2. Click "New OAuth App"
|
||||
3. Set Homepage URL: `http://localhost:8080`
|
||||
4. Set Authorization callback URL: `http://localhost:8080/auth/github/callback`
|
||||
5. Copy Client ID and Client Secret
|
||||
|
||||
### Microsoft
|
||||
|
||||
1. Go to [Azure Portal](https://portal.azure.com/)
|
||||
2. Register new application in Azure AD
|
||||
3. Add redirect URI: `http://localhost:8080/auth/microsoft/callback`
|
||||
4. Create client secret
|
||||
5. Copy Application (client) ID and secret value
|
||||
|
||||
### Facebook
|
||||
|
||||
1. Go to [Facebook Developers](https://developers.facebook.com/)
|
||||
2. Create new app
|
||||
3. Add Facebook Login product
|
||||
4. Set Valid OAuth Redirect URIs: `http://localhost:8080/auth/facebook/callback`
|
||||
5. Copy App ID and App Secret
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "redirect_uri_mismatch" error
|
||||
- Ensure the redirect URL in code matches exactly with provider configuration
|
||||
- Include protocol (http/https), domain, port, and path
|
||||
|
||||
### "invalid_client" error
|
||||
- Verify Client ID and Client Secret are correct
|
||||
- Check if credentials are for the correct environment (dev/prod)
|
||||
|
||||
### "invalid_grant" error during token exchange
|
||||
- State parameter validation failed
|
||||
- Token might have expired
|
||||
- Check server time synchronization
|
||||
|
||||
### User not created after successful OAuth2 login
|
||||
- Check database constraints (username/email unique)
|
||||
- Verify UserInfoParser is extracting required fields
|
||||
- Check database logs for constraint violations
|
||||
|
||||
## Testing
|
||||
|
||||
```go
|
||||
func TestOAuth2Flow(t *testing.T) {
|
||||
// Mock database
|
||||
db, mock, _ := sqlmock.New()
|
||||
|
||||
oauth2Auth := security.NewGoogleAuthenticator(
|
||||
"test-client-id",
|
||||
"test-client-secret",
|
||||
"http://localhost/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Test state generation
|
||||
state, err := oauth2Auth.GenerateState()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, state)
|
||||
|
||||
// Test auth URL generation
|
||||
authURL := oauth2Auth.GetAuthURL(state)
|
||||
assert.Contains(t, authURL, "accounts.google.com")
|
||||
assert.Contains(t, authURL, state)
|
||||
}
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### DatabaseAuthenticator with OAuth2
|
||||
|
||||
| Method | Description |
|
||||
|--------|-------------|
|
||||
| WithOAuth2(cfg) | Adds OAuth2 provider (can be called multiple times, returns *DatabaseAuthenticator) |
|
||||
| OAuth2GetAuthURL(provider, state) | Returns OAuth2 authorization URL for specified provider |
|
||||
| OAuth2GenerateState() | Generates random state for CSRF protection |
|
||||
| OAuth2HandleCallback(ctx, provider, code, state) | Exchanges code for token and creates session |
|
||||
| OAuth2RefreshToken(ctx, refreshToken, provider) | Refreshes expired access token using refresh token |
|
||||
| OAuth2GetProviders() | Returns list of configured OAuth2 provider names |
|
||||
| Login(ctx, req) | Standard username/password login |
|
||||
| Logout(ctx, req) | Invalidates session (works for both OAuth2 and regular sessions) |
|
||||
| Authenticate(r) | Validates session token from request (works for both OAuth2 and regular sessions) |
|
||||
|
||||
### Pre-configured Constructors
|
||||
|
||||
- `NewGoogleAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||
- `NewGitHubAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||
- `NewMicrosoftAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||
- `NewFacebookAuthenticator(clientID, secret, redirectURL, db)` - Single provider
|
||||
- `NewMultiProviderAuthenticator(db, configs)` - Multiple providers at once
|
||||
|
||||
All return `*DatabaseAuthenticator` with OAuth2 pre-configured.
|
||||
|
||||
For multiple providers, use `WithOAuth2()` multiple times or `NewMultiProviderAuthenticator()`.
|
||||
|
||||
## Examples
|
||||
|
||||
Complete working examples available in `oauth2_examples.go`:
|
||||
- Basic Google OAuth2
|
||||
- GitHub OAuth2
|
||||
- Custom provider
|
||||
- Multi-provider setup
|
||||
- Token refresh
|
||||
- Logout flow
|
||||
- Complete integration with security middleware
|
||||
281
pkg/security/OAUTH2_REFRESH_QUICK_REFERENCE.md
Normal file
281
pkg/security/OAUTH2_REFRESH_QUICK_REFERENCE.md
Normal file
@@ -0,0 +1,281 @@
|
||||
# OAuth2 Refresh Token - Quick Reference
|
||||
|
||||
## Quick Setup (3 Steps)
|
||||
|
||||
### 1. Initialize Authenticator
|
||||
```go
|
||||
auth := security.NewGoogleAuthenticator(
|
||||
"client-id",
|
||||
"client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
```
|
||||
|
||||
### 2. OAuth2 Login Flow
|
||||
```go
|
||||
// Login - Redirect to Google
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Callback - Store tokens
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, _ := auth.OAuth2HandleCallback(
|
||||
r.Context(),
|
||||
"google",
|
||||
r.URL.Query().Get("code"),
|
||||
r.URL.Query().Get("state"),
|
||||
)
|
||||
|
||||
// Save refresh_token on client
|
||||
// loginResp.RefreshToken - Store this securely!
|
||||
// loginResp.Token - Session token for API calls
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Refresh Endpoint
|
||||
```go
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Refresh token
|
||||
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, "google")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 401)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Multi-Provider Example
|
||||
|
||||
```go
|
||||
// Configure multiple providers
|
||||
auth := security.NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ProviderName: "google",
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
}).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ProviderName: "github",
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
})
|
||||
|
||||
// Refresh with provider selection
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Provider string `json:"provider"` // "google" or "github"
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 401)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Client-Side JavaScript
|
||||
|
||||
```javascript
|
||||
// Automatic token refresh on 401
|
||||
async function apiCall(url) {
|
||||
let response = await fetch(url, {
|
||||
headers: {
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
|
||||
}
|
||||
});
|
||||
|
||||
// Token expired - refresh it
|
||||
if (response.status === 401) {
|
||||
await refreshToken();
|
||||
|
||||
// Retry request with new token
|
||||
response = await fetch(url, {
|
||||
headers: {
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async function refreshToken() {
|
||||
const response = await fetch('/auth/refresh', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
refresh_token: localStorage.getItem('refresh_token'),
|
||||
provider: localStorage.getItem('provider')
|
||||
})
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
localStorage.setItem('access_token', data.token);
|
||||
localStorage.setItem('refresh_token', data.refresh_token);
|
||||
} else {
|
||||
// Refresh failed - redirect to login
|
||||
window.location.href = '/login';
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API Methods
|
||||
|
||||
| Method | Parameters | Returns |
|
||||
|--------|-----------|---------|
|
||||
| `OAuth2RefreshToken` | `ctx, refreshToken, provider` | `*LoginResponse, error` |
|
||||
| `OAuth2HandleCallback` | `ctx, provider, code, state` | `*LoginResponse, error` |
|
||||
| `OAuth2GetAuthURL` | `provider, state` | `string, error` |
|
||||
| `OAuth2GenerateState` | none | `string, error` |
|
||||
| `OAuth2GetProviders` | none | `[]string` |
|
||||
|
||||
---
|
||||
|
||||
## LoginResponse Structure
|
||||
|
||||
```go
|
||||
type LoginResponse struct {
|
||||
Token string // New session token for API calls
|
||||
RefreshToken string // Refresh token (store securely)
|
||||
User *UserContext // User information
|
||||
ExpiresIn int64 // Seconds until token expires
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Database Stored Procedures
|
||||
|
||||
- `resolvespec_oauth_getrefreshtoken(refresh_token)` - Get session by refresh token
|
||||
- `resolvespec_oauth_updaterefreshtoken(update_data)` - Update tokens after refresh
|
||||
- `resolvespec_oauth_getuser(user_id)` - Get user data
|
||||
|
||||
All procedures return: `{p_success bool, p_error text, p_data jsonb}`
|
||||
|
||||
---
|
||||
|
||||
## Common Errors
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `invalid or expired refresh token` | Token revoked/expired | Re-authenticate user |
|
||||
| `OAuth2 provider 'xxx' not found` | Provider not configured | Add with `WithOAuth2()` |
|
||||
| `failed to refresh token with provider` | Provider rejected request | Check credentials, re-auth user |
|
||||
|
||||
---
|
||||
|
||||
## Security Checklist
|
||||
|
||||
- [ ] Use HTTPS for all OAuth2 endpoints
|
||||
- [ ] Store refresh tokens securely (HttpOnly cookies or encrypted storage)
|
||||
- [ ] Set cookie flags: `HttpOnly`, `Secure`, `SameSite=Strict`
|
||||
- [ ] Implement rate limiting on refresh endpoint
|
||||
- [ ] Log refresh attempts for audit
|
||||
- [ ] Rotate tokens on refresh
|
||||
- [ ] Revoke old sessions after successful refresh
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# 1. Login and get refresh token
|
||||
curl http://localhost:8080/auth/google/login
|
||||
# Follow OAuth2 flow, get refresh_token from callback response
|
||||
|
||||
# 2. Refresh token
|
||||
curl -X POST http://localhost:8080/auth/refresh \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"refresh_token":"ya29.xxx","provider":"google"}'
|
||||
|
||||
# 3. Use new token
|
||||
curl http://localhost:8080/api/protected \
|
||||
-H "Authorization: Bearer sess_abc123..."
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Pre-configured Providers
|
||||
|
||||
```go
|
||||
// Google
|
||||
auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
|
||||
|
||||
// GitHub
|
||||
auth := security.NewGitHubAuthenticator(clientID, secret, redirectURL, db)
|
||||
|
||||
// Microsoft
|
||||
auth := security.NewMicrosoftAuthenticator(clientID, secret, redirectURL, db)
|
||||
|
||||
// Facebook
|
||||
auth := security.NewFacebookAuthenticator(clientID, secret, redirectURL, db)
|
||||
|
||||
// All providers at once
|
||||
auth := security.NewMultiProviderAuthenticator(db, map[string]security.OAuth2Config{
|
||||
"google": {...},
|
||||
"github": {...},
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Provider-Specific Notes
|
||||
|
||||
### Google
|
||||
- Add `access_type=offline` to get refresh token
|
||||
- Add `prompt=consent` to force consent screen
|
||||
```go
|
||||
authURL += "&access_type=offline&prompt=consent"
|
||||
```
|
||||
|
||||
### GitHub
|
||||
- Refresh tokens not always provided
|
||||
- May need to request `offline_access` scope
|
||||
|
||||
### Microsoft
|
||||
- Use `offline_access` scope for refresh token
|
||||
|
||||
### Facebook
|
||||
- Tokens expire after 60 days by default
|
||||
- Check app settings for token expiration policy
|
||||
|
||||
---
|
||||
|
||||
## Complete Example
|
||||
|
||||
See `/pkg/security/oauth2_examples.go` line 250 for full working example.
|
||||
|
||||
For detailed documentation see `/pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md`.
|
||||
495
pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md
Normal file
495
pkg/security/OAUTH2_REFRESH_TOKEN_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,495 @@
|
||||
# OAuth2 Refresh Token Implementation
|
||||
|
||||
## Overview
|
||||
|
||||
OAuth2 refresh token functionality is **fully implemented** in the ResolveSpec security package. This allows refreshing expired access tokens without requiring users to re-authenticate.
|
||||
|
||||
## Implementation Status: ✅ COMPLETE
|
||||
|
||||
### Components Implemented
|
||||
|
||||
1. **✅ Database Schema** - Tables and stored procedures
|
||||
2. **✅ Go Methods** - OAuth2RefreshToken implementation
|
||||
3. **✅ Thread Safety** - Mutex protection for provider map
|
||||
4. **✅ Examples** - Working code examples
|
||||
5. **✅ Documentation** - Complete API reference
|
||||
|
||||
---
|
||||
|
||||
## 1. Database Schema
|
||||
|
||||
### Tables Modified
|
||||
|
||||
```sql
|
||||
-- user_sessions table with OAuth2 token fields
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
session_token VARCHAR(500) NOT NULL UNIQUE,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
access_token TEXT, -- OAuth2 access token
|
||||
refresh_token TEXT, -- OAuth2 refresh token
|
||||
token_type VARCHAR(50), -- "Bearer", etc.
|
||||
auth_provider VARCHAR(50) -- "google", "github", etc.
|
||||
);
|
||||
```
|
||||
|
||||
### Stored Procedures
|
||||
|
||||
**`resolvespec_oauth_getrefreshtoken(p_refresh_token)`**
|
||||
- Gets OAuth2 session data by refresh token
|
||||
- Returns: `{user_id, access_token, token_type, expiry}`
|
||||
- Location: `database_schema.sql:714`
|
||||
|
||||
**`resolvespec_oauth_updaterefreshtoken(p_update_data)`**
|
||||
- Updates session with new tokens after refresh
|
||||
- Input: `{user_id, old_refresh_token, new_session_token, new_access_token, new_refresh_token, expires_at}`
|
||||
- Location: `database_schema.sql:752`
|
||||
|
||||
**`resolvespec_oauth_getuser(p_user_id)`**
|
||||
- Gets user data by ID for building UserContext
|
||||
- Location: `database_schema.sql:791`
|
||||
|
||||
---
|
||||
|
||||
## 2. Go Implementation
|
||||
|
||||
### Method Signature
|
||||
|
||||
```go
|
||||
func (a *DatabaseAuthenticator) OAuth2RefreshToken(
|
||||
ctx context.Context,
|
||||
refreshToken string,
|
||||
providerName string,
|
||||
) (*LoginResponse, error)
|
||||
```
|
||||
|
||||
**Location:** `pkg/security/oauth2_methods.go:375`
|
||||
|
||||
### Implementation Flow
|
||||
|
||||
```
|
||||
1. Validate provider exists
|
||||
├─ getOAuth2Provider(providerName) with RLock
|
||||
└─ Return error if provider not configured
|
||||
|
||||
2. Get session from database
|
||||
├─ Call resolvespec_oauth_getrefreshtoken(refreshToken)
|
||||
└─ Parse session data {user_id, access_token, token_type, expiry}
|
||||
|
||||
3. Refresh token with OAuth2 provider
|
||||
├─ Create oauth2.Token from stored data
|
||||
├─ Use provider.config.TokenSource(ctx, oldToken)
|
||||
└─ Call tokenSource.Token() to get new token
|
||||
|
||||
4. Generate new session token
|
||||
└─ Use OAuth2GenerateState() for secure random token
|
||||
|
||||
5. Update database
|
||||
├─ Call resolvespec_oauth_updaterefreshtoken()
|
||||
└─ Store new session_token, access_token, refresh_token
|
||||
|
||||
6. Get user data
|
||||
├─ Call resolvespec_oauth_getuser(user_id)
|
||||
└─ Build UserContext
|
||||
|
||||
7. Return LoginResponse
|
||||
└─ {Token, RefreshToken, User, ExpiresIn}
|
||||
```
|
||||
|
||||
### Thread Safety
|
||||
|
||||
**Mutex Protection:** All access to `oauth2Providers` map is protected with `sync.RWMutex`
|
||||
|
||||
```go
|
||||
type DatabaseAuthenticator struct {
|
||||
oauth2Providers map[string]*OAuth2Provider
|
||||
oauth2ProvidersMutex sync.RWMutex // Thread-safe access
|
||||
}
|
||||
|
||||
// Read operations use RLock
|
||||
func (a *DatabaseAuthenticator) getOAuth2Provider(name string) {
|
||||
a.oauth2ProvidersMutex.RLock()
|
||||
defer a.oauth2ProvidersMutex.RUnlock()
|
||||
// ... access map
|
||||
}
|
||||
|
||||
// Write operations use Lock
|
||||
func (a *DatabaseAuthenticator) WithOAuth2(cfg OAuth2Config) {
|
||||
a.oauth2ProvidersMutex.Lock()
|
||||
defer a.oauth2ProvidersMutex.Unlock()
|
||||
// ... modify map
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. Usage Examples
|
||||
|
||||
### Single Provider (Google)
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create Google OAuth2 authenticator
|
||||
auth := security.NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Token refresh endpoint
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Refresh token (provider name defaults to "google")
|
||||
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, "google")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set new session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
})
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
http.ListenAndServe(":8080", router)
|
||||
}
|
||||
```
|
||||
|
||||
### Multi-Provider Setup
|
||||
|
||||
```go
|
||||
// Single authenticator with multiple OAuth2 providers
|
||||
auth := security.NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
}).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
ProviderName: "github",
|
||||
})
|
||||
|
||||
// Refresh endpoint with provider selection
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Provider string `json:"provider"` // "google" or "github"
|
||||
}
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Refresh with specific provider
|
||||
loginResp, err := auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
```
|
||||
|
||||
### Client-Side Usage
|
||||
|
||||
```javascript
|
||||
// JavaScript client example
|
||||
async function refreshAccessToken() {
|
||||
const refreshToken = localStorage.getItem('refresh_token');
|
||||
const provider = localStorage.getItem('auth_provider'); // "google", "github", etc.
|
||||
|
||||
const response = await fetch('/auth/refresh', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
refresh_token: refreshToken,
|
||||
provider: provider
|
||||
})
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
|
||||
// Store new tokens
|
||||
localStorage.setItem('access_token', data.token);
|
||||
localStorage.setItem('refresh_token', data.refresh_token);
|
||||
|
||||
console.log('Token refreshed successfully');
|
||||
return data.token;
|
||||
} else {
|
||||
// Refresh failed - redirect to login
|
||||
window.location.href = '/login';
|
||||
}
|
||||
}
|
||||
|
||||
// Automatically refresh token when API returns 401
|
||||
async function apiCall(endpoint) {
|
||||
let response = await fetch(endpoint, {
|
||||
headers: {
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('access_token')
|
||||
}
|
||||
});
|
||||
|
||||
if (response.status === 401) {
|
||||
// Token expired - try refresh
|
||||
const newToken = await refreshAccessToken();
|
||||
|
||||
// Retry with new token
|
||||
response = await fetch(endpoint, {
|
||||
headers: {
|
||||
'Authorization': 'Bearer ' + newToken
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. API Reference
|
||||
|
||||
### DatabaseAuthenticator Methods
|
||||
|
||||
| Method | Signature | Description |
|
||||
|--------|-----------|-------------|
|
||||
| `OAuth2RefreshToken` | `(ctx, refreshToken, provider) (*LoginResponse, error)` | Refreshes expired OAuth2 access token |
|
||||
| `WithOAuth2` | `(cfg OAuth2Config) *DatabaseAuthenticator` | Adds OAuth2 provider (chainable) |
|
||||
| `OAuth2GetAuthURL` | `(provider, state) (string, error)` | Gets authorization URL |
|
||||
| `OAuth2HandleCallback` | `(ctx, provider, code, state) (*LoginResponse, error)` | Handles OAuth2 callback |
|
||||
| `OAuth2GenerateState` | `() (string, error)` | Generates CSRF state token |
|
||||
| `OAuth2GetProviders` | `() []string` | Lists configured providers |
|
||||
|
||||
### LoginResponse Structure
|
||||
|
||||
```go
|
||||
type LoginResponse struct {
|
||||
Token string // New session token
|
||||
RefreshToken string // New refresh token (may be same as input)
|
||||
User *UserContext // User information
|
||||
ExpiresIn int64 // Seconds until expiration
|
||||
}
|
||||
|
||||
type UserContext struct {
|
||||
UserID int // Database user ID
|
||||
UserName string // Username
|
||||
Email string // Email address
|
||||
UserLevel int // Permission level
|
||||
SessionID string // Session token
|
||||
RemoteID string // OAuth2 provider user ID
|
||||
Roles []string // User roles
|
||||
Claims map[string]any // Additional claims
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Important Notes
|
||||
|
||||
### Provider Configuration
|
||||
|
||||
**For Google:** Add `access_type=offline` to get refresh token on first login:
|
||||
|
||||
```go
|
||||
auth := security.NewGoogleAuthenticator(clientID, clientSecret, redirectURL, db)
|
||||
// When generating auth URL, add access_type parameter
|
||||
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||
authURL += "&access_type=offline&prompt=consent"
|
||||
```
|
||||
|
||||
**For GitHub:** Refresh tokens are not always provided. Check provider documentation.
|
||||
|
||||
### Token Storage
|
||||
|
||||
- Store refresh tokens securely on client (localStorage, secure cookie, etc.)
|
||||
- Never log refresh tokens
|
||||
- Refresh tokens are long-lived (days/months depending on provider)
|
||||
- Access tokens are short-lived (minutes/hours)
|
||||
|
||||
### Error Handling
|
||||
|
||||
Common errors:
|
||||
- `"invalid or expired refresh token"` - Token expired or revoked
|
||||
- `"OAuth2 provider 'xxx' not found"` - Provider not configured
|
||||
- `"failed to refresh token with provider"` - Provider rejected refresh request
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **Always use HTTPS** for token transmission
|
||||
2. **Store refresh tokens securely** on client
|
||||
3. **Set appropriate cookie flags**: `HttpOnly`, `Secure`, `SameSite`
|
||||
4. **Implement token rotation** - issue new refresh token on each refresh
|
||||
5. **Revoke old tokens** after successful refresh
|
||||
6. **Rate limit** refresh endpoints
|
||||
7. **Log refresh attempts** for audit trail
|
||||
|
||||
---
|
||||
|
||||
## 6. Testing
|
||||
|
||||
### Manual Test Flow
|
||||
|
||||
1. **Initial Login:**
|
||||
```bash
|
||||
curl http://localhost:8080/auth/google/login
|
||||
# Follow redirect to Google
|
||||
# Returns to callback with LoginResponse containing refresh_token
|
||||
```
|
||||
|
||||
2. **Wait for Token Expiry (or manually expire in DB)**
|
||||
|
||||
3. **Refresh Token:**
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/auth/refresh \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"refresh_token": "ya29.a0AfH6SMB...",
|
||||
"provider": "google"
|
||||
}'
|
||||
|
||||
# Response:
|
||||
{
|
||||
"token": "sess_abc123...",
|
||||
"refresh_token": "ya29.a0AfH6SMB...",
|
||||
"user": {
|
||||
"user_id": 1,
|
||||
"user_name": "john_doe",
|
||||
"email": "john@example.com",
|
||||
"session_id": "sess_abc123..."
|
||||
},
|
||||
"expires_in": 3600
|
||||
}
|
||||
```
|
||||
|
||||
4. **Use New Token:**
|
||||
```bash
|
||||
curl http://localhost:8080/api/protected \
|
||||
-H "Authorization: Bearer sess_abc123..."
|
||||
```
|
||||
|
||||
### Database Verification
|
||||
|
||||
```sql
|
||||
-- Check session with refresh token
|
||||
SELECT session_token, user_id, expires_at, refresh_token, auth_provider
|
||||
FROM user_sessions
|
||||
WHERE refresh_token = 'ya29.a0AfH6SMB...';
|
||||
|
||||
-- Verify token was updated after refresh
|
||||
SELECT session_token, access_token, refresh_token,
|
||||
expires_at, last_activity_at
|
||||
FROM user_sessions
|
||||
WHERE user_id = 1
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Troubleshooting
|
||||
|
||||
### "Refresh token not found or expired"
|
||||
|
||||
**Cause:** Refresh token doesn't exist in database or session expired
|
||||
|
||||
**Solution:**
|
||||
- Check if initial OAuth2 login stored refresh token
|
||||
- Verify provider returns refresh token (some require `access_type=offline`)
|
||||
- Check session hasn't been deleted from database
|
||||
|
||||
### "Failed to refresh token with provider"
|
||||
|
||||
**Cause:** OAuth2 provider rejected the refresh request
|
||||
|
||||
**Possible reasons:**
|
||||
- Refresh token was revoked by user
|
||||
- OAuth2 app credentials changed
|
||||
- Network connectivity issues
|
||||
- Provider rate limiting
|
||||
|
||||
**Solution:**
|
||||
- Re-authenticate user (full OAuth2 flow)
|
||||
- Check provider dashboard for app status
|
||||
- Verify client credentials are correct
|
||||
|
||||
### "OAuth2 provider 'xxx' not found"
|
||||
|
||||
**Cause:** Provider not registered with `WithOAuth2()`
|
||||
|
||||
**Solution:**
|
||||
```go
|
||||
// Make sure provider is configured
|
||||
auth := security.NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(security.OAuth2Config{
|
||||
ProviderName: "google", // This name must match refresh call
|
||||
// ... other config
|
||||
})
|
||||
|
||||
// Then use same name in refresh
|
||||
auth.OAuth2RefreshToken(ctx, token, "google") // Must match ProviderName
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Complete Working Example
|
||||
|
||||
See `pkg/security/oauth2_examples.go:250` for full working example with token refresh.
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
|
||||
OAuth2 refresh token functionality is **production-ready** with:
|
||||
|
||||
- ✅ Complete database schema with stored procedures
|
||||
- ✅ Thread-safe Go implementation with mutex protection
|
||||
- ✅ Multi-provider support (Google, GitHub, Microsoft, Facebook, custom)
|
||||
- ✅ Comprehensive error handling
|
||||
- ✅ Working code examples
|
||||
- ✅ Full API documentation
|
||||
- ✅ Security best practices implemented
|
||||
|
||||
**No additional implementation needed - feature is complete and functional.**
|
||||
208
pkg/security/PASSKEY_QUICK_REFERENCE.md
Normal file
208
pkg/security/PASSKEY_QUICK_REFERENCE.md
Normal file
@@ -0,0 +1,208 @@
|
||||
# Passkey Authentication Quick Reference
|
||||
|
||||
## Overview
|
||||
Passkey authentication (WebAuthn/FIDO2) is now integrated into the DatabaseAuthenticator. This provides passwordless authentication using biometrics, security keys, or device credentials.
|
||||
|
||||
## Setup
|
||||
|
||||
### Database Schema
|
||||
Run the passkey SQL schema (in database_schema.sql):
|
||||
- Creates `user_passkey_credentials` table
|
||||
- Adds stored procedures for passkey operations
|
||||
|
||||
### Go Code
|
||||
```go
|
||||
// Create passkey provider
|
||||
passkeyProvider := security.NewDatabasePasskeyProvider(db,
|
||||
security.DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com",
|
||||
RPName: "Example App",
|
||||
RPOrigin: "https://example.com",
|
||||
Timeout: 60000,
|
||||
})
|
||||
|
||||
// Create authenticator with passkey support
|
||||
auth := security.NewDatabaseAuthenticatorWithOptions(db,
|
||||
security.DatabaseAuthenticatorOptions{
|
||||
PasskeyProvider: passkeyProvider,
|
||||
})
|
||||
|
||||
// Or add passkey to existing authenticator
|
||||
auth = security.NewDatabaseAuthenticator(db).WithPasskey(passkeyProvider)
|
||||
```
|
||||
|
||||
## Registration Flow
|
||||
|
||||
### Backend - Step 1: Begin Registration
|
||||
```go
|
||||
options, err := auth.BeginPasskeyRegistration(ctx,
|
||||
security.PasskeyBeginRegistrationRequest{
|
||||
UserID: 1,
|
||||
Username: "alice",
|
||||
DisplayName: "Alice Smith",
|
||||
})
|
||||
// Send options to client as JSON
|
||||
```
|
||||
|
||||
### Frontend - Step 2: Create Credential
|
||||
```javascript
|
||||
// Convert options from server
|
||||
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||
options.user.id = base64ToArrayBuffer(options.user.id);
|
||||
|
||||
// Create credential
|
||||
const credential = await navigator.credentials.create({
|
||||
publicKey: options
|
||||
});
|
||||
|
||||
// Send credential back to server
|
||||
```
|
||||
|
||||
### Backend - Step 3: Complete Registration
|
||||
```go
|
||||
credential, err := auth.CompletePasskeyRegistration(ctx,
|
||||
security.PasskeyRegisterRequest{
|
||||
UserID: 1,
|
||||
Response: clientResponse,
|
||||
ExpectedChallenge: storedChallenge,
|
||||
CredentialName: "My iPhone",
|
||||
})
|
||||
```
|
||||
|
||||
## Authentication Flow
|
||||
|
||||
### Backend - Step 1: Begin Authentication
|
||||
```go
|
||||
options, err := auth.BeginPasskeyAuthentication(ctx,
|
||||
security.PasskeyBeginAuthenticationRequest{
|
||||
Username: "alice", // Optional for resident key
|
||||
})
|
||||
// Send options to client as JSON
|
||||
```
|
||||
|
||||
### Frontend - Step 2: Get Credential
|
||||
```javascript
|
||||
// Convert options from server
|
||||
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||
|
||||
// Get credential
|
||||
const credential = await navigator.credentials.get({
|
||||
publicKey: options
|
||||
});
|
||||
|
||||
// Send assertion back to server
|
||||
```
|
||||
|
||||
### Backend - Step 3: Complete Authentication
|
||||
```go
|
||||
loginResponse, err := auth.LoginWithPasskey(ctx,
|
||||
security.PasskeyLoginRequest{
|
||||
Response: clientAssertion,
|
||||
ExpectedChallenge: storedChallenge,
|
||||
Claims: map[string]any{
|
||||
"ip_address": "192.168.1.1",
|
||||
"user_agent": "Mozilla/5.0...",
|
||||
},
|
||||
})
|
||||
// Returns session token and user info
|
||||
```
|
||||
|
||||
## Credential Management
|
||||
|
||||
### List Credentials
|
||||
```go
|
||||
credentials, err := auth.GetPasskeyCredentials(ctx, userID)
|
||||
```
|
||||
|
||||
### Update Credential Name
|
||||
```go
|
||||
err := auth.UpdatePasskeyCredentialName(ctx, userID, credentialID, "New Name")
|
||||
```
|
||||
|
||||
### Delete Credential
|
||||
```go
|
||||
err := auth.DeletePasskeyCredential(ctx, userID, credentialID)
|
||||
```
|
||||
|
||||
## HTTP Endpoints Example
|
||||
|
||||
### POST /api/passkey/register/begin
|
||||
Request: `{user_id, username, display_name}`
|
||||
Response: PasskeyRegistrationOptions
|
||||
|
||||
### POST /api/passkey/register/complete
|
||||
Request: `{user_id, response, credential_name}`
|
||||
Response: PasskeyCredential
|
||||
|
||||
### POST /api/passkey/login/begin
|
||||
Request: `{username}` (optional)
|
||||
Response: PasskeyAuthenticationOptions
|
||||
|
||||
### POST /api/passkey/login/complete
|
||||
Request: `{response}`
|
||||
Response: LoginResponse with session token
|
||||
|
||||
### GET /api/passkey/credentials
|
||||
Response: Array of PasskeyCredential
|
||||
|
||||
### DELETE /api/passkey/credentials/{id}
|
||||
Request: `{credential_id}`
|
||||
Response: 204 No Content
|
||||
|
||||
## Database Stored Procedures
|
||||
|
||||
- `resolvespec_passkey_store_credential` - Store new credential
|
||||
- `resolvespec_passkey_get_credential` - Get credential by ID
|
||||
- `resolvespec_passkey_get_user_credentials` - Get all user credentials
|
||||
- `resolvespec_passkey_update_counter` - Update sign counter (clone detection)
|
||||
- `resolvespec_passkey_delete_credential` - Delete credential
|
||||
- `resolvespec_passkey_update_name` - Update credential name
|
||||
- `resolvespec_passkey_get_credentials_by_username` - Get credentials for login
|
||||
|
||||
## Security Features
|
||||
|
||||
- **Clone Detection**: Sign counter validation detects credential cloning
|
||||
- **Attestation Support**: Stores attestation type (none, indirect, direct)
|
||||
- **Transport Options**: Tracks authenticator transports (usb, nfc, ble, internal)
|
||||
- **Backup State**: Tracks if credential is backed up/synced
|
||||
- **User Verification**: Supports preferred/required user verification
|
||||
|
||||
## Important Notes
|
||||
|
||||
1. **WebAuthn Library**: Current implementation is simplified. For production, use a proper WebAuthn library like `github.com/go-webauthn/webauthn` for full verification.
|
||||
|
||||
2. **Challenge Storage**: Store challenges securely in session/cache. Never expose challenges to client beyond initial request.
|
||||
|
||||
3. **HTTPS Required**: Passkeys only work over HTTPS (except localhost).
|
||||
|
||||
4. **Browser Support**: Check browser compatibility for WebAuthn API.
|
||||
|
||||
5. **Relying Party ID**: Must match your domain exactly.
|
||||
|
||||
## Client-Side Helper Functions
|
||||
|
||||
```javascript
|
||||
function base64ToArrayBuffer(base64) {
|
||||
const binary = atob(base64);
|
||||
const bytes = new Uint8Array(binary.length);
|
||||
for (let i = 0; i < binary.length; i++) {
|
||||
bytes[i] = binary.charCodeAt(i);
|
||||
}
|
||||
return bytes.buffer;
|
||||
}
|
||||
|
||||
function arrayBufferToBase64(buffer) {
|
||||
const bytes = new Uint8Array(buffer);
|
||||
let binary = '';
|
||||
for (let i = 0; i < bytes.length; i++) {
|
||||
binary += String.fromCharCode(bytes[i]);
|
||||
}
|
||||
return btoa(binary);
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
Run tests: `go test -v ./pkg/security -run Passkey`
|
||||
|
||||
All passkey functionality includes comprehensive tests using sqlmock.
|
||||
@@ -7,15 +7,16 @@
|
||||
auth := security.NewDatabaseAuthenticator(db) // Session-based (recommended)
|
||||
// OR: auth := security.NewJWTAuthenticator("secret-key", db)
|
||||
// OR: auth := security.NewHeaderAuthenticator()
|
||||
// OR: auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db) // OAuth2
|
||||
|
||||
colSec := security.NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := security.NewDatabaseRowSecurityProvider(db)
|
||||
|
||||
// Step 2: Combine providers
|
||||
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
|
||||
// Step 3: Setup and apply middleware
|
||||
securityList := security.SetupSecurityProvider(handler, provider)
|
||||
securityList, _ := security.SetupSecurityProvider(handler, provider)
|
||||
router.Use(security.NewAuthMiddleware(securityList))
|
||||
router.Use(security.SetSecurityMiddleware(securityList))
|
||||
```
|
||||
@@ -30,6 +31,7 @@ router.Use(security.SetSecurityMiddleware(securityList))
|
||||
```go
|
||||
// DatabaseAuthenticator uses these stored procedures:
|
||||
resolvespec_login(jsonb) // Login with credentials
|
||||
resolvespec_register(jsonb) // Register new user
|
||||
resolvespec_logout(jsonb) // Invalidate session
|
||||
resolvespec_session(text, text) // Validate session token
|
||||
resolvespec_session_update(text, jsonb) // Update activity timestamp
|
||||
@@ -256,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
|
||||
// Add to blacklist
|
||||
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{
|
||||
"token": req.Token,
|
||||
"user_id": req.UserID,
|
||||
}).Error
|
||||
// Invalidate session via stored procedure
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
|
||||
@@ -403,11 +402,16 @@ assert.Equal(t, "user_id = {UserID}", row.Template)
|
||||
```
|
||||
HTTP Request
|
||||
↓
|
||||
NewAuthMiddleware → calls provider.Authenticate()
|
||||
↓ (adds UserContext to context)
|
||||
NewOptionalAuthMiddleware → calls provider.Authenticate()
|
||||
↓ (adds UserContext or guest context; never 401)
|
||||
SetSecurityMiddleware → adds SecurityList to context
|
||||
↓
|
||||
Handler.Handle()
|
||||
Handler.Handle() → resolves model
|
||||
↓
|
||||
BeforeHandle Hook → CheckModelAuthAllowed(secCtx, operation)
|
||||
├─ SecurityDisabled → allow
|
||||
├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
|
||||
└─ UserID == 0 → abort 401
|
||||
↓
|
||||
BeforeRead Hook → calls provider.GetColumnSecurity() + GetRowSecurity()
|
||||
↓
|
||||
@@ -502,10 +506,31 @@ func (p *MyProvider) GetColumnSecurity(ctx context.Context, userID int, schema,
|
||||
|
||||
---
|
||||
|
||||
## Login/Logout Endpoints
|
||||
## Login/Logout/Register Endpoints
|
||||
|
||||
```go
|
||||
func SetupAuthRoutes(router *mux.Router, securityList *security.SecurityList) {
|
||||
// Register
|
||||
router.HandleFunc("/auth/register", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req security.RegisterRequest
|
||||
json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Check if provider supports registration
|
||||
registrable, ok := securityList.Provider().(security.Registrable)
|
||||
if !ok {
|
||||
http.Error(w, "Registration not supported", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := registrable.Register(r.Context(), req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}).Methods("POST")
|
||||
|
||||
// Login
|
||||
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req security.LoginRequest
|
||||
@@ -670,15 +695,30 @@ http.Handle("/api/protected", authHandler)
|
||||
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
|
||||
http.Handle("/home", optionalHandler)
|
||||
|
||||
// Example handler
|
||||
func myHandler(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := security.GetUserContext(r.Context())
|
||||
if userCtx.UserID == 0 {
|
||||
// Guest user
|
||||
} else {
|
||||
// Authenticated user
|
||||
}
|
||||
}
|
||||
// NewOptionalAuthMiddleware - For spec routes; auth enforcement deferred to BeforeHandle
|
||||
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
|
||||
apiRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||
restheadspec.RegisterSecurityHooks(handler, securityList) // includes BeforeHandle
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Model-Level Access Control
|
||||
|
||||
```go
|
||||
// Register model with rules (pkg/modelregistry)
|
||||
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
|
||||
SecurityDisabled: false, // skip all auth when true
|
||||
CanPublicRead: true, // unauthenticated reads allowed
|
||||
CanPublicCreate: false, // requires auth
|
||||
CanPublicUpdate: false, // requires auth
|
||||
CanPublicDelete: false, // requires auth
|
||||
CanUpdate: true, // authenticated can update
|
||||
CanDelete: false, // authenticated cannot delete (enforced in BeforeDelete)
|
||||
})
|
||||
|
||||
// CheckModelAuthAllowed used automatically in BeforeHandle hook
|
||||
// No code needed — call RegisterSecurityHooks and it's applied
|
||||
```
|
||||
|
||||
---
|
||||
@@ -707,6 +747,7 @@ meta, ok := security.GetUserMeta(ctx)
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `INTERFACE_GUIDE.md` | **Start here** - Complete implementation guide |
|
||||
| `OAUTH2.md` | **OAuth2 Guide** - Google, GitHub, Microsoft, Facebook, custom providers |
|
||||
| `examples.go` | Working provider implementations to copy |
|
||||
| `setup_example.go` | 6 complete integration examples |
|
||||
| `README.md` | Architecture overview and migration guide |
|
||||
|
||||
@@ -6,11 +6,13 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
||||
|
||||
- ✅ **Interface-Based** - Type-safe providers instead of callbacks
|
||||
- ✅ **Login/Logout Support** - Built-in authentication lifecycle
|
||||
- ✅ **Two-Factor Authentication (2FA)** - Optional TOTP support for enhanced security
|
||||
- ✅ **Composable** - Mix and match different providers
|
||||
- ✅ **No Global State** - Each handler has its own security configuration
|
||||
- ✅ **Testable** - Easy to mock and test
|
||||
- ✅ **Extensible** - Implement custom providers for your needs
|
||||
- ✅ **Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability
|
||||
- ✅ **OAuth2 Authorization Server** - Built-in OAuth 2.1 + PKCE server (RFC 8414, 7591, 7009, 7662) with login form and external provider federation
|
||||
|
||||
## Stored Procedure Architecture
|
||||
|
||||
@@ -37,6 +39,12 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
||||
| `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator |
|
||||
| `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider |
|
||||
| `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider |
|
||||
| `resolvespec_oauth_register_client` | Persist OAuth2 client (RFC 7591) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_get_client` | Retrieve OAuth2 client by ID | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_save_code` | Persist authorization code | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_exchange_code` | Consume authorization code (single-use) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_introspect` | Token introspection (RFC 7662) | OAuthServer / DatabaseAuthenticator |
|
||||
| `resolvespec_oauth_revoke` | Token revocation (RFC 7009) | OAuthServer / DatabaseAuthenticator |
|
||||
|
||||
See `database_schema.sql` for complete stored procedure definitions and examples.
|
||||
|
||||
@@ -212,6 +220,23 @@ auth := security.NewJWTAuthenticator("secret-key", db)
|
||||
// Note: Requires JWT library installation for token signing/verification
|
||||
```
|
||||
|
||||
**TwoFactorAuthenticator** - Wraps any authenticator with TOTP 2FA:
|
||||
```go
|
||||
baseAuth := security.NewDatabaseAuthenticator(db)
|
||||
|
||||
// Use in-memory provider (for testing)
|
||||
tfaProvider := security.NewMemoryTwoFactorProvider(nil)
|
||||
|
||||
// Or use database provider (for production)
|
||||
tfaProvider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
// Requires: users table with totp fields, user_totp_backup_codes table
|
||||
// Requires: resolvespec_totp_* stored procedures (see totp_database_schema.sql)
|
||||
|
||||
auth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
|
||||
// Supports: TOTP codes, backup codes, QR code generation
|
||||
// Compatible with Google Authenticator, Microsoft Authenticator, Authy, etc.
|
||||
```
|
||||
|
||||
### Column Security Providers
|
||||
|
||||
**DatabaseColumnSecurityProvider** - Loads rules from database:
|
||||
@@ -334,7 +359,182 @@ func handleRefresh(securityList *security.SecurityList) http.HandlerFunc {
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Two-Factor Authentication (2FA)
|
||||
|
||||
### Overview
|
||||
|
||||
- **Optional per-user** - Enable/disable 2FA individually
|
||||
- **TOTP standard** - Compatible with Google Authenticator, Microsoft Authenticator, Authy, 1Password, etc.
|
||||
- **Configurable** - SHA1/SHA256/SHA512, 6/8 digits, custom time periods
|
||||
- **Backup codes** - One-time recovery codes with secure hashing
|
||||
- **Clock skew** - Handles time differences between client/server
|
||||
|
||||
### Setup
|
||||
|
||||
```go
|
||||
// 1. Wrap existing authenticator with 2FA support
|
||||
baseAuth := security.NewDatabaseAuthenticator(db)
|
||||
tfaProvider := security.NewMemoryTwoFactorProvider(nil) // Use custom DB implementation in production
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
|
||||
|
||||
// 2. Use as normal authenticator
|
||||
provider := security.NewCompositeSecurityProvider(tfaAuth, colSec, rowSec)
|
||||
securityList := security.NewSecurityList(provider)
|
||||
```
|
||||
|
||||
### Enable 2FA for User
|
||||
|
||||
```go
|
||||
// 1. Initiate 2FA setup
|
||||
secret, err := tfaAuth.Setup2FA(userID, "MyApp", "user@example.com")
|
||||
// Returns: secret.Secret, secret.QRCodeURL, secret.BackupCodes
|
||||
|
||||
// 2. User scans QR code with authenticator app
|
||||
// Display secret.QRCodeURL as QR code image
|
||||
|
||||
// 3. User enters verification code from app
|
||||
code := "123456" // From authenticator app
|
||||
err = tfaAuth.Enable2FA(userID, secret.Secret, code)
|
||||
// 2FA is now enabled for this user
|
||||
|
||||
// 4. Store backup codes securely and show to user once
|
||||
// Display: secret.BackupCodes (10 codes)
|
||||
```
|
||||
|
||||
### Login Flow with 2FA
|
||||
|
||||
```go
|
||||
// 1. User provides credentials
|
||||
req := security.LoginRequest{
|
||||
Username: "user@example.com",
|
||||
Password: "password",
|
||||
}
|
||||
|
||||
resp, err := tfaAuth.Login(ctx, req)
|
||||
|
||||
// 2. Check if 2FA required
|
||||
if resp.Requires2FA {
|
||||
// Prompt user for 2FA code
|
||||
code := getUserInput() // From authenticator app or backup code
|
||||
|
||||
// 3. Login again with 2FA code
|
||||
req.TwoFactorCode = code
|
||||
resp, err = tfaAuth.Login(ctx, req)
|
||||
|
||||
// 4. Success - token is returned
|
||||
token := resp.Token
|
||||
}
|
||||
```
|
||||
|
||||
### Manage 2FA
|
||||
|
||||
```go
|
||||
// Disable 2FA
|
||||
err := tfaAuth.Disable2FA(userID)
|
||||
|
||||
// Regenerate backup codes
|
||||
newCodes, err := tfaAuth.RegenerateBackupCodes(userID, 10)
|
||||
|
||||
// Check status
|
||||
has2FA, err := tfaProvider.Get2FAStatus(userID)
|
||||
```
|
||||
|
||||
### Custom 2FA Storage
|
||||
|
||||
**Option 1: Use DatabaseTwoFactorProvider (Recommended)**
|
||||
|
||||
```go
|
||||
// Uses PostgreSQL stored procedures for all operations
|
||||
db := setupDatabase()
|
||||
|
||||
// Run migrations from totp_database_schema.sql
|
||||
// - Add totp_secret, totp_enabled, totp_enabled_at to users table
|
||||
// - Create user_totp_backup_codes table
|
||||
// - Create resolvespec_totp_* stored procedures
|
||||
|
||||
tfaProvider := security.NewDatabaseTwoFactorProvider(db, nil)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, nil)
|
||||
```
|
||||
|
||||
**Option 2: Implement Custom Provider**
|
||||
|
||||
Implement `TwoFactorAuthProvider` for custom storage:
|
||||
|
||||
```go
|
||||
type DBTwoFactorProvider struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func (p *DBTwoFactorProvider) Enable2FA(userID int, secret string, backupCodes []string) error {
|
||||
// Store secret and hashed backup codes in database
|
||||
return p.db.Exec("UPDATE users SET totp_secret = ?, backup_codes = ? WHERE id = ?",
|
||||
secret, hashCodes(backupCodes), userID).Error
|
||||
}
|
||||
|
||||
func (p *DBTwoFactorProvider) Get2FASecret(userID int) (string, error) {
|
||||
var secret string
|
||||
err := p.db.Raw("SELECT totp_secret FROM users WHERE id = ?", userID).Scan(&secret).Error
|
||||
return secret, err
|
||||
}
|
||||
|
||||
// Implement remaining methods: Generate2FASecret, Validate2FACode, Disable2FA,
|
||||
// Get2FAStatus, GenerateBackupCodes, ValidateBackupCode
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
```go
|
||||
config := &security.TwoFactorConfig{
|
||||
Algorithm: "SHA256", // SHA1, SHA256, SHA512
|
||||
Digits: 8, // 6 or 8
|
||||
Period: 30, // Seconds per code
|
||||
SkewWindow: 2, // Accept codes ±2 periods
|
||||
}
|
||||
|
||||
totp := security.NewTOTPGenerator(config)
|
||||
tfaAuth := security.NewTwoFactorAuthenticator(baseAuth, tfaProvider, config)
|
||||
```
|
||||
|
||||
### API Response Structure
|
||||
|
||||
```go
|
||||
// LoginResponse with 2FA
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
Requires2FA bool `json:"requires_2fa"`
|
||||
TwoFactorSetupData *TwoFactorSecret `json:"two_factor_setup,omitempty"`
|
||||
User *UserContext `json:"user"`
|
||||
}
|
||||
|
||||
// TwoFactorSecret for setup
|
||||
type TwoFactorSecret struct {
|
||||
Secret string `json:"secret"` // Base32 encoded
|
||||
QRCodeURL string `json:"qr_code_url"` // otpauth://totp/...
|
||||
BackupCodes []string `json:"backup_codes"` // 10 recovery codes
|
||||
}
|
||||
|
||||
// UserContext includes 2FA status
|
||||
type UserContext struct {
|
||||
UserID int `json:"user_id"`
|
||||
TwoFactorEnabled bool `json:"two_factor_enabled"`
|
||||
// ... other fields
|
||||
}
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
- **Store secrets encrypted** - Never store TOTP secrets in plain text
|
||||
- **Hash backup codes** - Use SHA-256 before storing
|
||||
- **Rate limit** - Limit 2FA verification attempts
|
||||
- **Require password** - Always verify password before disabling 2FA
|
||||
- **Show backup codes once** - Display only during setup/regeneration
|
||||
- **Log 2FA events** - Track enable/disable/failed attempts
|
||||
- **Mark codes as used** - Backup codes are single-use only
|
||||
|
||||
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
} else {
|
||||
http.Error(w, "Refresh not supported", http.StatusNotImplemented)
|
||||
@@ -558,14 +758,25 @@ resolvespec.RegisterSecurityHooks(resolveHandler, securityList)
|
||||
```
|
||||
HTTP Request
|
||||
↓
|
||||
NewAuthMiddleware (security package)
|
||||
NewOptionalAuthMiddleware (security package) ← recommended for spec routes
|
||||
├─ Calls provider.Authenticate(request)
|
||||
└─ Adds UserContext to context
|
||||
├─ On success: adds authenticated UserContext to context
|
||||
└─ On failure: adds guest UserContext (UserID=0) to context
|
||||
↓
|
||||
SetSecurityMiddleware (security package)
|
||||
└─ Adds SecurityList to context
|
||||
↓
|
||||
Spec Handler (restheadspec/funcspec/resolvespec)
|
||||
Spec Handler (restheadspec/funcspec/resolvespec/websocketspec/mqttspec)
|
||||
└─ Resolves schema + entity + model from request
|
||||
↓
|
||||
BeforeHandle Hook (registered by spec via RegisterSecurityHooks)
|
||||
├─ Adapts spec's HookContext → SecurityContext
|
||||
├─ Calls security.CheckModelAuthAllowed(secCtx, operation)
|
||||
│ ├─ Loads model rules from context or registry
|
||||
│ ├─ SecurityDisabled → allow
|
||||
│ ├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
|
||||
│ └─ UserID == 0 → 401 unauthorized
|
||||
└─ On error: aborts with 401
|
||||
↓
|
||||
BeforeRead Hook (registered by spec)
|
||||
├─ Adapts spec's HookContext → SecurityContext
|
||||
@@ -591,7 +802,8 @@ HTTP Response (secured data)
|
||||
```
|
||||
|
||||
**Key Points:**
|
||||
- Security package is spec-agnostic and provides core logic
|
||||
- `NewOptionalAuthMiddleware` never rejects — it sets guest context on auth failure; `BeforeHandle` enforces auth after model resolution
|
||||
- `BeforeHandle` fires after model resolution, giving access to model rules and user context simultaneously
|
||||
- Each spec registers its own hooks that adapt to SecurityContext
|
||||
- Security rules are loaded once and cached for the request
|
||||
- Row security is applied to the query (database level)
|
||||
@@ -692,6 +904,155 @@ securityList := security.NewSecurityList(provider)
|
||||
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
|
||||
```
|
||||
|
||||
## OAuth2 Authorization Server
|
||||
|
||||
`OAuthServer` is a generic OAuth 2.1 + PKCE authorization server. It is not tied to any spec — `pkg/resolvemcp` uses it, but it can be used standalone with any `http.ServeMux`.
|
||||
|
||||
### Endpoints
|
||||
|
||||
| Method | Path | RFC |
|
||||
|--------|------|-----|
|
||||
| `GET` | `/.well-known/oauth-authorization-server` | RFC 8414 — server metadata |
|
||||
| `POST` | `/oauth/register` | RFC 7591 — dynamic client registration |
|
||||
| `GET` | `/oauth/authorize` | OAuth 2.1 — start authorization / provider selection |
|
||||
| `POST` | `/oauth/authorize` | OAuth 2.1 — login form submission |
|
||||
| `POST` | `/oauth/token` | OAuth 2.1 — code exchange + refresh |
|
||||
| `POST` | `/oauth/revoke` | RFC 7009 — token revocation |
|
||||
| `POST` | `/oauth/introspect` | RFC 7662 — token introspection |
|
||||
| `GET` | `{ProviderCallbackPath}` | External provider redirect target |
|
||||
|
||||
### Config
|
||||
|
||||
```go
|
||||
cfg := security.OAuthServerConfig{
|
||||
Issuer: "https://example.com", // Required — token issuer URL
|
||||
ProviderCallbackPath: "/oauth/provider/callback", // External provider redirect target
|
||||
LoginTitle: "My App Login", // HTML login page title
|
||||
PersistClients: true, // Store clients in DB (multi-instance safe)
|
||||
PersistCodes: true, // Store codes in DB (multi-instance safe)
|
||||
DefaultScopes: []string{"openid", "profile"}, // Returned when no scope requested
|
||||
AccessTokenTTL: time.Hour,
|
||||
AuthCodeTTL: 5 * time.Minute,
|
||||
}
|
||||
```
|
||||
|
||||
| Field | Default | Notes |
|
||||
|-------|---------|-------|
|
||||
| `Issuer` | — | Required |
|
||||
| `ProviderCallbackPath` | `/oauth/provider/callback` | |
|
||||
| `LoginTitle` | `"Login"` | |
|
||||
| `PersistClients` | `false` | Set `true` for multi-instance |
|
||||
| `PersistCodes` | `false` | Set `true` for multi-instance |
|
||||
| `DefaultScopes` | `nil` | |
|
||||
| `AccessTokenTTL` | `1h` | |
|
||||
| `AuthCodeTTL` | `5m` | |
|
||||
|
||||
### Operating Modes
|
||||
|
||||
**Mode 1 — Direct login (username/password form)**
|
||||
|
||||
Pass a `*DatabaseAuthenticator` to `NewOAuthServer`. The server renders a login form at `GET /oauth/authorize` and issues tokens via the stored session after login.
|
||||
|
||||
```go
|
||||
auth := security.NewDatabaseAuthenticator(db)
|
||||
srv := security.NewOAuthServer(cfg, auth)
|
||||
```
|
||||
|
||||
**Mode 2 — External provider federation**
|
||||
|
||||
Pass `nil` as auth and register external providers. The authorize page shows a provider selection UI.
|
||||
|
||||
```go
|
||||
srv := security.NewOAuthServer(cfg, nil)
|
||||
srv.RegisterExternalProvider(googleAuth, "google")
|
||||
srv.RegisterExternalProvider(githubAuth, "github")
|
||||
```
|
||||
|
||||
**Mode 3 — Both**
|
||||
|
||||
Pass auth for the login form and also register external providers. The authorize page shows both a login form and provider buttons.
|
||||
|
||||
```go
|
||||
srv := security.NewOAuthServer(cfg, auth)
|
||||
srv.RegisterExternalProvider(googleAuth, "google")
|
||||
```
|
||||
|
||||
### Standalone Usage
|
||||
|
||||
```go
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/.well-known/", srv.HTTPHandler())
|
||||
mux.Handle("/oauth/", srv.HTTPHandler())
|
||||
mux.Handle(cfg.ProviderCallbackPath, srv.HTTPHandler())
|
||||
|
||||
http.ListenAndServe(":8080", mux)
|
||||
```
|
||||
|
||||
### DB Persistence
|
||||
|
||||
When `PersistClients: true` or `PersistCodes: true`, the server calls the corresponding `DatabaseAuthenticator` methods. Both flags default to `false` (in-memory maps). Enable both for multi-instance deployments.
|
||||
|
||||
Requires `oauth_clients` and `oauth_codes` tables + 6 stored procedures from `database_schema.sql`.
|
||||
|
||||
#### New DB Types
|
||||
|
||||
```go
|
||||
type OAuthServerClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||
}
|
||||
|
||||
type OAuthCode struct {
|
||||
Code string `json:"code"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ClientState string `json:"client_state,omitempty"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
SessionToken string `json:"session_token"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type OAuthTokenInfo struct {
|
||||
Active bool `json:"active"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
}
|
||||
```
|
||||
|
||||
#### DatabaseAuthenticator OAuth Methods
|
||||
|
||||
```go
|
||||
auth.OAuthRegisterClient(ctx, client) // RFC 7591 — persist client
|
||||
auth.OAuthGetClient(ctx, clientID) // retrieve client
|
||||
auth.OAuthSaveCode(ctx, code) // persist authorization code
|
||||
auth.OAuthExchangeCode(ctx, code) // consume code (single-use, deletes on read)
|
||||
auth.OAuthIntrospectToken(ctx, token) // RFC 7662 — returns OAuthTokenInfo
|
||||
auth.OAuthRevokeToken(ctx, token) // RFC 7009 — revoke session
|
||||
```
|
||||
|
||||
#### SQLNames Fields
|
||||
|
||||
```go
|
||||
type SQLNames struct {
|
||||
// ... existing fields ...
|
||||
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
|
||||
OAuthGetClient string // default: "resolvespec_oauth_get_client"
|
||||
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
|
||||
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
|
||||
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
|
||||
OAuthRevoke string // default: "resolvespec_oauth_revoke"
|
||||
}
|
||||
```
|
||||
|
||||
The main changes:
|
||||
1. Security package no longer knows about specific spec types
|
||||
2. Each spec registers its own security hooks
|
||||
@@ -809,15 +1170,49 @@ func (p *MyProvider) GetRowSecurity(ctx context.Context, userID int, schema, tab
|
||||
}
|
||||
```
|
||||
|
||||
## Model-Level Access Control
|
||||
|
||||
Use `ModelRules` (from `pkg/modelregistry`) to control per-entity auth behavior:
|
||||
|
||||
```go
|
||||
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
|
||||
SecurityDisabled: false, // true = skip all auth checks
|
||||
CanPublicRead: true, // unauthenticated GET allowed
|
||||
CanPublicCreate: false, // requires auth
|
||||
CanPublicUpdate: false, // requires auth
|
||||
CanPublicDelete: false, // requires auth
|
||||
CanUpdate: true, // authenticated users can update
|
||||
CanDelete: false, // authenticated users cannot delete
|
||||
})
|
||||
```
|
||||
|
||||
`CheckModelAuthAllowed(secCtx, operation)` applies these rules in `BeforeHandle`:
|
||||
1. `SecurityDisabled` → allow all
|
||||
2. `CanPublicRead/Create/Update/Delete` → allow unauthenticated for that operation
|
||||
3. Guest (UserID == 0) → return 401
|
||||
4. Authenticated → allow (operation-specific `CanUpdate`/`CanDelete` checked in `BeforeUpdate`/`BeforeDelete`)
|
||||
|
||||
---
|
||||
|
||||
## Middleware and Handler API
|
||||
|
||||
### NewAuthMiddleware
|
||||
Standard middleware that authenticates all requests:
|
||||
Standard middleware that authenticates all requests and returns 401 on failure:
|
||||
|
||||
```go
|
||||
router.Use(security.NewAuthMiddleware(securityList))
|
||||
```
|
||||
|
||||
### NewOptionalAuthMiddleware
|
||||
Middleware for spec routes — always continues; sets guest context on auth failure:
|
||||
|
||||
```go
|
||||
// Use with RegisterSecurityHooks — auth enforcement is deferred to BeforeHandle
|
||||
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
|
||||
apiRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||
restheadspec.RegisterSecurityHooks(handler, securityList) // registers BeforeHandle
|
||||
```
|
||||
|
||||
Routes can skip authentication using the `SkipAuth` helper:
|
||||
|
||||
```go
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
|
||||
}
|
||||
|
||||
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
|
||||
// For JWT, logout could involve token blacklisting
|
||||
// Add token to blacklist table
|
||||
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
|
||||
// "token": req.Token,
|
||||
// "expires_at": time.Now().Add(24 * time.Hour),
|
||||
// }).Error
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"reflect"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// SecurityContext is a generic interface that any spec can implement to integrate with security features
|
||||
@@ -226,6 +227,122 @@ func ApplyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) err
|
||||
return applyColumnSecurity(secCtx, securityList)
|
||||
}
|
||||
|
||||
// checkModelUpdateAllowed returns an error if CanUpdate is false for the model.
|
||||
// Rules are read from context (set by NewModelAuthMiddleware) with a fallback to the model registry.
|
||||
func checkModelUpdateAllowed(secCtx SecurityContext) error {
|
||||
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||
if !ok {
|
||||
schema := secCtx.GetSchema()
|
||||
entity := secCtx.GetEntity()
|
||||
var err error
|
||||
if schema != "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||
}
|
||||
if err != nil || schema == "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||
}
|
||||
if err != nil {
|
||||
return nil // model not registered, allow by default
|
||||
}
|
||||
}
|
||||
if !rules.CanUpdate {
|
||||
return fmt.Errorf("update not allowed for %s", secCtx.GetEntity())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkModelDeleteAllowed returns an error if CanDelete is false for the model.
|
||||
// Rules are read from context (set by NewModelAuthMiddleware) with a fallback to the model registry.
|
||||
func checkModelDeleteAllowed(secCtx SecurityContext) error {
|
||||
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||
if !ok {
|
||||
schema := secCtx.GetSchema()
|
||||
entity := secCtx.GetEntity()
|
||||
var err error
|
||||
if schema != "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||
}
|
||||
if err != nil || schema == "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||
}
|
||||
if err != nil {
|
||||
return nil // model not registered, allow by default
|
||||
}
|
||||
}
|
||||
if !rules.CanDelete {
|
||||
return fmt.Errorf("delete not allowed for %s", secCtx.GetEntity())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckModelAuthAllowed checks whether the requested operation is permitted based on
|
||||
// model rules and the current user's authentication state. It is intended for use in
|
||||
// a BeforeHandle hook, fired after model resolution.
|
||||
//
|
||||
// Logic:
|
||||
// 1. Load model rules from context (set by NewModelAuthMiddleware) or fall back to registry.
|
||||
// 2. SecurityDisabled → allow.
|
||||
// 3. operation == "read" && CanPublicRead → allow.
|
||||
// 4. operation == "create" && CanPublicCreate → allow.
|
||||
// 5. operation == "update" && CanPublicUpdate → allow.
|
||||
// 6. operation == "delete" && CanPublicDelete → allow.
|
||||
// 7. Guest (UserID == 0) → return "authentication required".
|
||||
// 8. Authenticated user → allow (operation-specific checks remain in BeforeUpdate/BeforeDelete).
|
||||
func CheckModelAuthAllowed(secCtx SecurityContext, operation string) error {
|
||||
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||
if !ok {
|
||||
schema := secCtx.GetSchema()
|
||||
entity := secCtx.GetEntity()
|
||||
var err error
|
||||
if schema != "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||
}
|
||||
if err != nil || schema == "" {
|
||||
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||
}
|
||||
if err != nil {
|
||||
// Model not registered - fall through to auth check
|
||||
userID, _ := secCtx.GetUserID()
|
||||
if userID == 0 {
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if rules.SecurityDisabled {
|
||||
return nil
|
||||
}
|
||||
if operation == "read" && rules.CanPublicRead {
|
||||
return nil
|
||||
}
|
||||
if operation == "create" && rules.CanPublicCreate {
|
||||
return nil
|
||||
}
|
||||
if operation == "update" && rules.CanPublicUpdate {
|
||||
return nil
|
||||
}
|
||||
if operation == "delete" && rules.CanPublicDelete {
|
||||
return nil
|
||||
}
|
||||
|
||||
userID, _ := secCtx.GetUserID()
|
||||
if userID == 0 {
|
||||
return fmt.Errorf("authentication required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckModelUpdateAllowed is the public wrapper for checkModelUpdateAllowed.
|
||||
func CheckModelUpdateAllowed(secCtx SecurityContext) error {
|
||||
return checkModelUpdateAllowed(secCtx)
|
||||
}
|
||||
|
||||
// CheckModelDeleteAllowed is the public wrapper for checkModelDeleteAllowed.
|
||||
func CheckModelDeleteAllowed(secCtx SecurityContext) error {
|
||||
return checkModelDeleteAllowed(secCtx)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
|
||||
@@ -7,33 +7,48 @@ import (
|
||||
|
||||
// UserContext holds authenticated user information
|
||||
type UserContext struct {
|
||||
UserID int `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
UserLevel int `json:"user_level"`
|
||||
SessionID string `json:"session_id"`
|
||||
SessionRID int64 `json:"session_rid"`
|
||||
RemoteID string `json:"remote_id"`
|
||||
Roles []string `json:"roles"`
|
||||
Email string `json:"email"`
|
||||
Claims map[string]any `json:"claims"`
|
||||
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
|
||||
UserID int `json:"user_id"`
|
||||
UserName string `json:"user_name"`
|
||||
UserLevel int `json:"user_level"`
|
||||
SessionID string `json:"session_id"`
|
||||
SessionRID int64 `json:"session_rid"`
|
||||
RemoteID string `json:"remote_id"`
|
||||
Roles []string `json:"roles"`
|
||||
Email string `json:"email"`
|
||||
Claims map[string]any `json:"claims"`
|
||||
Meta map[string]any `json:"meta"` // Additional metadata that can hold any JSON-serializable values
|
||||
TwoFactorEnabled bool `json:"two_factor_enabled"` // Indicates if 2FA is enabled for this user
|
||||
}
|
||||
|
||||
// LoginRequest contains credentials for login
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Claims map[string]any `json:"claims"` // Additional login data
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
TwoFactorCode string `json:"two_factor_code,omitempty"` // TOTP or backup code
|
||||
Claims map[string]any `json:"claims"` // Additional login data
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
}
|
||||
|
||||
// RegisterRequest contains information for new user registration
|
||||
type RegisterRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Email string `json:"email"`
|
||||
UserLevel int `json:"user_level"`
|
||||
Roles []string `json:"roles"`
|
||||
Claims map[string]any `json:"claims"` // Additional registration data
|
||||
Meta map[string]any `json:"meta"` // Additional metadata
|
||||
}
|
||||
|
||||
// LoginResponse contains the result of a login attempt
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
User *UserContext `json:"user"`
|
||||
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
User *UserContext `json:"user"`
|
||||
ExpiresIn int64 `json:"expires_in"` // Token expiration in seconds
|
||||
Requires2FA bool `json:"requires_2fa"` // True if 2FA code is required
|
||||
TwoFactorSetupData *TwoFactorSecret `json:"two_factor_setup,omitempty"` // Present when setting up 2FA
|
||||
Meta map[string]any `json:"meta"` // Additional metadata to be set on user context
|
||||
}
|
||||
|
||||
// LogoutRequest contains information for logout
|
||||
@@ -55,6 +70,12 @@ type Authenticator interface {
|
||||
Authenticate(r *http.Request) (*UserContext, error)
|
||||
}
|
||||
|
||||
// Registrable allows providers to support user registration
|
||||
type Registrable interface {
|
||||
// Register creates a new user account
|
||||
Register(ctx context.Context, req RegisterRequest) (*LoginResponse, error)
|
||||
}
|
||||
|
||||
// ColumnSecurityProvider handles column-level security (masking/hiding)
|
||||
type ColumnSecurityProvider interface {
|
||||
// GetColumnSecurity loads column security rules for a user and entity
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||
)
|
||||
|
||||
// contextKey is a custom type for context keys to avoid collisions
|
||||
@@ -23,6 +25,7 @@ const (
|
||||
UserMetaKey contextKey = "user_meta"
|
||||
SkipAuthKey contextKey = "skip_auth"
|
||||
OptionalAuthKey contextKey = "optional_auth"
|
||||
ModelRulesKey contextKey = "model_rules"
|
||||
)
|
||||
|
||||
// SkipAuth returns a context with skip auth flag set to true
|
||||
@@ -136,6 +139,31 @@ func NewOptionalAuthHandler(securityList *SecurityList, next http.Handler) http.
|
||||
})
|
||||
}
|
||||
|
||||
// NewOptionalAuthMiddleware creates authentication middleware that always continues.
|
||||
// On auth failure, a guest user context is set instead of returning 401.
|
||||
// Intended for spec routes where auth enforcement is deferred to a BeforeHandle hook
|
||||
// after model resolution.
|
||||
func NewOptionalAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
provider := securityList.Provider()
|
||||
if provider == nil {
|
||||
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
userCtx, err := provider.Authenticate(r)
|
||||
if err != nil {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, setUserContext(r, userCtx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates an authentication middleware with the given security list
|
||||
// This middleware extracts user authentication from the request and adds it to context
|
||||
// Routes can skip authentication by setting SkipAuthKey context value (use SkipAuth helper)
|
||||
@@ -182,6 +210,68 @@ func NewAuthMiddleware(securityList *SecurityList) func(http.Handler) http.Handl
|
||||
}
|
||||
}
|
||||
|
||||
// NewModelAuthMiddleware creates authentication middleware that respects ModelRules for the given model name.
|
||||
// It first checks if ModelRules are set for the model:
|
||||
// - If SecurityDisabled is true, authentication is skipped and a guest context is set.
|
||||
// - Otherwise, all checks from NewAuthMiddleware apply (SkipAuthKey, provider check, OptionalAuthKey, Authenticate).
|
||||
//
|
||||
// If the model is not found in any registry, the middleware falls back to standard NewAuthMiddleware behaviour.
|
||||
func NewModelAuthMiddleware(securityList *SecurityList, modelName string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Check ModelRules first
|
||||
if rules, err := modelregistry.GetModelRulesByName(modelName); err == nil {
|
||||
// Store rules in context for downstream use (e.g., security hooks)
|
||||
r = r.WithContext(context.WithValue(r.Context(), ModelRulesKey, rules))
|
||||
|
||||
if rules.SecurityDisabled {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
isRead := r.Method == http.MethodGet || r.Method == http.MethodHead
|
||||
isUpdate := r.Method == http.MethodPut || r.Method == http.MethodPatch
|
||||
if (isRead && rules.CanPublicRead) || (isUpdate && rules.CanPublicUpdate) {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this route should skip authentication
|
||||
if skip, ok := r.Context().Value(SkipAuthKey).(bool); ok && skip {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
|
||||
// Get the security provider
|
||||
provider := securityList.Provider()
|
||||
if provider == nil {
|
||||
http.Error(w, "Security provider not configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if this route has optional authentication
|
||||
optional, _ := r.Context().Value(OptionalAuthKey).(bool)
|
||||
|
||||
// Try to authenticate
|
||||
userCtx, err := provider.Authenticate(r)
|
||||
if err != nil {
|
||||
if optional {
|
||||
guestCtx := createGuestContext(r)
|
||||
next.ServeHTTP(w, setUserContext(r, guestCtx))
|
||||
return
|
||||
}
|
||||
http.Error(w, "Authentication failed: "+err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, setUserContext(r, userCtx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// SetSecurityMiddleware adds security context to requests
|
||||
// This middleware should be applied after AuthMiddleware
|
||||
func SetSecurityMiddleware(securityList *SecurityList) func(http.Handler) http.Handler {
|
||||
@@ -366,6 +456,131 @@ func GetUserMeta(ctx context.Context) (map[string]any, bool) {
|
||||
return meta, ok
|
||||
}
|
||||
|
||||
// SessionCookieOptions configures the session cookie set by SetSessionCookie.
|
||||
// All fields are optional; sensible secure defaults are applied when omitted.
|
||||
type SessionCookieOptions struct {
|
||||
// Name is the cookie name. Defaults to "session_token".
|
||||
Name string
|
||||
// Path is the cookie path. Defaults to "/".
|
||||
Path string
|
||||
// Domain restricts the cookie to a specific domain. Empty means current host.
|
||||
Domain string
|
||||
// Secure sets the Secure flag. Defaults to true.
|
||||
// Set to false only in local development over HTTP.
|
||||
Secure *bool
|
||||
// SameSite sets the SameSite policy. Defaults to http.SameSiteLaxMode.
|
||||
SameSite http.SameSite
|
||||
}
|
||||
|
||||
func (o SessionCookieOptions) name() string {
|
||||
if o.Name != "" {
|
||||
return o.Name
|
||||
}
|
||||
return "session_token"
|
||||
}
|
||||
|
||||
func (o SessionCookieOptions) path() string {
|
||||
if o.Path != "" {
|
||||
return o.Path
|
||||
}
|
||||
return "/"
|
||||
}
|
||||
|
||||
func (o SessionCookieOptions) secure() bool {
|
||||
if o.Secure != nil {
|
||||
return *o.Secure
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (o SessionCookieOptions) sameSite() http.SameSite {
|
||||
if o.SameSite != 0 {
|
||||
return o.SameSite
|
||||
}
|
||||
return http.SameSiteLaxMode
|
||||
}
|
||||
|
||||
// SetSessionCookie writes the session_token cookie to the response after a successful login.
|
||||
// Call this immediately after a successful Authenticator.Login() call.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// resp, err := auth.Login(r.Context(), req)
|
||||
// if err != nil { ... }
|
||||
// security.SetSessionCookie(w, resp)
|
||||
// json.NewEncoder(w).Encode(resp)
|
||||
func SetSessionCookie(w http.ResponseWriter, loginResp *LoginResponse, opts ...SessionCookieOptions) {
|
||||
var o SessionCookieOptions
|
||||
if len(opts) > 0 {
|
||||
o = opts[0]
|
||||
}
|
||||
|
||||
maxAge := 0
|
||||
if loginResp.ExpiresIn > 0 {
|
||||
maxAge = int(loginResp.ExpiresIn)
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: o.name(),
|
||||
Value: loginResp.Token,
|
||||
Path: o.path(),
|
||||
Domain: o.Domain,
|
||||
MaxAge: maxAge,
|
||||
HttpOnly: true,
|
||||
Secure: o.secure(),
|
||||
SameSite: o.sameSite(),
|
||||
})
|
||||
}
|
||||
|
||||
// GetSessionCookie returns the session token value from the request cookie, or empty string if not present.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// token := security.GetSessionCookie(r)
|
||||
func GetSessionCookie(r *http.Request, opts ...SessionCookieOptions) string {
|
||||
var o SessionCookieOptions
|
||||
if len(opts) > 0 {
|
||||
o = opts[0]
|
||||
}
|
||||
cookie, err := r.Cookie(o.name())
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
// ClearSessionCookie expires the session_token cookie, effectively logging the user out on the browser side.
|
||||
// Call this after a successful Authenticator.Logout() call.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// err := auth.Logout(r.Context(), req)
|
||||
// if err != nil { ... }
|
||||
// security.ClearSessionCookie(w)
|
||||
func ClearSessionCookie(w http.ResponseWriter, opts ...SessionCookieOptions) {
|
||||
var o SessionCookieOptions
|
||||
if len(opts) > 0 {
|
||||
o = opts[0]
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: o.name(),
|
||||
Value: "",
|
||||
Path: o.path(),
|
||||
Domain: o.Domain,
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: o.secure(),
|
||||
SameSite: o.sameSite(),
|
||||
})
|
||||
}
|
||||
|
||||
// GetModelRulesFromContext extracts ModelRules stored by NewModelAuthMiddleware
|
||||
func GetModelRulesFromContext(ctx context.Context) (modelregistry.ModelRules, bool) {
|
||||
rules, ok := ctx.Value(ModelRulesKey).(modelregistry.ModelRules)
|
||||
return rules, ok
|
||||
}
|
||||
|
||||
// // Handler adapters for resolvespec/restheadspec compatibility
|
||||
// // These functions allow using NewAuthHandler and NewOptionalAuthHandler with custom handler abstractions
|
||||
|
||||
|
||||
615
pkg/security/oauth2_examples.go
Normal file
615
pkg/security/oauth2_examples.go
Normal file
@@ -0,0 +1,615 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// Example: OAuth2 Authentication with Google
|
||||
func ExampleOAuth2Google() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create OAuth2 authenticator for Google
|
||||
oauth2Auth := NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Login endpoint - redirects to Google
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Callback endpoint - handles Google response
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "google", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
// Return user info as JSON
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: OAuth2 Authentication with GitHub
|
||||
func ExampleOAuth2GitHub() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
oauth2Auth := NewGitHubAuthenticator(
|
||||
"your-github-client-id",
|
||||
"your-github-client-secret",
|
||||
"http://localhost:8080/auth/github/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "github", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: Custom OAuth2 Provider
|
||||
func ExampleOAuth2Custom() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Custom OAuth2 provider configuration
|
||||
oauth2Auth := NewDatabaseAuthenticator(db).WithOAuth2(OAuth2Config{
|
||||
ClientID: "your-client-id",
|
||||
ClientSecret: "your-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://your-provider.com/oauth/authorize",
|
||||
TokenURL: "https://your-provider.com/oauth/token",
|
||||
UserInfoURL: "https://your-provider.com/oauth/userinfo",
|
||||
ProviderName: "custom-provider",
|
||||
|
||||
// Custom user info parser
|
||||
UserInfoParser: func(userInfo map[string]any) (*UserContext, error) {
|
||||
// Extract custom fields from your provider
|
||||
return &UserContext{
|
||||
UserName: userInfo["username"].(string),
|
||||
Email: userInfo["email"].(string),
|
||||
RemoteID: userInfo["id"].(string),
|
||||
UserLevel: 1,
|
||||
Roles: []string{"user"},
|
||||
Claims: userInfo,
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
router.HandleFunc("/auth/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL("custom-provider", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "custom-provider", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: Multi-Provider OAuth2 with Security Integration
|
||||
func ExampleOAuth2MultiProvider() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create OAuth2 authenticators for multiple providers
|
||||
googleAuth := NewGoogleAuthenticator(
|
||||
"google-client-id",
|
||||
"google-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
githubAuth := NewGitHubAuthenticator(
|
||||
"github-client-id",
|
||||
"github-client-secret",
|
||||
"http://localhost:8080/auth/github/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Create column and row security providers
|
||||
colSec := NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := NewDatabaseRowSecurityProvider(db)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Google OAuth2 routes
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := googleAuth.OAuth2GenerateState()
|
||||
authURL, _ := googleAuth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := googleAuth.OAuth2HandleCallback(r.Context(), "google", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// GitHub OAuth2 routes
|
||||
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := githubAuth.OAuth2GenerateState()
|
||||
authURL, _ := githubAuth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := githubAuth.OAuth2HandleCallback(r.Context(), "github", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Use Google auth for protected routes (or GitHub - both work)
|
||||
provider, _ := NewCompositeSecurityProvider(googleAuth, colSec, rowSec)
|
||||
securityList, _ := NewSecurityList(provider)
|
||||
|
||||
// Protected route with authentication
|
||||
protectedRouter := router.PathPrefix("/api").Subrouter()
|
||||
protectedRouter.Use(NewAuthMiddleware(securityList))
|
||||
protectedRouter.Use(SetSecurityMiddleware(securityList))
|
||||
|
||||
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_ = json.NewEncoder(w).Encode(userCtx)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: OAuth2 with Token Refresh
|
||||
func ExampleOAuth2TokenRefresh() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
oauth2Auth := NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Refresh token endpoint
|
||||
router.HandleFunc("/auth/refresh", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
Provider string `json:"provider"` // "google", "github", etc.
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "Invalid request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Default to google if not specified
|
||||
if req.Provider == "" {
|
||||
req.Provider = "google"
|
||||
}
|
||||
|
||||
// Use OAuth2-specific refresh method
|
||||
loginResp, err := oauth2Auth.OAuth2RefreshToken(r.Context(), req.RefreshToken, req.Provider)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set new session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: OAuth2 Logout
|
||||
func ExampleOAuth2Logout() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
oauth2Auth := NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
router.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Header.Get("Authorization")
|
||||
if token == "" {
|
||||
cookie, err := r.Cookie("session_token")
|
||||
if err == nil {
|
||||
token = cookie.Value
|
||||
}
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
// Get user ID from session
|
||||
userCtx, err := oauth2Auth.Authenticate(r)
|
||||
if err == nil {
|
||||
_ = oauth2Auth.Logout(r.Context(), LogoutRequest{
|
||||
Token: token,
|
||||
UserID: userCtx.UserID,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Clear cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("Logged out successfully"))
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
// Example: Complete OAuth2 Integration with Database Setup
|
||||
func ExampleOAuth2Complete() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create tables (run once)
|
||||
setupOAuth2Tables(db)
|
||||
|
||||
// Create OAuth2 authenticator
|
||||
oauth2Auth := NewGoogleAuthenticator(
|
||||
"your-client-id",
|
||||
"your-client-secret",
|
||||
"http://localhost:8080/auth/google/callback",
|
||||
db,
|
||||
)
|
||||
|
||||
// Create security providers
|
||||
colSec := NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := NewDatabaseRowSecurityProvider(db)
|
||||
provider, _ := NewCompositeSecurityProvider(oauth2Auth, colSec, rowSec)
|
||||
securityList, _ := NewSecurityList(provider)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Public routes
|
||||
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("Welcome! <a href='/auth/google/login'>Login with Google</a>"))
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := oauth2Auth.OAuth2GenerateState()
|
||||
authURL, _ := oauth2Auth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
|
||||
loginResp, err := oauth2Auth.OAuth2HandleCallback(r.Context(), "github", code, state)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResp.Token,
|
||||
Path: "/",
|
||||
MaxAge: int(loginResp.ExpiresIn),
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/dashboard", http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
// Protected routes
|
||||
protectedRouter := router.PathPrefix("/").Subrouter()
|
||||
protectedRouter.Use(NewAuthMiddleware(securityList))
|
||||
protectedRouter.Use(SetSecurityMiddleware(securityList))
|
||||
|
||||
protectedRouter.HandleFunc("/dashboard", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_, _ = fmt.Fprintf(w, "Welcome, %s! Your email: %s", userCtx.UserName, userCtx.Email)
|
||||
})
|
||||
|
||||
protectedRouter.HandleFunc("/api/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_ = json.NewEncoder(w).Encode(userCtx)
|
||||
})
|
||||
|
||||
protectedRouter.HandleFunc("/auth/logout", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_ = oauth2Auth.Logout(r.Context(), LogoutRequest{
|
||||
Token: userCtx.SessionID,
|
||||
UserID: userCtx.UserID,
|
||||
})
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
|
||||
func setupOAuth2Tables(db *sql.DB) {
|
||||
// Create tables from database_schema.sql
|
||||
// This is a helper function - in production, use migrations
|
||||
ctx := context.Background()
|
||||
|
||||
// Create users table if not exists
|
||||
_, _ = db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
username VARCHAR(255) NOT NULL UNIQUE,
|
||||
email VARCHAR(255) NOT NULL UNIQUE,
|
||||
password VARCHAR(255),
|
||||
user_level INTEGER DEFAULT 0,
|
||||
roles VARCHAR(500),
|
||||
is_active BOOLEAN DEFAULT true,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_login_at TIMESTAMP,
|
||||
remote_id VARCHAR(255),
|
||||
auth_provider VARCHAR(50)
|
||||
)
|
||||
`)
|
||||
|
||||
// Create user_sessions table (used for both regular and OAuth2 sessions)
|
||||
_, _ = db.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
session_token VARCHAR(500) NOT NULL UNIQUE,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_activity_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
ip_address VARCHAR(45),
|
||||
user_agent TEXT,
|
||||
access_token TEXT,
|
||||
refresh_token TEXT,
|
||||
token_type VARCHAR(50) DEFAULT 'Bearer',
|
||||
auth_provider VARCHAR(50)
|
||||
)
|
||||
`)
|
||||
}
|
||||
|
||||
// Example: All OAuth2 Providers at Once
|
||||
func ExampleOAuth2AllProviders() {
|
||||
db, _ := sql.Open("postgres", "connection-string")
|
||||
|
||||
// Create authenticator with ALL OAuth2 providers
|
||||
auth := NewDatabaseAuthenticator(db).
|
||||
WithOAuth2(OAuth2Config{
|
||||
ClientID: "google-client-id",
|
||||
ClientSecret: "google-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/google/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
}).
|
||||
WithOAuth2(OAuth2Config{
|
||||
ClientID: "github-client-id",
|
||||
ClientSecret: "github-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/github/callback",
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
ProviderName: "github",
|
||||
}).
|
||||
WithOAuth2(OAuth2Config{
|
||||
ClientID: "microsoft-client-id",
|
||||
ClientSecret: "microsoft-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/microsoft/callback",
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||
TokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
UserInfoURL: "https://graph.microsoft.com/v1.0/me",
|
||||
ProviderName: "microsoft",
|
||||
}).
|
||||
WithOAuth2(OAuth2Config{
|
||||
ClientID: "facebook-client-id",
|
||||
ClientSecret: "facebook-client-secret",
|
||||
RedirectURL: "http://localhost:8080/auth/facebook/callback",
|
||||
Scopes: []string{"email"},
|
||||
AuthURL: "https://www.facebook.com/v12.0/dialog/oauth",
|
||||
TokenURL: "https://graph.facebook.com/v12.0/oauth/access_token",
|
||||
UserInfoURL: "https://graph.facebook.com/me?fields=id,name,email",
|
||||
ProviderName: "facebook",
|
||||
})
|
||||
|
||||
// Get list of configured providers
|
||||
providers := auth.OAuth2GetProviders()
|
||||
fmt.Printf("Configured OAuth2 providers: %v\n", providers)
|
||||
|
||||
router := mux.NewRouter()
|
||||
|
||||
// Google routes
|
||||
router.HandleFunc("/auth/google/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("google", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
router.HandleFunc("/auth/google/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "google", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
// GitHub routes
|
||||
router.HandleFunc("/auth/github/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("github", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
router.HandleFunc("/auth/github/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "github", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
// Microsoft routes
|
||||
router.HandleFunc("/auth/microsoft/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("microsoft", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
router.HandleFunc("/auth/microsoft/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "microsoft", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
// Facebook routes
|
||||
router.HandleFunc("/auth/facebook/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
state, _ := auth.OAuth2GenerateState()
|
||||
authURL, _ := auth.OAuth2GetAuthURL("facebook", state)
|
||||
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||
})
|
||||
router.HandleFunc("/auth/facebook/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginResp, err := auth.OAuth2HandleCallback(r.Context(), "facebook", r.URL.Query().Get("code"), r.URL.Query().Get("state"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(loginResp)
|
||||
})
|
||||
|
||||
// Create security list for protected routes
|
||||
colSec := NewDatabaseColumnSecurityProvider(db)
|
||||
rowSec := NewDatabaseRowSecurityProvider(db)
|
||||
provider, _ := NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||
securityList, _ := NewSecurityList(provider)
|
||||
|
||||
// Protected routes work for ALL OAuth2 providers + regular sessions
|
||||
protectedRouter := router.PathPrefix("/api").Subrouter()
|
||||
protectedRouter.Use(NewAuthMiddleware(securityList))
|
||||
protectedRouter.Use(SetSecurityMiddleware(securityList))
|
||||
|
||||
protectedRouter.HandleFunc("/profile", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, _ := GetUserContext(r.Context())
|
||||
_ = json.NewEncoder(w).Encode(userCtx)
|
||||
})
|
||||
|
||||
_ = http.ListenAndServe(":8080", router)
|
||||
}
|
||||
579
pkg/security/oauth2_methods.go
Normal file
579
pkg/security/oauth2_methods.go
Normal file
@@ -0,0 +1,579 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
// OAuth2Config contains configuration for OAuth2 authentication
|
||||
type OAuth2Config struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
RedirectURL string
|
||||
Scopes []string
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
UserInfoURL string
|
||||
ProviderName string
|
||||
|
||||
// Optional: Custom user info parser
|
||||
// If not provided, will use standard claims (sub, email, name)
|
||||
UserInfoParser func(userInfo map[string]any) (*UserContext, error)
|
||||
}
|
||||
|
||||
// OAuth2Provider holds configuration and state for a single OAuth2 provider
|
||||
type OAuth2Provider struct {
|
||||
config *oauth2.Config
|
||||
userInfoURL string
|
||||
userInfoParser func(userInfo map[string]any) (*UserContext, error)
|
||||
providerName string
|
||||
states map[string]time.Time // state -> expiry time
|
||||
statesMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// WithOAuth2 configures OAuth2 support for the DatabaseAuthenticator
|
||||
// Can be called multiple times to add multiple OAuth2 providers
|
||||
// Returns the same DatabaseAuthenticator instance for method chaining
|
||||
func (a *DatabaseAuthenticator) WithOAuth2(cfg OAuth2Config) *DatabaseAuthenticator {
|
||||
if cfg.ProviderName == "" {
|
||||
cfg.ProviderName = "oauth2"
|
||||
}
|
||||
|
||||
if cfg.UserInfoParser == nil {
|
||||
cfg.UserInfoParser = defaultOAuth2UserInfoParser
|
||||
}
|
||||
|
||||
provider := &OAuth2Provider{
|
||||
config: &oauth2.Config{
|
||||
ClientID: cfg.ClientID,
|
||||
ClientSecret: cfg.ClientSecret,
|
||||
RedirectURL: cfg.RedirectURL,
|
||||
Scopes: cfg.Scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: cfg.AuthURL,
|
||||
TokenURL: cfg.TokenURL,
|
||||
},
|
||||
},
|
||||
userInfoURL: cfg.UserInfoURL,
|
||||
userInfoParser: cfg.UserInfoParser,
|
||||
providerName: cfg.ProviderName,
|
||||
states: make(map[string]time.Time),
|
||||
}
|
||||
|
||||
// Initialize providers map if needed
|
||||
a.oauth2ProvidersMutex.Lock()
|
||||
if a.oauth2Providers == nil {
|
||||
a.oauth2Providers = make(map[string]*OAuth2Provider)
|
||||
}
|
||||
|
||||
// Register provider
|
||||
a.oauth2Providers[cfg.ProviderName] = provider
|
||||
a.oauth2ProvidersMutex.Unlock()
|
||||
|
||||
// Start state cleanup goroutine for this provider
|
||||
go provider.cleanupStates()
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// OAuth2GetAuthURL returns the OAuth2 authorization URL for redirecting users
|
||||
func (a *DatabaseAuthenticator) OAuth2GetAuthURL(providerName, state string) (string, error) {
|
||||
provider, err := a.getOAuth2Provider(providerName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Store state for validation
|
||||
provider.statesMutex.Lock()
|
||||
provider.states[state] = time.Now().Add(10 * time.Minute)
|
||||
provider.statesMutex.Unlock()
|
||||
|
||||
return provider.config.AuthCodeURL(state), nil
|
||||
}
|
||||
|
||||
// OAuth2GenerateState generates a random state string for CSRF protection
|
||||
func (a *DatabaseAuthenticator) OAuth2GenerateState() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// OAuth2HandleCallback handles the OAuth2 callback and exchanges code for token
|
||||
func (a *DatabaseAuthenticator) OAuth2HandleCallback(ctx context.Context, providerName, code, state string) (*LoginResponse, error) {
|
||||
provider, err := a.getOAuth2Provider(providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate state
|
||||
if !provider.validateState(state) {
|
||||
return nil, fmt.Errorf("invalid state parameter")
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
token, err := provider.config.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
|
||||
// Fetch user info
|
||||
client := provider.config.Client(ctx, token)
|
||||
resp, err := client.Get(provider.userInfoURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch user info: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read user info: %w", err)
|
||||
}
|
||||
|
||||
var userInfo map[string]any
|
||||
if err := json.Unmarshal(body, &userInfo); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user info: %w", err)
|
||||
}
|
||||
|
||||
// Parse user info
|
||||
userCtx, err := provider.userInfoParser(userInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user context: %w", err)
|
||||
}
|
||||
|
||||
// Get or create user in database
|
||||
userID, err := a.oauth2GetOrCreateUser(ctx, userCtx, providerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get or create user: %w", err)
|
||||
}
|
||||
userCtx.UserID = userID
|
||||
|
||||
// Create session token
|
||||
sessionToken, err := a.OAuth2GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate session token: %w", err)
|
||||
}
|
||||
|
||||
expiresAt := time.Now().Add(24 * time.Hour)
|
||||
if token.Expiry.After(time.Now()) {
|
||||
expiresAt = token.Expiry
|
||||
}
|
||||
|
||||
// Store session in database
|
||||
err = a.oauth2CreateSession(ctx, sessionToken, userCtx.UserID, token, expiresAt, providerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
userCtx.SessionID = sessionToken
|
||||
|
||||
return &LoginResponse{
|
||||
Token: sessionToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
User: userCtx,
|
||||
ExpiresIn: int64(time.Until(expiresAt).Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// OAuth2GetProviders returns list of configured OAuth2 provider names
|
||||
func (a *DatabaseAuthenticator) OAuth2GetProviders() []string {
|
||||
a.oauth2ProvidersMutex.RLock()
|
||||
defer a.oauth2ProvidersMutex.RUnlock()
|
||||
|
||||
if a.oauth2Providers == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
providers := make([]string, 0, len(a.oauth2Providers))
|
||||
for name := range a.oauth2Providers {
|
||||
providers = append(providers, name)
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
// getOAuth2Provider retrieves a registered OAuth2 provider by name
|
||||
func (a *DatabaseAuthenticator) getOAuth2Provider(providerName string) (*OAuth2Provider, error) {
|
||||
a.oauth2ProvidersMutex.RLock()
|
||||
defer a.oauth2ProvidersMutex.RUnlock()
|
||||
|
||||
if a.oauth2Providers == nil {
|
||||
return nil, fmt.Errorf("OAuth2 not configured - call WithOAuth2() first")
|
||||
}
|
||||
|
||||
provider, ok := a.oauth2Providers[providerName]
|
||||
if !ok {
|
||||
// Build provider list without calling OAuth2GetProviders to avoid recursion
|
||||
providerNames := make([]string, 0, len(a.oauth2Providers))
|
||||
for name := range a.oauth2Providers {
|
||||
providerNames = append(providerNames, name)
|
||||
}
|
||||
return nil, fmt.Errorf("OAuth2 provider '%s' not found - available providers: %v", providerName, providerNames)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// oauth2GetOrCreateUser finds or creates a user based on OAuth2 info using stored procedure
|
||||
func (a *DatabaseAuthenticator) oauth2GetOrCreateUser(ctx context.Context, userCtx *UserContext, providerName string) (int, error) {
|
||||
userData := map[string]interface{}{
|
||||
"username": userCtx.UserName,
|
||||
"email": userCtx.Email,
|
||||
"remote_id": userCtx.RemoteID,
|
||||
"user_level": userCtx.UserLevel,
|
||||
"roles": userCtx.Roles,
|
||||
"auth_provider": providerName,
|
||||
}
|
||||
|
||||
userJSON, err := json.Marshal(userData)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to marshal user data: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var userID *int
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_user_id
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthGetOrCreateUser), userJSON).Scan(&success, &errMsg, &userID)
|
||||
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get or create user: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return 0, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return 0, fmt.Errorf("failed to get or create user")
|
||||
}
|
||||
|
||||
if userID == nil {
|
||||
return 0, fmt.Errorf("user ID not returned")
|
||||
}
|
||||
|
||||
return *userID, nil
|
||||
}
|
||||
|
||||
// oauth2CreateSession creates a new OAuth2 session using stored procedure
|
||||
func (a *DatabaseAuthenticator) oauth2CreateSession(ctx context.Context, sessionToken string, userID int, token *oauth2.Token, expiresAt time.Time, providerName string) error {
|
||||
sessionData := map[string]interface{}{
|
||||
"session_token": sessionToken,
|
||||
"user_id": userID,
|
||||
"access_token": token.AccessToken,
|
||||
"refresh_token": token.RefreshToken,
|
||||
"token_type": token.TokenType,
|
||||
"expires_at": expiresAt,
|
||||
"auth_provider": providerName,
|
||||
}
|
||||
|
||||
sessionJSON, err := json.Marshal(sessionData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal session data: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthCreateSession), sessionJSON).Scan(&success, &errMsg)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create session: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return fmt.Errorf("failed to create session")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateState validates state using in-memory storage
|
||||
func (p *OAuth2Provider) validateState(state string) bool {
|
||||
p.statesMutex.Lock()
|
||||
defer p.statesMutex.Unlock()
|
||||
|
||||
expiry, ok := p.states[state]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Now().After(expiry) {
|
||||
delete(p.states, state)
|
||||
return false
|
||||
}
|
||||
|
||||
delete(p.states, state) // One-time use
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanupStates removes expired states periodically
|
||||
func (p *OAuth2Provider) cleanupStates() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
p.statesMutex.Lock()
|
||||
now := time.Now()
|
||||
for state, expiry := range p.states {
|
||||
if now.After(expiry) {
|
||||
delete(p.states, state)
|
||||
}
|
||||
}
|
||||
p.statesMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// defaultOAuth2UserInfoParser parses standard OAuth2 user info claims
|
||||
func defaultOAuth2UserInfoParser(userInfo map[string]any) (*UserContext, error) {
|
||||
ctx := &UserContext{
|
||||
Claims: userInfo,
|
||||
Roles: []string{"user"},
|
||||
}
|
||||
|
||||
// Extract standard claims
|
||||
if sub, ok := userInfo["sub"].(string); ok {
|
||||
ctx.RemoteID = sub
|
||||
}
|
||||
if email, ok := userInfo["email"].(string); ok {
|
||||
ctx.Email = email
|
||||
// Use email as username if name not available
|
||||
ctx.UserName = strings.Split(email, "@")[0]
|
||||
}
|
||||
if name, ok := userInfo["name"].(string); ok {
|
||||
ctx.UserName = name
|
||||
}
|
||||
if login, ok := userInfo["login"].(string); ok {
|
||||
ctx.UserName = login // GitHub uses "login"
|
||||
}
|
||||
|
||||
if ctx.UserName == "" {
|
||||
return nil, fmt.Errorf("could not extract username from user info")
|
||||
}
|
||||
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
// OAuth2RefreshToken refreshes an expired OAuth2 access token using the refresh token
|
||||
// Takes the refresh token and returns a new LoginResponse with updated tokens
|
||||
func (a *DatabaseAuthenticator) OAuth2RefreshToken(ctx context.Context, refreshToken, providerName string) (*LoginResponse, error) {
|
||||
provider, err := a.getOAuth2Provider(providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get session by refresh token from database
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var sessionData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetRefreshToken), refreshToken).Scan(&success, &errMsg, &sessionData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session by refresh token: %w", err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid or expired refresh token")
|
||||
}
|
||||
|
||||
// Parse session data
|
||||
var session struct {
|
||||
UserID int `json:"user_id"`
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
Expiry time.Time `json:"expiry"`
|
||||
}
|
||||
if err := json.Unmarshal(sessionData, &session); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse session data: %w", err)
|
||||
}
|
||||
|
||||
// Create oauth2.Token from stored data
|
||||
oldToken := &oauth2.Token{
|
||||
AccessToken: session.AccessToken,
|
||||
TokenType: session.TokenType,
|
||||
RefreshToken: refreshToken,
|
||||
Expiry: session.Expiry,
|
||||
}
|
||||
|
||||
// Use OAuth2 provider to refresh the token
|
||||
tokenSource := provider.config.TokenSource(ctx, oldToken)
|
||||
newToken, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh token with provider: %w", err)
|
||||
}
|
||||
|
||||
// Generate new session token
|
||||
newSessionToken, err := a.OAuth2GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate new session token: %w", err)
|
||||
}
|
||||
|
||||
// Update session in database with new tokens
|
||||
updateData := map[string]interface{}{
|
||||
"user_id": session.UserID,
|
||||
"old_refresh_token": refreshToken,
|
||||
"new_session_token": newSessionToken,
|
||||
"new_access_token": newToken.AccessToken,
|
||||
"new_refresh_token": newToken.RefreshToken,
|
||||
"expires_at": newToken.Expiry,
|
||||
}
|
||||
|
||||
updateJSON, err := json.Marshal(updateData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal update data: %w", err)
|
||||
}
|
||||
|
||||
var updateSuccess bool
|
||||
var updateErrMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthUpdateRefreshToken), updateJSON).Scan(&updateSuccess, &updateErrMsg)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update session: %w", err)
|
||||
}
|
||||
|
||||
if !updateSuccess {
|
||||
if updateErrMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *updateErrMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to update session")
|
||||
}
|
||||
|
||||
// Get user data
|
||||
var userSuccess bool
|
||||
var userErrMsg *string
|
||||
var userData []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetUser), session.UserID).Scan(&userSuccess, &userErrMsg, &userData)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user data: %w", err)
|
||||
}
|
||||
|
||||
if !userSuccess {
|
||||
if userErrMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *userErrMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get user data")
|
||||
}
|
||||
|
||||
// Parse user context
|
||||
var userCtx UserContext
|
||||
if err := json.Unmarshal(userData, &userCtx); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse user context: %w", err)
|
||||
}
|
||||
|
||||
userCtx.SessionID = newSessionToken
|
||||
|
||||
return &LoginResponse{
|
||||
Token: newSessionToken,
|
||||
RefreshToken: newToken.RefreshToken,
|
||||
User: &userCtx,
|
||||
ExpiresIn: int64(time.Until(newToken.Expiry).Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Pre-configured OAuth2 factory methods
|
||||
|
||||
// NewGoogleAuthenticator creates a DatabaseAuthenticator configured for Google OAuth2
|
||||
func NewGoogleAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
return auth.WithOAuth2(OAuth2Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
ProviderName: "google",
|
||||
})
|
||||
}
|
||||
|
||||
// NewGitHubAuthenticator creates a DatabaseAuthenticator configured for GitHub OAuth2
|
||||
func NewGitHubAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
return auth.WithOAuth2(OAuth2Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"user:email"},
|
||||
AuthURL: "https://github.com/login/oauth/authorize",
|
||||
TokenURL: "https://github.com/login/oauth/access_token",
|
||||
UserInfoURL: "https://api.github.com/user",
|
||||
ProviderName: "github",
|
||||
})
|
||||
}
|
||||
|
||||
// NewMicrosoftAuthenticator creates a DatabaseAuthenticator configured for Microsoft OAuth2
|
||||
func NewMicrosoftAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
return auth.WithOAuth2(OAuth2Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"openid", "profile", "email"},
|
||||
AuthURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||
TokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
UserInfoURL: "https://graph.microsoft.com/v1.0/me",
|
||||
ProviderName: "microsoft",
|
||||
})
|
||||
}
|
||||
|
||||
// NewFacebookAuthenticator creates a DatabaseAuthenticator configured for Facebook OAuth2
|
||||
func NewFacebookAuthenticator(clientID, clientSecret, redirectURL string, db *sql.DB) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
return auth.WithOAuth2(OAuth2Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"email"},
|
||||
AuthURL: "https://www.facebook.com/v12.0/dialog/oauth",
|
||||
TokenURL: "https://graph.facebook.com/v12.0/oauth/access_token",
|
||||
UserInfoURL: "https://graph.facebook.com/me?fields=id,name,email",
|
||||
ProviderName: "facebook",
|
||||
})
|
||||
}
|
||||
|
||||
// NewMultiProviderAuthenticator creates a DatabaseAuthenticator with all major OAuth2 providers configured
|
||||
func NewMultiProviderAuthenticator(db *sql.DB, configs map[string]OAuth2Config) *DatabaseAuthenticator {
|
||||
auth := NewDatabaseAuthenticator(db)
|
||||
|
||||
//nolint:gocritic // OAuth2Config is copied but kept for API simplicity
|
||||
for _, cfg := range configs {
|
||||
auth.WithOAuth2(cfg)
|
||||
}
|
||||
|
||||
return auth
|
||||
}
|
||||
859
pkg/security/oauth_server.go
Normal file
859
pkg/security/oauth_server.go
Normal file
@@ -0,0 +1,859 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthServerConfig configures the MCP-standard OAuth2 authorization server.
|
||||
type OAuthServerConfig struct {
|
||||
// Issuer is the public base URL of this server (e.g. "https://api.example.com").
|
||||
// Used in /.well-known/oauth-authorization-server and to build endpoint URLs.
|
||||
Issuer string
|
||||
|
||||
// ProviderCallbackPath is the path on this server that external OAuth2 providers
|
||||
// redirect back to. Defaults to "/oauth/provider/callback".
|
||||
ProviderCallbackPath string
|
||||
|
||||
// LoginTitle is shown on the built-in login form when the server acts as its own
|
||||
// identity provider. Defaults to "MCP Login".
|
||||
LoginTitle string
|
||||
|
||||
// PersistClients stores registered clients in the database when a DatabaseAuthenticator is provided.
|
||||
// Clients registered during a session survive server restarts.
|
||||
PersistClients bool
|
||||
|
||||
// PersistCodes stores authorization codes in the database.
|
||||
// Useful for multi-instance deployments. Defaults to in-memory.
|
||||
PersistCodes bool
|
||||
|
||||
// DefaultScopes lists scopes advertised in server metadata. Defaults to ["openid","profile","email"].
|
||||
DefaultScopes []string
|
||||
|
||||
// AccessTokenTTL is the issued token lifetime. Defaults to 24h.
|
||||
AccessTokenTTL time.Duration
|
||||
|
||||
// AuthCodeTTL is the auth code lifetime. Defaults to 2 minutes.
|
||||
AuthCodeTTL time.Duration
|
||||
}
|
||||
|
||||
// oauthClient is a dynamically registered OAuth2 client (RFC 7591).
|
||||
type oauthClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||
}
|
||||
|
||||
// pendingAuth tracks an in-progress authorization code exchange.
|
||||
type pendingAuth struct {
|
||||
ClientID string
|
||||
RedirectURI string
|
||||
ClientState string
|
||||
CodeChallenge string
|
||||
CodeChallengeMethod string
|
||||
ProviderName string // empty = password login
|
||||
ExpiresAt time.Time
|
||||
SessionToken string // set after authentication completes
|
||||
Scopes []string // requested scopes
|
||||
}
|
||||
|
||||
// externalProvider pairs a DatabaseAuthenticator with its provider name.
|
||||
type externalProvider struct {
|
||||
auth *DatabaseAuthenticator
|
||||
providerName string
|
||||
}
|
||||
|
||||
// OAuthServer implements the MCP-standard OAuth2 authorization server (OAuth 2.1 + PKCE).
|
||||
//
|
||||
// It can act as both:
|
||||
// - A direct identity provider using DatabaseAuthenticator username/password login
|
||||
// - A federation layer that delegates authentication to external OAuth2 providers
|
||||
// (Google, GitHub, Microsoft, etc.) registered via RegisterExternalProvider
|
||||
//
|
||||
// The server exposes these RFC-compliant endpoints:
|
||||
//
|
||||
// GET /.well-known/oauth-authorization-server RFC 8414 — server metadata discovery
|
||||
// POST /oauth/register RFC 7591 — dynamic client registration
|
||||
// GET /oauth/authorize OAuth 2.1 + PKCE — start authorization
|
||||
// POST /oauth/authorize Direct login form submission
|
||||
// POST /oauth/token Token exchange and refresh
|
||||
// POST /oauth/revoke RFC 7009 — token revocation
|
||||
// POST /oauth/introspect RFC 7662 — token introspection
|
||||
// GET {ProviderCallbackPath} Internal — external provider callback
|
||||
type OAuthServer struct {
|
||||
cfg OAuthServerConfig
|
||||
auth *DatabaseAuthenticator // nil = only external providers
|
||||
providers []externalProvider
|
||||
|
||||
mu sync.RWMutex
|
||||
clients map[string]*oauthClient
|
||||
pending map[string]*pendingAuth // provider_state → pending (external flow)
|
||||
codes map[string]*pendingAuth // auth_code → pending (post-auth)
|
||||
}
|
||||
|
||||
// NewOAuthServer creates a new MCP OAuth2 authorization server.
|
||||
//
|
||||
// Pass a DatabaseAuthenticator to enable direct username/password login (the server
|
||||
// acts as its own identity provider). Pass nil to use only external providers.
|
||||
// External providers are added separately via RegisterExternalProvider.
|
||||
func NewOAuthServer(cfg OAuthServerConfig, auth *DatabaseAuthenticator) *OAuthServer {
|
||||
if cfg.ProviderCallbackPath == "" {
|
||||
cfg.ProviderCallbackPath = "/oauth/provider/callback"
|
||||
}
|
||||
if cfg.LoginTitle == "" {
|
||||
cfg.LoginTitle = "Sign in"
|
||||
}
|
||||
if len(cfg.DefaultScopes) == 0 {
|
||||
cfg.DefaultScopes = []string{"openid", "profile", "email"}
|
||||
}
|
||||
if cfg.AccessTokenTTL == 0 {
|
||||
cfg.AccessTokenTTL = 24 * time.Hour
|
||||
}
|
||||
if cfg.AuthCodeTTL == 0 {
|
||||
cfg.AuthCodeTTL = 2 * time.Minute
|
||||
}
|
||||
s := &OAuthServer{
|
||||
cfg: cfg,
|
||||
auth: auth,
|
||||
clients: make(map[string]*oauthClient),
|
||||
pending: make(map[string]*pendingAuth),
|
||||
codes: make(map[string]*pendingAuth),
|
||||
}
|
||||
go s.cleanupExpired()
|
||||
return s
|
||||
}
|
||||
|
||||
// RegisterExternalProvider adds an external OAuth2 provider (Google, GitHub, Microsoft, etc.)
|
||||
// that handles user authentication via redirect. The DatabaseAuthenticator must have been
|
||||
// configured with WithOAuth2(providerName, ...) before calling this.
|
||||
// Multiple providers can be registered; the first is used as the default.
|
||||
func (s *OAuthServer) RegisterExternalProvider(auth *DatabaseAuthenticator, providerName string) {
|
||||
s.providers = append(s.providers, externalProvider{auth: auth, providerName: providerName})
|
||||
}
|
||||
|
||||
// ProviderCallbackPath returns the configured path for external provider callbacks.
|
||||
func (s *OAuthServer) ProviderCallbackPath() string {
|
||||
return s.cfg.ProviderCallbackPath
|
||||
}
|
||||
|
||||
// HTTPHandler returns an http.Handler that serves all RFC-required OAuth2 endpoints.
|
||||
// Mount it at the root of your HTTP server alongside the MCP transport.
|
||||
//
|
||||
// mux := http.NewServeMux()
|
||||
// mux.Handle("/", oauthServer.HTTPHandler())
|
||||
// mux.Handle("/mcp/", mcpTransport)
|
||||
func (s *OAuthServer) HTTPHandler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/.well-known/oauth-authorization-server", s.metadataHandler)
|
||||
mux.HandleFunc("/oauth/register", s.registerHandler)
|
||||
mux.HandleFunc("/oauth/authorize", s.authorizeHandler)
|
||||
mux.HandleFunc("/oauth/token", s.tokenHandler)
|
||||
mux.HandleFunc("/oauth/revoke", s.revokeHandler)
|
||||
mux.HandleFunc("/oauth/introspect", s.introspectHandler)
|
||||
mux.HandleFunc(s.cfg.ProviderCallbackPath, s.providerCallbackHandler)
|
||||
return mux
|
||||
}
|
||||
|
||||
// cleanupExpired removes stale pending auths and codes every 5 minutes.
|
||||
func (s *OAuthServer) cleanupExpired() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.mu.Lock()
|
||||
for k, p := range s.pending {
|
||||
if now.After(p.ExpiresAt) {
|
||||
delete(s.pending, k)
|
||||
}
|
||||
}
|
||||
for k, p := range s.codes {
|
||||
if now.After(p.ExpiresAt) {
|
||||
delete(s.codes, k)
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 8414 — Server metadata
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) metadataHandler(w http.ResponseWriter, r *http.Request) {
|
||||
issuer := s.cfg.Issuer
|
||||
meta := map[string]interface{}{
|
||||
"issuer": issuer,
|
||||
"authorization_endpoint": issuer + "/oauth/authorize",
|
||||
"token_endpoint": issuer + "/oauth/token",
|
||||
"registration_endpoint": issuer + "/oauth/register",
|
||||
"revocation_endpoint": issuer + "/oauth/revoke",
|
||||
"introspection_endpoint": issuer + "/oauth/introspect",
|
||||
"scopes_supported": s.cfg.DefaultScopes,
|
||||
"response_types_supported": []string{"code"},
|
||||
"grant_types_supported": []string{"authorization_code", "refresh_token"},
|
||||
"code_challenge_methods_supported": []string{"S256"},
|
||||
"token_endpoint_auth_methods_supported": []string{"none"},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(meta) //nolint:errcheck
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 7591 — Dynamic client registration
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) registerHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
var req struct {
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeOAuthError(w, "invalid_request", "malformed JSON", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if len(req.RedirectURIs) == 0 {
|
||||
writeOAuthError(w, "invalid_request", "redirect_uris required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
grantTypes := req.GrantTypes
|
||||
if len(grantTypes) == 0 {
|
||||
grantTypes = []string{"authorization_code"}
|
||||
}
|
||||
allowedScopes := req.AllowedScopes
|
||||
if len(allowedScopes) == 0 {
|
||||
allowedScopes = s.cfg.DefaultScopes
|
||||
}
|
||||
clientID, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
client := &oauthClient{
|
||||
ClientID: clientID,
|
||||
RedirectURIs: req.RedirectURIs,
|
||||
ClientName: req.ClientName,
|
||||
GrantTypes: grantTypes,
|
||||
AllowedScopes: allowedScopes,
|
||||
}
|
||||
|
||||
if s.cfg.PersistClients && s.auth != nil {
|
||||
dbClient := &OAuthServerClient{
|
||||
ClientID: client.ClientID,
|
||||
RedirectURIs: client.RedirectURIs,
|
||||
ClientName: client.ClientName,
|
||||
GrantTypes: client.GrantTypes,
|
||||
AllowedScopes: client.AllowedScopes,
|
||||
}
|
||||
if _, err := s.auth.OAuthRegisterClient(r.Context(), dbClient); err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.clients[clientID] = client
|
||||
s.mu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(client) //nolint:errcheck
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Authorization endpoint — GET + POST /oauth/authorize
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) authorizeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
s.authorizeGet(w, r)
|
||||
case http.MethodPost:
|
||||
s.authorizePost(w, r)
|
||||
default:
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// authorizeGet validates the request and either:
|
||||
// - Redirects to an external provider (if providers are registered)
|
||||
// - Renders a login form (if the server is its own identity provider)
|
||||
func (s *OAuthServer) authorizeGet(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
clientID := q.Get("client_id")
|
||||
redirectURI := q.Get("redirect_uri")
|
||||
clientState := q.Get("state")
|
||||
codeChallenge := q.Get("code_challenge")
|
||||
codeChallengeMethod := q.Get("code_challenge_method")
|
||||
providerName := q.Get("provider")
|
||||
scopeStr := q.Get("scope")
|
||||
var scopes []string
|
||||
if scopeStr != "" {
|
||||
scopes = strings.Fields(scopeStr)
|
||||
}
|
||||
|
||||
if q.Get("response_type") != "code" {
|
||||
writeOAuthError(w, "unsupported_response_type", "only 'code' is supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if codeChallenge == "" {
|
||||
writeOAuthError(w, "invalid_request", "code_challenge required (PKCE S256)", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if codeChallengeMethod != "" && codeChallengeMethod != "S256" {
|
||||
writeOAuthError(w, "invalid_request", "only S256 code_challenge_method is supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
|
||||
if !ok {
|
||||
writeOAuthError(w, "invalid_client", "unknown client_id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !oauthSliceContains(client.RedirectURIs, redirectURI) {
|
||||
writeOAuthError(w, "invalid_request", "redirect_uri not registered", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// External provider path
|
||||
if len(s.providers) > 0 {
|
||||
s.redirectToExternalProvider(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName, scopes)
|
||||
return
|
||||
}
|
||||
|
||||
// Direct login form path (server is its own identity provider)
|
||||
if s.auth == nil {
|
||||
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "")
|
||||
}
|
||||
|
||||
// authorizePost handles login form submission for the direct login flow.
|
||||
func (s *OAuthServer) authorizePost(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "invalid form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
clientID := r.FormValue("client_id")
|
||||
redirectURI := r.FormValue("redirect_uri")
|
||||
clientState := r.FormValue("client_state")
|
||||
codeChallenge := r.FormValue("code_challenge")
|
||||
codeChallengeMethod := r.FormValue("code_challenge_method")
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
scopeStr := r.FormValue("scope")
|
||||
var scopes []string
|
||||
if scopeStr != "" {
|
||||
scopes = strings.Fields(scopeStr)
|
||||
}
|
||||
|
||||
client, ok := s.lookupOrFetchClient(r.Context(), clientID)
|
||||
if !ok || !oauthSliceContains(client.RedirectURIs, redirectURI) {
|
||||
http.Error(w, "invalid client or redirect_uri", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if s.auth == nil {
|
||||
http.Error(w, "no authentication provider configured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loginResp, err := s.auth.Login(r.Context(), LoginRequest{
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
if err != nil {
|
||||
s.renderLoginForm(w, r, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scopeStr, "Invalid username or password")
|
||||
return
|
||||
}
|
||||
|
||||
s.issueCodeAndRedirect(w, r, loginResp.Token, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, "", scopes)
|
||||
}
|
||||
|
||||
// redirectToExternalProvider stores the pending auth and redirects to the configured provider.
|
||||
func (s *OAuthServer) redirectToExternalProvider(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
|
||||
var provider *externalProvider
|
||||
if providerName != "" {
|
||||
for i := range s.providers {
|
||||
if s.providers[i].providerName == providerName {
|
||||
provider = &s.providers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if provider == nil {
|
||||
http.Error(w, fmt.Sprintf("provider %q not found", providerName), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
provider = &s.providers[0]
|
||||
}
|
||||
|
||||
providerState, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
pending := &pendingAuth{
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ClientState: clientState,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
ProviderName: provider.providerName,
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
Scopes: scopes,
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.pending[providerState] = pending
|
||||
s.mu.Unlock()
|
||||
|
||||
authURL, err := provider.auth.OAuth2GetAuthURL(provider.providerName, providerState)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// External provider callback — GET {ProviderCallbackPath}
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) providerCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
providerState := r.URL.Query().Get("state")
|
||||
|
||||
if code == "" {
|
||||
http.Error(w, "missing code", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
pending, ok := s.pending[providerState]
|
||||
if ok {
|
||||
delete(s.pending, providerState)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok || time.Now().After(pending.ExpiresAt) {
|
||||
http.Error(w, "invalid or expired state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
provider := s.providerByName(pending.ProviderName)
|
||||
if provider == nil {
|
||||
http.Error(w, fmt.Sprintf("provider %q not found", pending.ProviderName), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
loginResp, err := provider.auth.OAuth2HandleCallback(r.Context(), pending.ProviderName, code, providerState)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
s.issueCodeAndRedirect(w, r, loginResp.Token,
|
||||
pending.ClientID, pending.RedirectURI, pending.ClientState,
|
||||
pending.CodeChallenge, pending.CodeChallengeMethod, pending.ProviderName, pending.Scopes)
|
||||
}
|
||||
|
||||
// issueCodeAndRedirect generates a short-lived auth code and redirects to the MCP client.
|
||||
func (s *OAuthServer) issueCodeAndRedirect(w http.ResponseWriter, r *http.Request, sessionToken, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, providerName string, scopes []string) {
|
||||
authCode, err := randomOAuthToken()
|
||||
if err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
pending := &pendingAuth{
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ClientState: clientState,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
ProviderName: providerName,
|
||||
SessionToken: sessionToken,
|
||||
ExpiresAt: time.Now().Add(s.cfg.AuthCodeTTL),
|
||||
Scopes: scopes,
|
||||
}
|
||||
|
||||
if s.cfg.PersistCodes && s.auth != nil {
|
||||
oauthCode := &OAuthCode{
|
||||
Code: authCode,
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
ClientState: clientState,
|
||||
CodeChallenge: codeChallenge,
|
||||
CodeChallengeMethod: codeChallengeMethod,
|
||||
SessionToken: sessionToken,
|
||||
Scopes: scopes,
|
||||
ExpiresAt: pending.ExpiresAt,
|
||||
}
|
||||
if err := s.auth.OAuthSaveCode(r.Context(), oauthCode); err != nil {
|
||||
http.Error(w, "server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
s.mu.Lock()
|
||||
s.codes[authCode] = pending
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
redirectURL, err := url.Parse(redirectURI)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid redirect_uri", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
qp := redirectURL.Query()
|
||||
qp.Set("code", authCode)
|
||||
if clientState != "" {
|
||||
qp.Set("state", clientState)
|
||||
}
|
||||
redirectURL.RawQuery = qp.Encode()
|
||||
http.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Token endpoint — POST /oauth/token
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) tokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
writeOAuthError(w, "invalid_request", "cannot parse form", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
switch r.FormValue("grant_type") {
|
||||
case "authorization_code":
|
||||
s.handleAuthCodeGrant(w, r)
|
||||
case "refresh_token":
|
||||
s.handleRefreshGrant(w, r)
|
||||
default:
|
||||
writeOAuthError(w, "unsupported_grant_type", "", http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleAuthCodeGrant(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.FormValue("code")
|
||||
redirectURI := r.FormValue("redirect_uri")
|
||||
clientID := r.FormValue("client_id")
|
||||
codeVerifier := r.FormValue("code_verifier")
|
||||
|
||||
if code == "" || codeVerifier == "" {
|
||||
writeOAuthError(w, "invalid_request", "code and code_verifier required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var sessionToken string
|
||||
var scopes []string
|
||||
|
||||
if s.cfg.PersistCodes && s.auth != nil {
|
||||
oauthCode, err := s.auth.OAuthExchangeCode(r.Context(), code)
|
||||
if err != nil {
|
||||
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if oauthCode.ClientID != clientID {
|
||||
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if oauthCode.RedirectURI != redirectURI {
|
||||
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validatePKCESHA256(oauthCode.CodeChallenge, codeVerifier) {
|
||||
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
sessionToken = oauthCode.SessionToken
|
||||
scopes = oauthCode.Scopes
|
||||
} else {
|
||||
s.mu.Lock()
|
||||
pending, ok := s.codes[code]
|
||||
if ok {
|
||||
delete(s.codes, code)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !ok || time.Now().After(pending.ExpiresAt) {
|
||||
writeOAuthError(w, "invalid_grant", "code expired or invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if pending.ClientID != clientID {
|
||||
writeOAuthError(w, "invalid_client", "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if pending.RedirectURI != redirectURI {
|
||||
writeOAuthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if !validatePKCESHA256(pending.CodeChallenge, codeVerifier) {
|
||||
writeOAuthError(w, "invalid_grant", "code_verifier invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
sessionToken = pending.SessionToken
|
||||
scopes = pending.Scopes
|
||||
}
|
||||
|
||||
writeOAuthToken(w, sessionToken, "", scopes)
|
||||
}
|
||||
|
||||
func (s *OAuthServer) handleRefreshGrant(w http.ResponseWriter, r *http.Request) {
|
||||
refreshToken := r.FormValue("refresh_token")
|
||||
providerName := r.FormValue("provider")
|
||||
if refreshToken == "" {
|
||||
writeOAuthError(w, "invalid_request", "refresh_token required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Try external providers first, then fall back to DatabaseAuthenticator
|
||||
provider := s.providerByName(providerName)
|
||||
if provider != nil {
|
||||
loginResp, err := provider.auth.OAuth2RefreshToken(r.Context(), refreshToken, providerName)
|
||||
if err != nil {
|
||||
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
|
||||
return
|
||||
}
|
||||
|
||||
if s.auth != nil {
|
||||
loginResp, err := s.auth.RefreshToken(r.Context(), refreshToken)
|
||||
if err != nil {
|
||||
writeOAuthError(w, "invalid_grant", err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
writeOAuthToken(w, loginResp.Token, loginResp.RefreshToken, nil)
|
||||
return
|
||||
}
|
||||
|
||||
writeOAuthError(w, "invalid_grant", "no provider available for refresh", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 7009 — Token revocation
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) revokeHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
token := r.FormValue("token")
|
||||
if token == "" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
if s.auth != nil {
|
||||
s.auth.OAuthRevokeToken(r.Context(), token) //nolint:errcheck
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// RFC 7662 — Token introspection
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) introspectHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
token := r.FormValue("token")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if token == "" || s.auth == nil {
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
|
||||
info, err := s.auth.OAuthIntrospectToken(r.Context(), token)
|
||||
if err != nil {
|
||||
w.Write([]byte(`{"active":false}`)) //nolint:errcheck
|
||||
return
|
||||
}
|
||||
json.NewEncoder(w).Encode(info) //nolint:errcheck
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Login form (direct identity provider mode)
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
func (s *OAuthServer) renderLoginForm(w http.ResponseWriter, r *http.Request, clientID, redirectURI, clientState, codeChallenge, codeChallengeMethod, scope, errMsg string) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
errHTML := ""
|
||||
if errMsg != "" {
|
||||
errHTML = `<p style="color:red">` + errMsg + `</p>`
|
||||
}
|
||||
fmt.Fprintf(w, loginFormHTML,
|
||||
s.cfg.LoginTitle,
|
||||
s.cfg.LoginTitle,
|
||||
errHTML,
|
||||
clientID,
|
||||
htmlEscape(redirectURI),
|
||||
htmlEscape(clientState),
|
||||
htmlEscape(codeChallenge),
|
||||
htmlEscape(codeChallengeMethod),
|
||||
htmlEscape(scope),
|
||||
)
|
||||
}
|
||||
|
||||
const loginFormHTML = `<!DOCTYPE html>
|
||||
<html><head><meta charset="utf-8"><title>%s</title>
|
||||
<style>body{font-family:sans-serif;display:flex;justify-content:center;align-items:center;min-height:100vh;margin:0;background:#f5f5f5}
|
||||
.card{background:#fff;padding:2rem;border-radius:8px;box-shadow:0 2px 8px rgba(0,0,0,.15);width:320px}
|
||||
h2{margin:0 0 1.5rem;font-size:1.25rem}
|
||||
label{display:block;margin-bottom:.25rem;font-size:.875rem;color:#555}
|
||||
input[type=text],input[type=password]{width:100%%;box-sizing:border-box;padding:.5rem;border:1px solid #ccc;border-radius:4px;margin-bottom:1rem;font-size:1rem}
|
||||
button{width:100%%;padding:.6rem;background:#0070f3;color:#fff;border:none;border-radius:4px;font-size:1rem;cursor:pointer}
|
||||
button:hover{background:#005fd4}.err{color:#d32f2f;margin-bottom:1rem;font-size:.875rem}</style>
|
||||
</head><body><div class="card">
|
||||
<h2>%s</h2>%s
|
||||
<form method="POST" action="/oauth/authorize">
|
||||
<input type="hidden" name="client_id" value="%s">
|
||||
<input type="hidden" name="redirect_uri" value="%s">
|
||||
<input type="hidden" name="client_state" value="%s">
|
||||
<input type="hidden" name="code_challenge" value="%s">
|
||||
<input type="hidden" name="code_challenge_method" value="%s">
|
||||
<input type="hidden" name="scope" value="%s">
|
||||
<label>Username</label><input type="text" name="username" autofocus autocomplete="username">
|
||||
<label>Password</label><input type="password" name="password" autocomplete="current-password">
|
||||
<button type="submit">Sign in</button>
|
||||
</form></div></body></html>`
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// lookupOrFetchClient checks in-memory first, then DB if PersistClients is enabled.
|
||||
func (s *OAuthServer) lookupOrFetchClient(ctx context.Context, clientID string) (*oauthClient, bool) {
|
||||
s.mu.RLock()
|
||||
c, ok := s.clients[clientID]
|
||||
s.mu.RUnlock()
|
||||
if ok {
|
||||
return c, true
|
||||
}
|
||||
|
||||
if !s.cfg.PersistClients || s.auth == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
dbClient, err := s.auth.OAuthGetClient(ctx, clientID)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
c = &oauthClient{
|
||||
ClientID: dbClient.ClientID,
|
||||
RedirectURIs: dbClient.RedirectURIs,
|
||||
ClientName: dbClient.ClientName,
|
||||
GrantTypes: dbClient.GrantTypes,
|
||||
AllowedScopes: dbClient.AllowedScopes,
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.clients[clientID] = c
|
||||
s.mu.Unlock()
|
||||
return c, true
|
||||
}
|
||||
|
||||
func (s *OAuthServer) providerByName(name string) *externalProvider {
|
||||
for i := range s.providers {
|
||||
if s.providers[i].providerName == name {
|
||||
return &s.providers[i]
|
||||
}
|
||||
}
|
||||
// If name is empty and only one provider exists, return it
|
||||
if name == "" && len(s.providers) == 1 {
|
||||
return &s.providers[0]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePKCESHA256(challenge, verifier string) bool {
|
||||
h := sha256.Sum256([]byte(verifier))
|
||||
return base64.RawURLEncoding.EncodeToString(h[:]) == challenge
|
||||
}
|
||||
|
||||
func randomOAuthToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func oauthSliceContains(slice []string, s string) bool {
|
||||
for _, v := range slice {
|
||||
if strings.EqualFold(v, s) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func writeOAuthToken(w http.ResponseWriter, accessToken, refreshToken string, scopes []string) {
|
||||
resp := map[string]interface{}{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
}
|
||||
if refreshToken != "" {
|
||||
resp["refresh_token"] = refreshToken
|
||||
}
|
||||
if len(scopes) > 0 {
|
||||
resp["scope"] = strings.Join(scopes, " ")
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
w.Header().Set("Pragma", "no-cache")
|
||||
json.NewEncoder(w).Encode(resp) //nolint:errcheck
|
||||
}
|
||||
|
||||
func writeOAuthError(w http.ResponseWriter, errCode, description string, status int) {
|
||||
resp := map[string]string{"error": errCode}
|
||||
if description != "" {
|
||||
resp["error_description"] = description
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(resp) //nolint:errcheck
|
||||
}
|
||||
|
||||
func htmlEscape(s string) string {
|
||||
s = strings.ReplaceAll(s, "&", "&")
|
||||
s = strings.ReplaceAll(s, `"`, """)
|
||||
s = strings.ReplaceAll(s, "<", "<")
|
||||
s = strings.ReplaceAll(s, ">", ">")
|
||||
return s
|
||||
}
|
||||
202
pkg/security/oauth_server_db.go
Normal file
202
pkg/security/oauth_server_db.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OAuthServerClient is a persisted RFC 7591 registered OAuth2 client.
|
||||
type OAuthServerClient struct {
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURIs []string `json:"redirect_uris"`
|
||||
ClientName string `json:"client_name,omitempty"`
|
||||
GrantTypes []string `json:"grant_types"`
|
||||
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthCode is a short-lived authorization code.
|
||||
type OAuthCode struct {
|
||||
Code string `json:"code"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
ClientState string `json:"client_state,omitempty"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
SessionToken string `json:"session_token"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
// OAuthTokenInfo is the RFC 7662 token introspection response.
|
||||
type OAuthTokenInfo struct {
|
||||
Active bool `json:"active"`
|
||||
Sub string `json:"sub,omitempty"`
|
||||
Username string `json:"username,omitempty"`
|
||||
Email string `json:"email,omitempty"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
Exp int64 `json:"exp,omitempty"`
|
||||
Iat int64 `json:"iat,omitempty"`
|
||||
}
|
||||
|
||||
// OAuthRegisterClient persists an OAuth2 client registration.
|
||||
func (a *DatabaseAuthenticator) OAuthRegisterClient(ctx context.Context, client *OAuthServerClient) (*OAuthServerClient, error) {
|
||||
input, err := json.Marshal(client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal client: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthRegisterClient), input).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to register client: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to register client")
|
||||
}
|
||||
|
||||
var result OAuthServerClient
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse registered client: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthGetClient retrieves a registered client by ID.
|
||||
func (a *DatabaseAuthenticator) OAuthGetClient(ctx context.Context, clientID string) (*OAuthServerClient, error) {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthGetClient), clientID).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get client: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("client not found")
|
||||
}
|
||||
|
||||
var result OAuthServerClient
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse client: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthSaveCode persists an authorization code.
|
||||
func (a *DatabaseAuthenticator) OAuthSaveCode(ctx context.Context, code *OAuthCode) error {
|
||||
input, err := json.Marshal(code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal code: %w", err)
|
||||
}
|
||||
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err = a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1::jsonb)
|
||||
`, a.sqlNames.OAuthSaveCode), input).Scan(&success, &errMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save code: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return fmt.Errorf("failed to save code")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// OAuthExchangeCode retrieves and deletes an authorization code (single use).
|
||||
func (a *DatabaseAuthenticator) OAuthExchangeCode(ctx context.Context, code string) (*OAuthCode, error) {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthExchangeCode), code).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("invalid or expired code")
|
||||
}
|
||||
|
||||
var result OAuthCode
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse code data: %w", err)
|
||||
}
|
||||
result.Code = code
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthIntrospectToken validates a token and returns its metadata (RFC 7662).
|
||||
func (a *DatabaseAuthenticator) OAuthIntrospectToken(ctx context.Context, token string) (*OAuthTokenInfo, error) {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
var data []byte
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error, p_data::text
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthIntrospect), token).Scan(&success, &errMsg, &data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to introspect token: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return nil, fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return nil, fmt.Errorf("introspection failed")
|
||||
}
|
||||
|
||||
var result OAuthTokenInfo
|
||||
if err := json.Unmarshal(data, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token info: %w", err)
|
||||
}
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// OAuthRevokeToken revokes a token by deleting the session (RFC 7009).
|
||||
func (a *DatabaseAuthenticator) OAuthRevokeToken(ctx context.Context, token string) error {
|
||||
var success bool
|
||||
var errMsg *string
|
||||
|
||||
err := a.db.QueryRowContext(ctx, fmt.Sprintf(`
|
||||
SELECT p_success, p_error
|
||||
FROM %s($1)
|
||||
`, a.sqlNames.OAuthRevoke), token).Scan(&success, &errMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke token: %w", err)
|
||||
}
|
||||
if !success {
|
||||
if errMsg != nil {
|
||||
return fmt.Errorf("%s", *errMsg)
|
||||
}
|
||||
return fmt.Errorf("failed to revoke token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
185
pkg/security/passkey.go
Normal file
185
pkg/security/passkey.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// PasskeyCredential represents a stored WebAuthn/FIDO2 credential
|
||||
type PasskeyCredential struct {
|
||||
ID string `json:"id"`
|
||||
UserID int `json:"user_id"`
|
||||
CredentialID []byte `json:"credential_id"` // Raw credential ID from authenticator
|
||||
PublicKey []byte `json:"public_key"` // COSE public key
|
||||
AttestationType string `json:"attestation_type"` // none, indirect, direct
|
||||
AAGUID []byte `json:"aaguid"` // Authenticator AAGUID
|
||||
SignCount uint32 `json:"sign_count"` // Signature counter
|
||||
CloneWarning bool `json:"clone_warning"` // True if cloning detected
|
||||
Transports []string `json:"transports,omitempty"` // usb, nfc, ble, internal
|
||||
BackupEligible bool `json:"backup_eligible"` // Credential can be backed up
|
||||
BackupState bool `json:"backup_state"` // Credential is currently backed up
|
||||
Name string `json:"name,omitempty"` // User-friendly name
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
LastUsedAt time.Time `json:"last_used_at"`
|
||||
}
|
||||
|
||||
// PasskeyRegistrationOptions contains options for beginning passkey registration
|
||||
type PasskeyRegistrationOptions struct {
|
||||
Challenge []byte `json:"challenge"`
|
||||
RelyingParty PasskeyRelyingParty `json:"rp"`
|
||||
User PasskeyUser `json:"user"`
|
||||
PubKeyCredParams []PasskeyCredentialParam `json:"pubKeyCredParams"`
|
||||
Timeout int64 `json:"timeout,omitempty"` // Milliseconds
|
||||
ExcludeCredentials []PasskeyCredentialDescriptor `json:"excludeCredentials,omitempty"`
|
||||
AuthenticatorSelection *PasskeyAuthenticatorSelection `json:"authenticatorSelection,omitempty"`
|
||||
Attestation string `json:"attestation,omitempty"` // none, indirect, direct, enterprise
|
||||
Extensions map[string]any `json:"extensions,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyAuthenticationOptions contains options for beginning passkey authentication
|
||||
type PasskeyAuthenticationOptions struct {
|
||||
Challenge []byte `json:"challenge"`
|
||||
Timeout int64 `json:"timeout,omitempty"`
|
||||
RelyingPartyID string `json:"rpId,omitempty"`
|
||||
AllowCredentials []PasskeyCredentialDescriptor `json:"allowCredentials,omitempty"`
|
||||
UserVerification string `json:"userVerification,omitempty"` // required, preferred, discouraged
|
||||
Extensions map[string]any `json:"extensions,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyRelyingParty identifies the relying party
|
||||
type PasskeyRelyingParty struct {
|
||||
ID string `json:"id"` // Domain (e.g., "example.com")
|
||||
Name string `json:"name"` // Display name
|
||||
}
|
||||
|
||||
// PasskeyUser identifies the user
|
||||
type PasskeyUser struct {
|
||||
ID []byte `json:"id"` // User handle (unique, persistent)
|
||||
Name string `json:"name"` // Username
|
||||
DisplayName string `json:"displayName"` // Display name
|
||||
}
|
||||
|
||||
// PasskeyCredentialParam specifies supported public key algorithm
|
||||
type PasskeyCredentialParam struct {
|
||||
Type string `json:"type"` // "public-key"
|
||||
Alg int `json:"alg"` // COSE algorithm identifier (e.g., -7 for ES256, -257 for RS256)
|
||||
}
|
||||
|
||||
// PasskeyCredentialDescriptor describes a credential
|
||||
type PasskeyCredentialDescriptor struct {
|
||||
Type string `json:"type"` // "public-key"
|
||||
ID []byte `json:"id"` // Credential ID
|
||||
Transports []string `json:"transports,omitempty"` // usb, nfc, ble, internal
|
||||
}
|
||||
|
||||
// PasskeyAuthenticatorSelection specifies authenticator requirements
|
||||
type PasskeyAuthenticatorSelection struct {
|
||||
AuthenticatorAttachment string `json:"authenticatorAttachment,omitempty"` // platform, cross-platform
|
||||
RequireResidentKey bool `json:"requireResidentKey,omitempty"`
|
||||
ResidentKey string `json:"residentKey,omitempty"` // discouraged, preferred, required
|
||||
UserVerification string `json:"userVerification,omitempty"` // required, preferred, discouraged
|
||||
}
|
||||
|
||||
// PasskeyRegistrationResponse contains the client's registration response
|
||||
type PasskeyRegistrationResponse struct {
|
||||
ID string `json:"id"` // Base64URL encoded credential ID
|
||||
RawID []byte `json:"rawId"` // Raw credential ID
|
||||
Type string `json:"type"` // "public-key"
|
||||
Response PasskeyAuthenticatorAttestationResponse `json:"response"`
|
||||
ClientExtensionResults map[string]any `json:"clientExtensionResults,omitempty"`
|
||||
Transports []string `json:"transports,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyAuthenticatorAttestationResponse contains attestation data
|
||||
type PasskeyAuthenticatorAttestationResponse struct {
|
||||
ClientDataJSON []byte `json:"clientDataJSON"`
|
||||
AttestationObject []byte `json:"attestationObject"`
|
||||
Transports []string `json:"transports,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyAuthenticationResponse contains the client's authentication response
|
||||
type PasskeyAuthenticationResponse struct {
|
||||
ID string `json:"id"` // Base64URL encoded credential ID
|
||||
RawID []byte `json:"rawId"` // Raw credential ID
|
||||
Type string `json:"type"` // "public-key"
|
||||
Response PasskeyAuthenticatorAssertionResponse `json:"response"`
|
||||
ClientExtensionResults map[string]any `json:"clientExtensionResults,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyAuthenticatorAssertionResponse contains assertion data
|
||||
type PasskeyAuthenticatorAssertionResponse struct {
|
||||
ClientDataJSON []byte `json:"clientDataJSON"`
|
||||
AuthenticatorData []byte `json:"authenticatorData"`
|
||||
Signature []byte `json:"signature"`
|
||||
UserHandle []byte `json:"userHandle,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyProvider handles passkey registration and authentication
|
||||
type PasskeyProvider interface {
|
||||
// BeginRegistration creates registration options for a new passkey
|
||||
BeginRegistration(ctx context.Context, userID int, username, displayName string) (*PasskeyRegistrationOptions, error)
|
||||
|
||||
// CompleteRegistration verifies and stores a new passkey credential
|
||||
CompleteRegistration(ctx context.Context, userID int, response PasskeyRegistrationResponse, expectedChallenge []byte) (*PasskeyCredential, error)
|
||||
|
||||
// BeginAuthentication creates authentication options for passkey login
|
||||
BeginAuthentication(ctx context.Context, username string) (*PasskeyAuthenticationOptions, error)
|
||||
|
||||
// CompleteAuthentication verifies a passkey assertion and returns the user
|
||||
CompleteAuthentication(ctx context.Context, response PasskeyAuthenticationResponse, expectedChallenge []byte) (int, error)
|
||||
|
||||
// GetCredentials returns all passkey credentials for a user
|
||||
GetCredentials(ctx context.Context, userID int) ([]PasskeyCredential, error)
|
||||
|
||||
// DeleteCredential removes a passkey credential
|
||||
DeleteCredential(ctx context.Context, userID int, credentialID string) error
|
||||
|
||||
// UpdateCredentialName updates the friendly name of a credential
|
||||
UpdateCredentialName(ctx context.Context, userID int, credentialID string, name string) error
|
||||
}
|
||||
|
||||
// PasskeyLoginRequest contains passkey authentication data
|
||||
type PasskeyLoginRequest struct {
|
||||
Response PasskeyAuthenticationResponse `json:"response"`
|
||||
ExpectedChallenge []byte `json:"expected_challenge"`
|
||||
Claims map[string]any `json:"claims"` // Additional login data
|
||||
}
|
||||
|
||||
// PasskeyRegisterRequest contains passkey registration data
|
||||
type PasskeyRegisterRequest struct {
|
||||
UserID int `json:"user_id"`
|
||||
Response PasskeyRegistrationResponse `json:"response"`
|
||||
ExpectedChallenge []byte `json:"expected_challenge"`
|
||||
CredentialName string `json:"credential_name,omitempty"`
|
||||
}
|
||||
|
||||
// PasskeyBeginRegistrationRequest contains options for starting passkey registration
|
||||
type PasskeyBeginRegistrationRequest struct {
|
||||
UserID int `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
|
||||
// PasskeyBeginAuthenticationRequest contains options for starting passkey authentication
|
||||
type PasskeyBeginAuthenticationRequest struct {
|
||||
Username string `json:"username,omitempty"` // Optional for resident key flow
|
||||
}
|
||||
|
||||
// ParsePasskeyRegistrationResponse parses a JSON passkey registration response
|
||||
func ParsePasskeyRegistrationResponse(data []byte) (*PasskeyRegistrationResponse, error) {
|
||||
var response PasskeyRegistrationResponse
|
||||
if err := json.Unmarshal(data, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
|
||||
// ParsePasskeyAuthenticationResponse parses a JSON passkey authentication response
|
||||
func ParsePasskeyAuthenticationResponse(data []byte) (*PasskeyAuthenticationResponse, error) {
|
||||
var response PasskeyAuthenticationResponse
|
||||
if err := json.Unmarshal(data, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &response, nil
|
||||
}
|
||||
432
pkg/security/passkey_examples.go
Normal file
432
pkg/security/passkey_examples.go
Normal file
@@ -0,0 +1,432 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// PasskeyAuthenticationExample demonstrates passkey (WebAuthn/FIDO2) authentication
|
||||
func PasskeyAuthenticationExample() {
|
||||
// Setup database connection
|
||||
db, _ := sql.Open("postgres", "postgres://user:pass@localhost/db")
|
||||
|
||||
// Create passkey provider
|
||||
passkeyProvider := NewDatabasePasskeyProvider(db, DatabasePasskeyProviderOptions{
|
||||
RPID: "example.com", // Your domain
|
||||
RPName: "Example Application", // Display name
|
||||
RPOrigin: "https://example.com", // Expected origin
|
||||
Timeout: 60000, // 60 seconds
|
||||
})
|
||||
|
||||
// Create authenticator with passkey support
|
||||
// Option 1: Pass during creation
|
||||
_ = NewDatabaseAuthenticatorWithOptions(db, DatabaseAuthenticatorOptions{
|
||||
PasskeyProvider: passkeyProvider,
|
||||
})
|
||||
|
||||
// Option 2: Use WithPasskey method
|
||||
auth := NewDatabaseAuthenticator(db).WithPasskey(passkeyProvider)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// === REGISTRATION FLOW ===
|
||||
|
||||
// Step 1: Begin registration
|
||||
regOptions, _ := auth.BeginPasskeyRegistration(ctx, PasskeyBeginRegistrationRequest{
|
||||
UserID: 1,
|
||||
Username: "alice",
|
||||
DisplayName: "Alice Smith",
|
||||
})
|
||||
|
||||
// Send regOptions to client as JSON
|
||||
// Client will call navigator.credentials.create() with these options
|
||||
_ = regOptions
|
||||
|
||||
// Step 2: Complete registration (after client returns credential)
|
||||
// This would come from the client's navigator.credentials.create() response
|
||||
clientResponse := PasskeyRegistrationResponse{
|
||||
ID: "base64-credential-id",
|
||||
RawID: []byte("raw-credential-id"),
|
||||
Type: "public-key",
|
||||
Response: PasskeyAuthenticatorAttestationResponse{
|
||||
ClientDataJSON: []byte("..."),
|
||||
AttestationObject: []byte("..."),
|
||||
},
|
||||
Transports: []string{"internal"},
|
||||
}
|
||||
|
||||
credential, _ := auth.CompletePasskeyRegistration(ctx, PasskeyRegisterRequest{
|
||||
UserID: 1,
|
||||
Response: clientResponse,
|
||||
ExpectedChallenge: regOptions.Challenge,
|
||||
CredentialName: "My iPhone",
|
||||
})
|
||||
|
||||
fmt.Printf("Registered credential: %s\n", credential.ID)
|
||||
|
||||
// === AUTHENTICATION FLOW ===
|
||||
|
||||
// Step 1: Begin authentication
|
||||
authOptions, _ := auth.BeginPasskeyAuthentication(ctx, PasskeyBeginAuthenticationRequest{
|
||||
Username: "alice", // Optional - omit for resident key flow
|
||||
})
|
||||
|
||||
// Send authOptions to client as JSON
|
||||
// Client will call navigator.credentials.get() with these options
|
||||
_ = authOptions
|
||||
|
||||
// Step 2: Complete authentication (after client returns assertion)
|
||||
// This would come from the client's navigator.credentials.get() response
|
||||
clientAssertion := PasskeyAuthenticationResponse{
|
||||
ID: "base64-credential-id",
|
||||
RawID: []byte("raw-credential-id"),
|
||||
Type: "public-key",
|
||||
Response: PasskeyAuthenticatorAssertionResponse{
|
||||
ClientDataJSON: []byte("..."),
|
||||
AuthenticatorData: []byte("..."),
|
||||
Signature: []byte("..."),
|
||||
},
|
||||
}
|
||||
|
||||
loginResponse, _ := auth.LoginWithPasskey(ctx, PasskeyLoginRequest{
|
||||
Response: clientAssertion,
|
||||
ExpectedChallenge: authOptions.Challenge,
|
||||
Claims: map[string]any{
|
||||
"ip_address": "192.168.1.1",
|
||||
"user_agent": "Mozilla/5.0...",
|
||||
},
|
||||
})
|
||||
|
||||
fmt.Printf("Logged in user: %s with token: %s\n",
|
||||
loginResponse.User.UserName, loginResponse.Token)
|
||||
|
||||
// === CREDENTIAL MANAGEMENT ===
|
||||
|
||||
// Get all credentials for a user
|
||||
credentials, _ := auth.GetPasskeyCredentials(ctx, 1)
|
||||
for i := range credentials {
|
||||
fmt.Printf("Credential: %s (created: %s, last used: %s)\n",
|
||||
credentials[i].Name, credentials[i].CreatedAt, credentials[i].LastUsedAt)
|
||||
}
|
||||
|
||||
// Update credential name
|
||||
_ = auth.UpdatePasskeyCredentialName(ctx, 1, credential.ID, "My New iPhone")
|
||||
|
||||
// Delete credential
|
||||
_ = auth.DeletePasskeyCredential(ctx, 1, credential.ID)
|
||||
}
|
||||
|
||||
// PasskeyHTTPHandlersExample shows HTTP handlers for passkey authentication
|
||||
func PasskeyHTTPHandlersExample(auth *DatabaseAuthenticator) {
|
||||
// Store challenges in session/cache in production
|
||||
challenges := make(map[string][]byte)
|
||||
|
||||
// Begin registration endpoint
|
||||
http.HandleFunc("/api/passkey/register/begin", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
UserID int `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
DisplayName string `json:"display_name"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
options, err := auth.BeginPasskeyRegistration(r.Context(), PasskeyBeginRegistrationRequest{
|
||||
UserID: req.UserID,
|
||||
Username: req.Username,
|
||||
DisplayName: req.DisplayName,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Store challenge for verification (use session ID as key in production)
|
||||
sessionID := "session-123"
|
||||
challenges[sessionID] = options.Challenge
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(options)
|
||||
})
|
||||
|
||||
// Complete registration endpoint
|
||||
http.HandleFunc("/api/passkey/register/complete", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
UserID int `json:"user_id"`
|
||||
Response PasskeyRegistrationResponse `json:"response"`
|
||||
CredentialName string `json:"credential_name"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Get stored challenge (from session in production)
|
||||
sessionID := "session-123"
|
||||
challenge := challenges[sessionID]
|
||||
delete(challenges, sessionID)
|
||||
|
||||
credential, err := auth.CompletePasskeyRegistration(r.Context(), PasskeyRegisterRequest{
|
||||
UserID: req.UserID,
|
||||
Response: req.Response,
|
||||
ExpectedChallenge: challenge,
|
||||
CredentialName: req.CredentialName,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(credential)
|
||||
})
|
||||
|
||||
// Begin authentication endpoint
|
||||
http.HandleFunc("/api/passkey/login/begin", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Username string `json:"username"` // Optional
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
options, err := auth.BeginPasskeyAuthentication(r.Context(), PasskeyBeginAuthenticationRequest{
|
||||
Username: req.Username,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Store challenge for verification (use session ID as key in production)
|
||||
sessionID := "session-456"
|
||||
challenges[sessionID] = options.Challenge
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(options)
|
||||
})
|
||||
|
||||
// Complete authentication endpoint
|
||||
http.HandleFunc("/api/passkey/login/complete", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Response PasskeyAuthenticationResponse `json:"response"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
// Get stored challenge (from session in production)
|
||||
sessionID := "session-456"
|
||||
challenge := challenges[sessionID]
|
||||
delete(challenges, sessionID)
|
||||
|
||||
loginResponse, err := auth.LoginWithPasskey(r.Context(), PasskeyLoginRequest{
|
||||
Response: req.Response,
|
||||
ExpectedChallenge: challenge,
|
||||
Claims: map[string]any{
|
||||
"ip_address": r.RemoteAddr,
|
||||
"user_agent": r.UserAgent(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set session cookie
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "session_token",
|
||||
Value: loginResponse.Token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(loginResponse)
|
||||
})
|
||||
|
||||
// List credentials endpoint
|
||||
http.HandleFunc("/api/passkey/credentials", func(w http.ResponseWriter, r *http.Request) {
|
||||
// Get user from authenticated session
|
||||
userCtx, err := auth.Authenticate(r)
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
credentials, err := auth.GetPasskeyCredentials(r.Context(), userCtx.UserID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(credentials)
|
||||
})
|
||||
|
||||
// Delete credential endpoint
|
||||
http.HandleFunc("/api/passkey/credentials/delete", func(w http.ResponseWriter, r *http.Request) {
|
||||
userCtx, err := auth.Authenticate(r)
|
||||
if err != nil {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
CredentialID string `json:"credential_id"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
err = auth.DeletePasskeyCredential(r.Context(), userCtx.UserID, req.CredentialID)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
}
|
||||
|
||||
// PasskeyClientSideExample shows the client-side JavaScript code needed
|
||||
func PasskeyClientSideExample() string {
|
||||
return `
|
||||
// === CLIENT-SIDE JAVASCRIPT FOR PASSKEY AUTHENTICATION ===
|
||||
|
||||
// Helper function to convert base64 to ArrayBuffer
|
||||
function base64ToArrayBuffer(base64) {
|
||||
const binary = atob(base64);
|
||||
const bytes = new Uint8Array(binary.length);
|
||||
for (let i = 0; i < binary.length; i++) {
|
||||
bytes[i] = binary.charCodeAt(i);
|
||||
}
|
||||
return bytes.buffer;
|
||||
}
|
||||
|
||||
// Helper function to convert ArrayBuffer to base64
|
||||
function arrayBufferToBase64(buffer) {
|
||||
const bytes = new Uint8Array(buffer);
|
||||
let binary = '';
|
||||
for (let i = 0; i < bytes.length; i++) {
|
||||
binary += String.fromCharCode(bytes[i]);
|
||||
}
|
||||
return btoa(binary);
|
||||
}
|
||||
|
||||
// === REGISTRATION ===
|
||||
|
||||
async function registerPasskey(userId, username, displayName) {
|
||||
// Step 1: Get registration options from server
|
||||
const optionsResponse = await fetch('/api/passkey/register/begin', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ user_id: userId, username, display_name: displayName })
|
||||
});
|
||||
const options = await optionsResponse.json();
|
||||
|
||||
// Convert base64 strings to ArrayBuffers
|
||||
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||
options.user.id = base64ToArrayBuffer(options.user.id);
|
||||
if (options.excludeCredentials) {
|
||||
options.excludeCredentials = options.excludeCredentials.map(cred => ({
|
||||
...cred,
|
||||
id: base64ToArrayBuffer(cred.id)
|
||||
}));
|
||||
}
|
||||
|
||||
// Step 2: Create credential using WebAuthn API
|
||||
const credential = await navigator.credentials.create({
|
||||
publicKey: options
|
||||
});
|
||||
|
||||
// Step 3: Send credential to server
|
||||
const credentialResponse = {
|
||||
id: credential.id,
|
||||
rawId: arrayBufferToBase64(credential.rawId),
|
||||
type: credential.type,
|
||||
response: {
|
||||
clientDataJSON: arrayBufferToBase64(credential.response.clientDataJSON),
|
||||
attestationObject: arrayBufferToBase64(credential.response.attestationObject)
|
||||
},
|
||||
transports: credential.response.getTransports ? credential.response.getTransports() : []
|
||||
};
|
||||
|
||||
const completeResponse = await fetch('/api/passkey/register/complete', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
user_id: userId,
|
||||
response: credentialResponse,
|
||||
credential_name: 'My Device'
|
||||
})
|
||||
});
|
||||
|
||||
return await completeResponse.json();
|
||||
}
|
||||
|
||||
// === AUTHENTICATION ===
|
||||
|
||||
async function loginWithPasskey(username) {
|
||||
// Step 1: Get authentication options from server
|
||||
const optionsResponse = await fetch('/api/passkey/login/begin', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ username })
|
||||
});
|
||||
const options = await optionsResponse.json();
|
||||
|
||||
// Convert base64 strings to ArrayBuffers
|
||||
options.challenge = base64ToArrayBuffer(options.challenge);
|
||||
if (options.allowCredentials) {
|
||||
options.allowCredentials = options.allowCredentials.map(cred => ({
|
||||
...cred,
|
||||
id: base64ToArrayBuffer(cred.id)
|
||||
}));
|
||||
}
|
||||
|
||||
// Step 2: Get credential using WebAuthn API
|
||||
const credential = await navigator.credentials.get({
|
||||
publicKey: options
|
||||
});
|
||||
|
||||
// Step 3: Send assertion to server
|
||||
const assertionResponse = {
|
||||
id: credential.id,
|
||||
rawId: arrayBufferToBase64(credential.rawId),
|
||||
type: credential.type,
|
||||
response: {
|
||||
clientDataJSON: arrayBufferToBase64(credential.response.clientDataJSON),
|
||||
authenticatorData: arrayBufferToBase64(credential.response.authenticatorData),
|
||||
signature: arrayBufferToBase64(credential.response.signature),
|
||||
userHandle: credential.response.userHandle ? arrayBufferToBase64(credential.response.userHandle) : null
|
||||
}
|
||||
};
|
||||
|
||||
const loginResponse = await fetch('/api/passkey/login/complete', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ response: assertionResponse })
|
||||
});
|
||||
|
||||
return await loginResponse.json();
|
||||
}
|
||||
|
||||
// === USAGE ===
|
||||
|
||||
// Register a new passkey
|
||||
document.getElementById('register-btn').addEventListener('click', async () => {
|
||||
try {
|
||||
const result = await registerPasskey(1, 'alice', 'Alice Smith');
|
||||
console.log('Passkey registered:', result);
|
||||
} catch (error) {
|
||||
console.error('Registration failed:', error);
|
||||
}
|
||||
});
|
||||
|
||||
// Login with passkey
|
||||
document.getElementById('login-btn').addEventListener('click', async () => {
|
||||
try {
|
||||
const result = await loginWithPasskey('alice');
|
||||
console.log('Logged in:', result);
|
||||
} catch (error) {
|
||||
console.error('Login failed:', error);
|
||||
}
|
||||
});
|
||||
`
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user