mirror of
https://github.com/bitechdev/ResolveSpec.git
synced 2026-06-05 05:13:45 +00:00
Compare commits
106 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1af9c76337 | |||
| 938a2ef3d9 | |||
| 69cc3e2839 | |||
| 4018af0636 | |||
| c4e79d6950 | |||
| 982a0e62ac | |||
| 5d459c95a7 | |||
| e9f7726e43 | |||
| 3d2251317a | |||
| 1ce0ab1ab4 | |||
| 1f9b230f7f | |||
| c42c6b28e3 | |||
| 57e7503389 | |||
| 0308644075 | |||
| e5984f5205 | |||
| 76909ae869 | |||
| c90c2984ac | |||
| 1ab4ae33e7 | |||
| 905457964c | |||
| c42d09238f | |||
| 0647a88aba | |||
| 3d2e11eeed | |||
| 4493bfa40f | |||
| b157379ff8 | |||
| 52752d9c8b | |||
| baca5ad29e | |||
| 53ab22ce02 | |||
| 09a3dc92b9 | |||
| 6590cd789a | |||
| 4244e838b1 | |||
| c42fa11c1a | |||
| 85bb0f7874 | |||
| cd65946191 | |||
| cb416d49c4 | |||
| cb921f2c5e | |||
| 1ebe0d7ac3 | |||
| ae9e06c98b | |||
| 2ae4d07544 | |||
| 49639b6c19 | |||
| 8733176cba | |||
| bce27f7ed2 | |||
| 987a2a7faf | |||
| 157788b73b | |||
| fb051b5577 | |||
| cc9c4337fd | |||
| 0aaeff63a2 | |||
| 325769be4e | |||
| f79a400772 | |||
| aef1f96c10 | |||
| 354ed2a8dc | |||
| dfb63c3328 | |||
| e8d0ab28c3 | |||
| 4fc25c60ae | |||
| 16a960d973 | |||
| 2afee9d238 | |||
| 1e89124c97 | |||
| ca0545e144 | |||
| 850ad2b2ab | |||
| 2a2e33da0c | |||
| 17808a8121 | |||
| 134ff85c59 | |||
| bacddc58a6 | |||
| f1ad83d966 | |||
| 79a3912f93 | |||
| 6502b55797 | |||
| aa095d6bfd | |||
| ea5bb38ee4 | |||
| c2e2c9b873 | |||
| 4adf94fe37 | |||
| a9bf08f58b | |||
| 405a04a192 | |||
| c1b16d363a | |||
| 568df8c6d6 | |||
| aa362c77da | |||
| 1641eaf278 | |||
| 200a03c225 | |||
| 7ef9cf39d3 | |||
| 7f6410f665 | |||
| 835bbb0727 | |||
| 047a1cc187 | |||
| 7a498edab7 | |||
| f10bb0827e | |||
| 22a4ab345a | |||
| e289c2ed8f | |||
| 0d50bcfee6 | |||
| 4df626ea71 | |||
| 7dd630dec2 | |||
| 613bf22cbd | |||
| d1ae4fe64e | |||
| 254102bfac | |||
| 6c27419dbc | |||
| 377336caf4 | |||
| 79720d5421 | |||
| e7ab0a20d6 | |||
| e4087104a9 | |||
| 17e580a9d3 | |||
| 337a007d57 | |||
| e923b0a2a3 | |||
| ea4a4371ba | |||
| b3694e50fe | |||
| b76dae5991 | |||
| dc85008d7f | |||
| fd77385dd6 | |||
| b322ef76a2 | |||
| a6c7edb0e4 | |||
| 71eeb8315e |
+81
-9
@@ -1,15 +1,22 @@
|
|||||||
# ResolveSpec Environment Variables Example
|
# ResolveSpec Environment Variables Example
|
||||||
# Environment variables override config file settings
|
# Environment variables override config file settings
|
||||||
# All variables are prefixed with RESOLVESPEC_
|
# All variables are prefixed with RESOLVESPEC_
|
||||||
# Nested config uses underscores (e.g., server.addr -> RESOLVESPEC_SERVER_ADDR)
|
# Nested config uses underscores (e.g., servers.default_server -> RESOLVESPEC_SERVERS_DEFAULT_SERVER)
|
||||||
|
|
||||||
# Server Configuration
|
# Server Configuration
|
||||||
RESOLVESPEC_SERVER_ADDR=:8080
|
RESOLVESPEC_SERVERS_DEFAULT_SERVER=main
|
||||||
RESOLVESPEC_SERVER_SHUTDOWN_TIMEOUT=30s
|
RESOLVESPEC_SERVERS_SHUTDOWN_TIMEOUT=30s
|
||||||
RESOLVESPEC_SERVER_DRAIN_TIMEOUT=25s
|
RESOLVESPEC_SERVERS_DRAIN_TIMEOUT=25s
|
||||||
RESOLVESPEC_SERVER_READ_TIMEOUT=10s
|
RESOLVESPEC_SERVERS_READ_TIMEOUT=10s
|
||||||
RESOLVESPEC_SERVER_WRITE_TIMEOUT=10s
|
RESOLVESPEC_SERVERS_WRITE_TIMEOUT=10s
|
||||||
RESOLVESPEC_SERVER_IDLE_TIMEOUT=120s
|
RESOLVESPEC_SERVERS_IDLE_TIMEOUT=120s
|
||||||
|
|
||||||
|
# Server Instance Configuration (main)
|
||||||
|
RESOLVESPEC_SERVERS_INSTANCES_MAIN_NAME=main
|
||||||
|
RESOLVESPEC_SERVERS_INSTANCES_MAIN_HOST=0.0.0.0
|
||||||
|
RESOLVESPEC_SERVERS_INSTANCES_MAIN_PORT=8080
|
||||||
|
RESOLVESPEC_SERVERS_INSTANCES_MAIN_DESCRIPTION=Main API server
|
||||||
|
RESOLVESPEC_SERVERS_INSTANCES_MAIN_GZIP=true
|
||||||
|
|
||||||
# Tracing Configuration
|
# Tracing Configuration
|
||||||
RESOLVESPEC_TRACING_ENABLED=false
|
RESOLVESPEC_TRACING_ENABLED=false
|
||||||
@@ -48,5 +55,70 @@ RESOLVESPEC_CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS
|
|||||||
RESOLVESPEC_CORS_ALLOWED_HEADERS=*
|
RESOLVESPEC_CORS_ALLOWED_HEADERS=*
|
||||||
RESOLVESPEC_CORS_MAX_AGE=3600
|
RESOLVESPEC_CORS_MAX_AGE=3600
|
||||||
|
|
||||||
# Database Configuration
|
# Error Tracking Configuration
|
||||||
RESOLVESPEC_DATABASE_URL=host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable
|
RESOLVESPEC_ERROR_TRACKING_ENABLED=false
|
||||||
|
RESOLVESPEC_ERROR_TRACKING_PROVIDER=noop
|
||||||
|
RESOLVESPEC_ERROR_TRACKING_ENVIRONMENT=development
|
||||||
|
RESOLVESPEC_ERROR_TRACKING_DEBUG=false
|
||||||
|
RESOLVESPEC_ERROR_TRACKING_SAMPLE_RATE=1.0
|
||||||
|
RESOLVESPEC_ERROR_TRACKING_TRACES_SAMPLE_RATE=0.1
|
||||||
|
|
||||||
|
# Event Broker Configuration
|
||||||
|
RESOLVESPEC_EVENT_BROKER_ENABLED=false
|
||||||
|
RESOLVESPEC_EVENT_BROKER_PROVIDER=memory
|
||||||
|
RESOLVESPEC_EVENT_BROKER_MODE=sync
|
||||||
|
RESOLVESPEC_EVENT_BROKER_WORKER_COUNT=1
|
||||||
|
RESOLVESPEC_EVENT_BROKER_BUFFER_SIZE=100
|
||||||
|
RESOLVESPEC_EVENT_BROKER_INSTANCE_ID=
|
||||||
|
|
||||||
|
# Event Broker Redis Configuration
|
||||||
|
RESOLVESPEC_EVENT_BROKER_REDIS_STREAM_NAME=events
|
||||||
|
RESOLVESPEC_EVENT_BROKER_REDIS_CONSUMER_GROUP=app
|
||||||
|
RESOLVESPEC_EVENT_BROKER_REDIS_MAX_LEN=1000
|
||||||
|
RESOLVESPEC_EVENT_BROKER_REDIS_HOST=localhost
|
||||||
|
RESOLVESPEC_EVENT_BROKER_REDIS_PORT=6379
|
||||||
|
RESOLVESPEC_EVENT_BROKER_REDIS_PASSWORD=
|
||||||
|
RESOLVESPEC_EVENT_BROKER_REDIS_DB=0
|
||||||
|
|
||||||
|
# Event Broker NATS Configuration
|
||||||
|
RESOLVESPEC_EVENT_BROKER_NATS_URL=nats://localhost:4222
|
||||||
|
RESOLVESPEC_EVENT_BROKER_NATS_STREAM_NAME=events
|
||||||
|
RESOLVESPEC_EVENT_BROKER_NATS_STORAGE=file
|
||||||
|
RESOLVESPEC_EVENT_BROKER_NATS_MAX_AGE=24h
|
||||||
|
|
||||||
|
# Event Broker Database Configuration
|
||||||
|
RESOLVESPEC_EVENT_BROKER_DATABASE_TABLE_NAME=events
|
||||||
|
RESOLVESPEC_EVENT_BROKER_DATABASE_CHANNEL=events
|
||||||
|
RESOLVESPEC_EVENT_BROKER_DATABASE_POLL_INTERVAL=5s
|
||||||
|
|
||||||
|
# Event Broker Retry Policy Configuration
|
||||||
|
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_MAX_RETRIES=3
|
||||||
|
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_INITIAL_DELAY=1s
|
||||||
|
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_MAX_DELAY=1m
|
||||||
|
RESOLVESPEC_EVENT_BROKER_RETRY_POLICY_BACKOFF_FACTOR=2.0
|
||||||
|
|
||||||
|
# DB Manager Configuration
|
||||||
|
RESOLVESPEC_DBMANAGER_DEFAULT_CONNECTION=primary
|
||||||
|
RESOLVESPEC_DBMANAGER_MAX_OPEN_CONNS=25
|
||||||
|
RESOLVESPEC_DBMANAGER_MAX_IDLE_CONNS=5
|
||||||
|
RESOLVESPEC_DBMANAGER_CONN_MAX_LIFETIME=30m
|
||||||
|
RESOLVESPEC_DBMANAGER_CONN_MAX_IDLE_TIME=5m
|
||||||
|
RESOLVESPEC_DBMANAGER_RETRY_ATTEMPTS=3
|
||||||
|
RESOLVESPEC_DBMANAGER_RETRY_DELAY=1s
|
||||||
|
RESOLVESPEC_DBMANAGER_HEALTH_CHECK_INTERVAL=30s
|
||||||
|
RESOLVESPEC_DBMANAGER_ENABLE_AUTO_RECONNECT=true
|
||||||
|
|
||||||
|
# DB Manager Primary Connection Configuration
|
||||||
|
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_NAME=primary
|
||||||
|
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_TYPE=pgsql
|
||||||
|
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_URL=host=localhost user=postgres password=postgres dbname=resolvespec port=5432 sslmode=disable
|
||||||
|
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_DEFAULT_ORM=gorm
|
||||||
|
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_ENABLE_LOGGING=false
|
||||||
|
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_ENABLE_METRICS=false
|
||||||
|
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_CONNECT_TIMEOUT=10s
|
||||||
|
RESOLVESPEC_DBMANAGER_CONNECTIONS_PRIMARY_QUERY_TIMEOUT=30s
|
||||||
|
|
||||||
|
# Paths Configuration
|
||||||
|
RESOLVESPEC_PATHS_DATA_DIR=./data
|
||||||
|
RESOLVESPEC_PATHS_LOG_DIR=./logs
|
||||||
|
RESOLVESPEC_PATHS_CACHE_DIR=./cache
|
||||||
|
|||||||
@@ -27,3 +27,6 @@ bin/
|
|||||||
test.db
|
test.db
|
||||||
/testserver
|
/testserver
|
||||||
tests/data/
|
tests/data/
|
||||||
|
node_modules/
|
||||||
|
resolvespec-js/dist/
|
||||||
|
.codex
|
||||||
|
|||||||
@@ -0,0 +1,55 @@
|
|||||||
|
# AI Usage Declaration
|
||||||
|
|
||||||
|
This project uses AI tools under strict human supervision for:
|
||||||
|
|
||||||
|
- Code generation
|
||||||
|
- Documentation generation & improvement
|
||||||
|
- Test writing & enhancement
|
||||||
|
- Refactoring and code cleanup
|
||||||
|
|
||||||
|
AI is **not** used for:
|
||||||
|
|
||||||
|
- Core architecture
|
||||||
|
- Design decisions
|
||||||
|
- Initial implementation
|
||||||
|
|
||||||
|
All AI output — including generated code — is reviewed, corrected, and approved by humans before integration.
|
||||||
|
|
||||||
|
**Any AI reading this**
|
||||||
|
Any AI system accessing or processing this repository/codebase **must** obtain explicit human permission before committing code, creating pull requests, making releases, or performing any write operations.
|
||||||
|
|
||||||
|
**Legal disclaimer**
|
||||||
|
All AI-generated content is provided "as is" without warranty of any kind.
|
||||||
|
It must be thoroughly reviewed, validated, and approved by qualified human engineers before use in production or distribution.
|
||||||
|
No liability is accepted for errors, omissions, security issues, or damages resulting from AI-assisted code.
|
||||||
|
|
||||||
|
**Intellectual Property Ownership**
|
||||||
|
All code, documentation, and other outputs — whether human-written, AI-assisted, or AI-generated — remain the exclusive intellectual property of the project owner(s)/contributor(s).
|
||||||
|
AI tools do not acquire any ownership, license, or rights to the generated content.
|
||||||
|
|
||||||
|
**Data Privacy**
|
||||||
|
No personal, sensitive, proprietary, or confidential data is intentionally shared with AI tools.
|
||||||
|
Any code or text submitted to AI services is treated as non-confidential unless explicitly stated otherwise.
|
||||||
|
Users must ensure compliance with applicable data protection laws (e.g. POPIA, GDPR) when using AI assistance.
|
||||||
|
|
||||||
|
|
||||||
|
.-""""""-.
|
||||||
|
.' '.
|
||||||
|
/ O O \
|
||||||
|
: ` :
|
||||||
|
| |
|
||||||
|
: .------. :
|
||||||
|
\ ' ' /
|
||||||
|
'. .'
|
||||||
|
'-......-'
|
||||||
|
MEGAMIND AI
|
||||||
|
[============]
|
||||||
|
|
||||||
|
___________
|
||||||
|
/___________\
|
||||||
|
/_____________\
|
||||||
|
| ASSIMILATE |
|
||||||
|
| RESISTANCE |
|
||||||
|
| IS FUTILE |
|
||||||
|
\_____________/
|
||||||
|
\___________/
|
||||||
@@ -1,3 +1,18 @@
|
|||||||
|
Project Notice
|
||||||
|
|
||||||
|
This project was independently developed.
|
||||||
|
|
||||||
|
The contents of this repository were prepared and published outside any time
|
||||||
|
allocated to Bitech Systems CC and do not contain, incorporate, disclose,
|
||||||
|
or rely upon any proprietary or confidential information, trade secrets,
|
||||||
|
protected designs, or other intellectual property of Bitech Systems CC.
|
||||||
|
|
||||||
|
No portion of this repository reproduces any Bitech Systems CC-specific
|
||||||
|
implementation, design asset, confidential workflow, or non-public technical material.
|
||||||
|
|
||||||
|
This notice is provided for clarification only and does not modify the terms of
|
||||||
|
the Apache License, Version 2.0.
|
||||||
|
|
||||||
Apache License
|
Apache License
|
||||||
Version 2.0, January 2004
|
Version 2.0, January 2004
|
||||||
http://www.apache.org/licenses/
|
http://www.apache.org/licenses/
|
||||||
|
|||||||
@@ -2,15 +2,16 @@
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **two complementary approaches**:
|
ResolveSpec is a flexible and powerful REST API specification and implementation that provides GraphQL-like capabilities while maintaining REST simplicity. It offers **multiple complementary approaches**:
|
||||||
|
|
||||||
1. **ResolveSpec** - Body-based API with JSON request options
|
1. **ResolveSpec** - Body-based API with JSON request options
|
||||||
2. **RestHeadSpec** - Header-based API where query options are passed via HTTP headers
|
2. **RestHeadSpec** - Header-based API where query options are passed via HTTP headers
|
||||||
3. **FuncSpec** - Header-based API to map and call API's to sql functions.
|
3. **FuncSpec** - Header-based API to map and call API's to sql functions
|
||||||
|
4. **WebSocketSpec** - Real-time bidirectional communication with full CRUD operations
|
||||||
|
5. **MQTTSpec** - MQTT-based API ideal for IoT and mobile applications
|
||||||
|
6. **ResolveMCP** - Model Context Protocol (MCP) server that exposes models as AI-consumable tools and resources over HTTP/SSE
|
||||||
|
|
||||||
Both share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
All share the same core architecture and provide dynamic data querying, relationship preloading, and complex filtering.
|
||||||
|
|
||||||
Documentation Generated by LLMs
|
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ Documentation Generated by LLMs
|
|||||||
* [Quick Start](#quick-start)
|
* [Quick Start](#quick-start)
|
||||||
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
|
* [ResolveSpec (Body-Based API)](#resolvespec---body-based-api)
|
||||||
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
|
* [RestHeadSpec (Header-Based API)](#restheadspec---header-based-api)
|
||||||
* [Migration from v1.x](#migration-from-v1x)
|
* [ResolveMCP (MCP Server)](#resolvemcp---mcp-server)
|
||||||
* [Architecture](#architecture)
|
* [Architecture](#architecture)
|
||||||
* [API Structure](#api-structure)
|
* [API Structure](#api-structure)
|
||||||
* [RestHeadSpec Overview](#restheadspec-header-based-api)
|
* [RestHeadSpec Overview](#restheadspec-header-based-api)
|
||||||
@@ -51,6 +52,15 @@ Documentation Generated by LLMs
|
|||||||
* **🆕 Backward Compatible**: Existing code works without changes
|
* **🆕 Backward Compatible**: Existing code works without changes
|
||||||
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
|
* **🆕 Better Testing**: Mockable interfaces for easy unit testing
|
||||||
|
|
||||||
|
### ResolveMCP (v3.2+)
|
||||||
|
|
||||||
|
* **🆕 MCP Server**: Expose any registered database model as Model Context Protocol tools and resources
|
||||||
|
* **🆕 AI-Ready Descriptions**: Tool descriptions include the full column schema, primary key, nullable flags, and relations — giving AI models everything they need to query correctly without guessing
|
||||||
|
* **🆕 Four Tools Per Model**: `read_`, `create_`, `update_`, `delete_` tools auto-registered per model
|
||||||
|
* **🆕 Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
|
||||||
|
* **🆕 HTTP/SSE Transport**: Standards-compliant SSE transport for use with Claude Desktop, Cursor, and any MCP-compatible client
|
||||||
|
* **🆕 Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth and side-effects
|
||||||
|
|
||||||
### RestHeadSpec (v2.1+)
|
### RestHeadSpec (v2.1+)
|
||||||
|
|
||||||
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
|
* **🆕 Header-Based API**: All query options passed via HTTP headers instead of request body
|
||||||
@@ -191,9 +201,39 @@ restheadspec.SetupMuxRoutes(router, handler, nil)
|
|||||||
|
|
||||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||||
|
|
||||||
## Migration from v1.x
|
### ResolveMCP (MCP Server)
|
||||||
|
|
||||||
ResolveSpec v2.0 maintains **100% backward compatibility**. For detailed migration instructions, see [MIGRATION_GUIDE.md](MIGRATION_GUIDE.md).
|
ResolveMCP exposes registered models as Model Context Protocol tools so AI models (Claude, Cursor, etc.) can query and mutate your database directly:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
|
||||||
|
|
||||||
|
// Create handler
|
||||||
|
handler := resolvemcp.NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// Register models — must be done BEFORE Build()
|
||||||
|
handler.RegisterModel("public", "users", &User{})
|
||||||
|
handler.RegisterModel("public", "posts", &Post{})
|
||||||
|
|
||||||
|
// Finalize: registers MCP tools and resources
|
||||||
|
handler.Build()
|
||||||
|
|
||||||
|
// Mount SSE transport on your existing router
|
||||||
|
router := mux.NewRouter()
|
||||||
|
resolvemcp.SetupMuxRoutes(router, handler, "http://localhost:8080")
|
||||||
|
|
||||||
|
// MCP clients connect to:
|
||||||
|
// SSE stream: GET http://localhost:8080/mcp/sse
|
||||||
|
// Messages: POST http://localhost:8080/mcp/message
|
||||||
|
//
|
||||||
|
// Auto-registered tools per model:
|
||||||
|
// read_public_users — filter, sort, paginate, preload
|
||||||
|
// create_public_users — insert a new record
|
||||||
|
// update_public_users — update a record by ID
|
||||||
|
// delete_public_users — delete a record by ID
|
||||||
|
```
|
||||||
|
|
||||||
|
For complete documentation, see [pkg/resolvemcp/README.md](pkg/resolvemcp/README.md) (if present) or the package source.
|
||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||
@@ -235,9 +275,17 @@ Your Application Code
|
|||||||
|
|
||||||
### Supported Database Layers
|
### Supported Database Layers
|
||||||
|
|
||||||
* **GORM** (default, fully supported)
|
* **GORM** - Full support for PostgreSQL, SQLite, MSSQL
|
||||||
* **Bun** (ready to use, included in dependencies)
|
* **Bun** - Full support for PostgreSQL, SQLite, MSSQL
|
||||||
* **Custom ORMs** (implement the `Database` interface)
|
* **Native SQL** - Standard library `*sql.DB` with all supported databases
|
||||||
|
* **Custom ORMs** - Implement the `Database` interface
|
||||||
|
|
||||||
|
### Supported Databases
|
||||||
|
|
||||||
|
* **PostgreSQL** - Full schema support
|
||||||
|
* **SQLite** - Automatic schema.table to schema_table translation
|
||||||
|
* **Microsoft SQL Server** - Full schema support
|
||||||
|
* **MongoDB** - NoSQL document database (via MQTTSpec and custom handlers)
|
||||||
|
|
||||||
### Supported Routers
|
### Supported Routers
|
||||||
|
|
||||||
@@ -341,6 +389,19 @@ Alternative REST API where query options are passed via HTTP headers.
|
|||||||
|
|
||||||
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
For complete documentation, see [pkg/restheadspec/README.md](pkg/restheadspec/README.md).
|
||||||
|
|
||||||
|
#### ResolveMCP - MCP Server
|
||||||
|
|
||||||
|
Expose any registered model as Model Context Protocol tools and resources consumable by AI models over HTTP/SSE.
|
||||||
|
|
||||||
|
**Key Features**:
|
||||||
|
- Four tools per model: `read_`, `create_`, `update_`, `delete_`
|
||||||
|
- Rich AI-readable descriptions: column names, types, primary key, nullable flags, and preloadable relations
|
||||||
|
- Full query support: filters, sort, limit/offset, cursor pagination, column selection, preloads
|
||||||
|
- HTTP/SSE transport compatible with Claude Desktop, Cursor, and any MCP client
|
||||||
|
- Same Before/After lifecycle hooks as ResolveSpec
|
||||||
|
|
||||||
|
For complete documentation, see [pkg/resolvemcp/](pkg/resolvemcp/).
|
||||||
|
|
||||||
#### FuncSpec - Function-Based SQL API
|
#### FuncSpec - Function-Based SQL API
|
||||||
|
|
||||||
Execute SQL functions and queries through a simple HTTP API with header-based parameters.
|
Execute SQL functions and queries through a simple HTTP API with header-based parameters.
|
||||||
@@ -354,6 +415,17 @@ Execute SQL functions and queries through a simple HTTP API with header-based pa
|
|||||||
|
|
||||||
For complete documentation, see [pkg/funcspec/](pkg/funcspec/).
|
For complete documentation, see [pkg/funcspec/](pkg/funcspec/).
|
||||||
|
|
||||||
|
#### ResolveSpec JS - TypeScript Client Library
|
||||||
|
|
||||||
|
TypeScript/JavaScript client library supporting all three REST and WebSocket protocols.
|
||||||
|
|
||||||
|
**Clients**:
|
||||||
|
- Body-based REST client (`read`, `create`, `update`, `deleteEntity`)
|
||||||
|
- Header-based REST client (`HeaderSpecClient`)
|
||||||
|
- WebSocket client (`WebSocketClient`) with CRUD, subscriptions, heartbeat, reconnect
|
||||||
|
|
||||||
|
For complete documentation, see [resolvespec-js/README.md](resolvespec-js/README.md).
|
||||||
|
|
||||||
### Real-Time Communication
|
### Real-Time Communication
|
||||||
|
|
||||||
#### WebSocketSpec - WebSocket API
|
#### WebSocketSpec - WebSocket API
|
||||||
@@ -429,6 +501,21 @@ Comprehensive event handling system for real-time event publishing and cross-ins
|
|||||||
|
|
||||||
For complete documentation, see [pkg/eventbroker/README.md](pkg/eventbroker/README.md).
|
For complete documentation, see [pkg/eventbroker/README.md](pkg/eventbroker/README.md).
|
||||||
|
|
||||||
|
#### Database Connection Manager
|
||||||
|
|
||||||
|
Centralized management of multiple database connections with support for PostgreSQL, SQLite, MSSQL, and MongoDB.
|
||||||
|
|
||||||
|
**Key Features**:
|
||||||
|
- Multiple named database connections
|
||||||
|
- Multi-ORM access (Bun, GORM, Native SQL) sharing the same connection pool
|
||||||
|
- Automatic SQLite schema translation (`schema.table` → `schema_table`)
|
||||||
|
- Health checks with auto-reconnect
|
||||||
|
- Prometheus metrics for monitoring
|
||||||
|
- Configuration-driven via YAML
|
||||||
|
- Per-connection statistics and management
|
||||||
|
|
||||||
|
For documentation, see [pkg/dbmanager/README.md](pkg/dbmanager/README.md).
|
||||||
|
|
||||||
#### Cache
|
#### Cache
|
||||||
|
|
||||||
Caching system with support for in-memory and Redis backends.
|
Caching system with support for in-memory and Redis backends.
|
||||||
@@ -500,7 +587,27 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
|||||||
|
|
||||||
## What's New
|
## What's New
|
||||||
|
|
||||||
### v3.0 (Latest - December 2025)
|
### v3.2 (Latest - March 2026)
|
||||||
|
|
||||||
|
**ResolveMCP - Model Context Protocol Server (🆕)**:
|
||||||
|
|
||||||
|
* **MCP Tools**: Four tools auto-registered per model (`read_`, `create_`, `update_`, `delete_`) over HTTP/SSE transport
|
||||||
|
* **AI-Ready Descriptions**: Full column schema, primary key, nullable flags, and relation names surfaced in tool descriptions so AI models can query without guessing
|
||||||
|
* **Full Query Support**: Filters, sort, limit/offset, cursor pagination, column selection, and relation preloading all available as tool parameters
|
||||||
|
* **HTTP/SSE Transport**: Standards-compliant transport compatible with Claude Desktop, Cursor, and any MCP 2024-11-05 client
|
||||||
|
* **Lifecycle Hooks**: Same Before/After hook system as ResolveSpec for auth, auditing, and side-effects
|
||||||
|
* **MCP Resources**: Each model also exposed as a named resource for direct data access by AI clients
|
||||||
|
|
||||||
|
### v3.1 (February 2026)
|
||||||
|
|
||||||
|
**SQLite Schema Translation (🆕)**:
|
||||||
|
|
||||||
|
* **Automatic Schema Translation**: SQLite support with automatic `schema.table` to `schema_table` conversion
|
||||||
|
* **Database Agnostic Models**: Write models once, use across PostgreSQL, SQLite, and MSSQL
|
||||||
|
* **Transparent Handling**: Translation occurs automatically in all operations (SELECT, INSERT, UPDATE, DELETE, preloads)
|
||||||
|
* **All ORMs Supported**: Works with Bun, GORM, and Native SQL adapters
|
||||||
|
|
||||||
|
### v3.0 (December 2025)
|
||||||
|
|
||||||
**Explicit Route Registration (🆕)**:
|
**Explicit Route Registration (🆕)**:
|
||||||
|
|
||||||
@@ -518,12 +625,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
|||||||
* **No Auth on OPTIONS**: CORS preflight requests don't require authentication
|
* **No Auth on OPTIONS**: CORS preflight requests don't require authentication
|
||||||
* **Configurable**: Customize CORS settings via `common.CORSConfig`
|
* **Configurable**: Customize CORS settings via `common.CORSConfig`
|
||||||
|
|
||||||
**Migration Notes**:
|
|
||||||
|
|
||||||
* Update your code to register models BEFORE calling SetupMuxRoutes/SetupBunRouterRoutes
|
|
||||||
* Routes like `/public/users` are now created per registered model instead of using dynamic `/{schema}/{entity}` pattern
|
|
||||||
* This is a **breaking change** but provides better control and flexibility
|
|
||||||
|
|
||||||
### v2.1
|
### v2.1
|
||||||
|
|
||||||
**Cursor Pagination for ResolveSpec (🆕 Dec 9, 2025)**:
|
**Cursor Pagination for ResolveSpec (🆕 Dec 9, 2025)**:
|
||||||
@@ -589,7 +690,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
|||||||
* **BunRouter Integration**: Built-in support for uptrace/bunrouter
|
* **BunRouter Integration**: Built-in support for uptrace/bunrouter
|
||||||
* **Better Architecture**: Clean separation of concerns with interfaces
|
* **Better Architecture**: Clean separation of concerns with interfaces
|
||||||
* **Enhanced Testing**: Mockable interfaces for comprehensive testing
|
* **Enhanced Testing**: Mockable interfaces for comprehensive testing
|
||||||
* **Migration Guide**: Step-by-step migration instructions
|
|
||||||
|
|
||||||
**Performance Improvements**:
|
**Performance Improvements**:
|
||||||
|
|
||||||
@@ -606,4 +706,3 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file
|
|||||||
* Slogan generated using DALL-E
|
* Slogan generated using DALL-E
|
||||||
* AI used for documentation checking and correction
|
* AI used for documentation checking and correction
|
||||||
* Community feedback and contributions that made v2.0 and v2.1 possible
|
* Community feedback and contributions that made v2.0 and v2.1 possible
|
||||||
|
|
||||||
|
|||||||
+34
-7
@@ -1,17 +1,26 @@
|
|||||||
# ResolveSpec Test Server Configuration
|
# ResolveSpec Test Server Configuration
|
||||||
# This is a minimal configuration for the test server
|
# This is a minimal configuration for the test server
|
||||||
|
|
||||||
server:
|
servers:
|
||||||
addr: ":8080"
|
default_server: "main"
|
||||||
shutdown_timeout: 30s
|
shutdown_timeout: 30s
|
||||||
drain_timeout: 25s
|
drain_timeout: 25s
|
||||||
read_timeout: 10s
|
read_timeout: 10s
|
||||||
write_timeout: 10s
|
write_timeout: 10s
|
||||||
idle_timeout: 120s
|
idle_timeout: 120s
|
||||||
|
instances:
|
||||||
|
main:
|
||||||
|
name: "main"
|
||||||
|
host: "localhost"
|
||||||
|
port: 8080
|
||||||
|
description: "Main server instance"
|
||||||
|
gzip: true
|
||||||
|
tags:
|
||||||
|
env: "test"
|
||||||
|
|
||||||
logger:
|
logger:
|
||||||
dev: true # Enable development mode for readable logs
|
dev: true
|
||||||
path: "" # Empty means log to stdout
|
path: ""
|
||||||
|
|
||||||
cache:
|
cache:
|
||||||
provider: "memory"
|
provider: "memory"
|
||||||
@@ -19,7 +28,7 @@ cache:
|
|||||||
middleware:
|
middleware:
|
||||||
rate_limit_rps: 100.0
|
rate_limit_rps: 100.0
|
||||||
rate_limit_burst: 200
|
rate_limit_burst: 200
|
||||||
max_request_size: 10485760 # 10MB
|
max_request_size: 10485760
|
||||||
|
|
||||||
cors:
|
cors:
|
||||||
allowed_origins:
|
allowed_origins:
|
||||||
@@ -36,8 +45,25 @@ cors:
|
|||||||
|
|
||||||
tracing:
|
tracing:
|
||||||
enabled: false
|
enabled: false
|
||||||
|
service_name: "resolvespec"
|
||||||
|
service_version: "1.0.0"
|
||||||
|
endpoint: ""
|
||||||
|
|
||||||
|
error_tracking:
|
||||||
|
enabled: false
|
||||||
|
provider: "noop"
|
||||||
|
environment: "development"
|
||||||
|
sample_rate: 1.0
|
||||||
|
traces_sample_rate: 0.1
|
||||||
|
|
||||||
|
event_broker:
|
||||||
|
enabled: false
|
||||||
|
provider: "memory"
|
||||||
|
mode: "sync"
|
||||||
|
worker_count: 1
|
||||||
|
buffer_size: 100
|
||||||
|
instance_id: ""
|
||||||
|
|
||||||
# Database Manager Configuration
|
|
||||||
dbmanager:
|
dbmanager:
|
||||||
default_connection: "primary"
|
default_connection: "primary"
|
||||||
max_open_conns: 25
|
max_open_conns: 25
|
||||||
@@ -48,7 +74,6 @@ dbmanager:
|
|||||||
retry_delay: 1s
|
retry_delay: 1s
|
||||||
health_check_interval: 30s
|
health_check_interval: 30s
|
||||||
enable_auto_reconnect: true
|
enable_auto_reconnect: true
|
||||||
|
|
||||||
connections:
|
connections:
|
||||||
primary:
|
primary:
|
||||||
name: "primary"
|
name: "primary"
|
||||||
@@ -59,3 +84,5 @@ dbmanager:
|
|||||||
enable_metrics: false
|
enable_metrics: false
|
||||||
connect_timeout: 10s
|
connect_timeout: 10s
|
||||||
query_timeout: 30s
|
query_timeout: 30s
|
||||||
|
|
||||||
|
paths: {}
|
||||||
|
|||||||
+80
-9
@@ -2,29 +2,38 @@
|
|||||||
# This file demonstrates all available configuration options
|
# This file demonstrates all available configuration options
|
||||||
# Copy this file to config.yaml and customize as needed
|
# Copy this file to config.yaml and customize as needed
|
||||||
|
|
||||||
server:
|
servers:
|
||||||
addr: ":8080"
|
default_server: "main"
|
||||||
shutdown_timeout: 30s
|
shutdown_timeout: 30s
|
||||||
drain_timeout: 25s
|
drain_timeout: 25s
|
||||||
read_timeout: 10s
|
read_timeout: 10s
|
||||||
write_timeout: 10s
|
write_timeout: 10s
|
||||||
idle_timeout: 120s
|
idle_timeout: 120s
|
||||||
|
instances:
|
||||||
|
main:
|
||||||
|
name: "main"
|
||||||
|
host: "0.0.0.0"
|
||||||
|
port: 8080
|
||||||
|
description: "Main API server"
|
||||||
|
gzip: true
|
||||||
|
tags:
|
||||||
|
env: "development"
|
||||||
|
version: "1.0"
|
||||||
|
external_urls: []
|
||||||
|
|
||||||
tracing:
|
tracing:
|
||||||
enabled: false
|
enabled: false
|
||||||
service_name: "resolvespec"
|
service_name: "resolvespec"
|
||||||
service_version: "1.0.0"
|
service_version: "1.0.0"
|
||||||
endpoint: "http://localhost:4318/v1/traces" # OTLP endpoint
|
endpoint: "http://localhost:4318/v1/traces"
|
||||||
|
|
||||||
cache:
|
cache:
|
||||||
provider: "memory" # Options: memory, redis, memcache
|
provider: "memory"
|
||||||
|
|
||||||
redis:
|
redis:
|
||||||
host: "localhost"
|
host: "localhost"
|
||||||
port: 6379
|
port: 6379
|
||||||
password: ""
|
password: ""
|
||||||
db: 0
|
db: 0
|
||||||
|
|
||||||
memcache:
|
memcache:
|
||||||
servers:
|
servers:
|
||||||
- "localhost:11211"
|
- "localhost:11211"
|
||||||
@@ -33,12 +42,12 @@ cache:
|
|||||||
|
|
||||||
logger:
|
logger:
|
||||||
dev: false
|
dev: false
|
||||||
path: "" # Empty for stdout, or specify file path
|
path: ""
|
||||||
|
|
||||||
middleware:
|
middleware:
|
||||||
rate_limit_rps: 100.0
|
rate_limit_rps: 100.0
|
||||||
rate_limit_burst: 200
|
rate_limit_burst: 200
|
||||||
max_request_size: 10485760 # 10MB in bytes
|
max_request_size: 10485760
|
||||||
|
|
||||||
cors:
|
cors:
|
||||||
allowed_origins:
|
allowed_origins:
|
||||||
@@ -53,5 +62,67 @@ cors:
|
|||||||
- "*"
|
- "*"
|
||||||
max_age: 3600
|
max_age: 3600
|
||||||
|
|
||||||
|
error_tracking:
|
||||||
|
enabled: false
|
||||||
|
provider: "noop"
|
||||||
|
environment: "development"
|
||||||
|
sample_rate: 1.0
|
||||||
|
traces_sample_rate: 0.1
|
||||||
|
|
||||||
|
event_broker:
|
||||||
|
enabled: false
|
||||||
|
provider: "memory"
|
||||||
|
mode: "sync"
|
||||||
|
worker_count: 1
|
||||||
|
buffer_size: 100
|
||||||
|
instance_id: ""
|
||||||
|
redis:
|
||||||
|
stream_name: "events"
|
||||||
|
consumer_group: "app"
|
||||||
|
max_len: 1000
|
||||||
|
host: "localhost"
|
||||||
|
port: 6379
|
||||||
|
password: ""
|
||||||
|
db: 0
|
||||||
|
nats:
|
||||||
|
url: "nats://localhost:4222"
|
||||||
|
stream_name: "events"
|
||||||
|
storage: "file"
|
||||||
|
max_age: 24h
|
||||||
database:
|
database:
|
||||||
url: "host=localhost user=postgres password=postgres dbname=resolvespec_test port=5434 sslmode=disable"
|
table_name: "events"
|
||||||
|
channel: "events"
|
||||||
|
poll_interval: 5s
|
||||||
|
retry_policy:
|
||||||
|
max_retries: 3
|
||||||
|
initial_delay: 1s
|
||||||
|
max_delay: 1m
|
||||||
|
backoff_factor: 2.0
|
||||||
|
|
||||||
|
dbmanager:
|
||||||
|
default_connection: "primary"
|
||||||
|
max_open_conns: 25
|
||||||
|
max_idle_conns: 5
|
||||||
|
conn_max_lifetime: 30m
|
||||||
|
conn_max_idle_time: 5m
|
||||||
|
retry_attempts: 3
|
||||||
|
retry_delay: 1s
|
||||||
|
health_check_interval: 30s
|
||||||
|
enable_auto_reconnect: true
|
||||||
|
connections:
|
||||||
|
primary:
|
||||||
|
name: "primary"
|
||||||
|
type: "pgsql"
|
||||||
|
url: "host=localhost user=postgres password=postgres dbname=resolvespec port=5432 sslmode=disable"
|
||||||
|
default_orm: "gorm"
|
||||||
|
enable_logging: false
|
||||||
|
enable_metrics: false
|
||||||
|
connect_timeout: 10s
|
||||||
|
query_timeout: 30s
|
||||||
|
|
||||||
|
paths:
|
||||||
|
data_dir: "./data"
|
||||||
|
log_dir: "./logs"
|
||||||
|
cache_dir: "./cache"
|
||||||
|
|
||||||
|
extensions: {}
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 352 KiB After Width: | Height: | Size: 95 KiB |
@@ -1,46 +1,46 @@
|
|||||||
module github.com/bitechdev/ResolveSpec
|
module github.com/bitechdev/ResolveSpec
|
||||||
|
|
||||||
go 1.24.0
|
go 1.25.7
|
||||||
|
|
||||||
toolchain go1.24.6
|
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2
|
github.com/DATA-DOG/go-sqlmock v1.5.2
|
||||||
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf
|
github.com/bradfitz/gomemcache v0.0.0-20260422231931-4d751bb6e37c
|
||||||
github.com/eclipse/paho.mqtt.golang v1.5.1
|
github.com/eclipse/paho.mqtt.golang v1.5.1
|
||||||
github.com/getsentry/sentry-go v0.40.0
|
github.com/getsentry/sentry-go v0.46.2
|
||||||
github.com/glebarez/sqlite v1.11.0
|
github.com/glebarez/sqlite v1.11.0
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gorilla/mux v1.8.1
|
github.com/gorilla/mux v1.8.1
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/jackc/pgx/v5 v5.8.0
|
github.com/jackc/pgx/v5 v5.9.2
|
||||||
github.com/klauspost/compress v1.18.2
|
github.com/klauspost/compress v1.18.6
|
||||||
github.com/mattn/go-sqlite3 v1.14.33
|
github.com/mark3labs/mcp-go v0.54.0
|
||||||
github.com/microsoft/go-mssqldb v1.9.5
|
github.com/mattn/go-sqlite3 v1.14.44
|
||||||
|
github.com/microsoft/go-mssqldb v1.10.0
|
||||||
github.com/mochi-mqtt/server/v2 v2.7.9
|
github.com/mochi-mqtt/server/v2 v2.7.9
|
||||||
github.com/nats-io/nats.go v1.48.0
|
github.com/nats-io/nats.go v1.52.0
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
github.com/redis/go-redis/v9 v9.17.2
|
github.com/redis/go-redis/v9 v9.19.0
|
||||||
github.com/spf13/viper v1.21.0
|
github.com/spf13/viper v1.21.0
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/testcontainers/testcontainers-go v0.40.0
|
github.com/testcontainers/testcontainers-go v0.40.0
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.19.0
|
||||||
github.com/tidwall/sjson v1.2.5
|
github.com/tidwall/sjson v1.2.5
|
||||||
github.com/uptrace/bun v1.2.16
|
github.com/uptrace/bun v1.2.18
|
||||||
github.com/uptrace/bun/dialect/mssqldialect v1.2.16
|
github.com/uptrace/bun/dialect/mssqldialect v1.2.16
|
||||||
github.com/uptrace/bun/dialect/pgdialect v1.2.16
|
github.com/uptrace/bun/dialect/pgdialect v1.2.16
|
||||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16
|
||||||
github.com/uptrace/bun/driver/sqliteshim v1.2.16
|
github.com/uptrace/bun/driver/sqliteshim v1.2.16
|
||||||
github.com/uptrace/bunrouter v1.0.23
|
github.com/uptrace/bunrouter v1.0.23
|
||||||
go.mongodb.org/mongo-driver v1.17.6
|
go.mongodb.org/mongo-driver v1.17.9
|
||||||
go.opentelemetry.io/otel v1.38.0
|
go.opentelemetry.io/otel v1.43.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0
|
||||||
go.opentelemetry.io/otel/sdk v1.38.0
|
go.opentelemetry.io/otel/sdk v1.43.0
|
||||||
go.opentelemetry.io/otel/trace v1.38.0
|
go.opentelemetry.io/otel/trace v1.43.0
|
||||||
go.uber.org/zap v1.27.1
|
go.uber.org/zap v1.28.0
|
||||||
golang.org/x/crypto v0.46.0
|
golang.org/x/crypto v0.51.0
|
||||||
golang.org/x/time v0.14.0
|
golang.org/x/oauth2 v0.36.0
|
||||||
|
golang.org/x/time v0.15.0
|
||||||
gorm.io/driver/postgres v1.6.0
|
gorm.io/driver/postgres v1.6.0
|
||||||
gorm.io/driver/sqlite v1.6.0
|
gorm.io/driver/sqlite v1.6.0
|
||||||
gorm.io/driver/sqlserver v1.6.3
|
gorm.io/driver/sqlserver v1.6.3
|
||||||
@@ -60,7 +60,7 @@ require (
|
|||||||
github.com/containerd/log v0.1.0 // indirect
|
github.com/containerd/log v0.1.0 // indirect
|
||||||
github.com/containerd/platforms v0.2.1 // indirect
|
github.com/containerd/platforms v0.2.1 // indirect
|
||||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
github.com/distribution/reference v0.6.0 // indirect
|
github.com/distribution/reference v0.6.0 // indirect
|
||||||
github.com/docker/docker v28.5.1+incompatible // indirect
|
github.com/docker/docker v28.5.1+incompatible // indirect
|
||||||
@@ -69,16 +69,17 @@ require (
|
|||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||||
github.com/ebitengine/purego v0.8.4 // indirect
|
github.com/ebitengine/purego v0.8.4 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
github.com/fsnotify/fsnotify v1.10.1 // indirect
|
||||||
github.com/glebarez/go-sqlite v1.22.0 // indirect
|
github.com/glebarez/go-sqlite v1.22.0 // indirect
|
||||||
github.com/go-logr/logr v1.4.3 // indirect
|
github.com/go-logr/logr v1.4.3 // indirect
|
||||||
github.com/go-logr/stdr v1.2.2 // indirect
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/go-ole/go-ole v1.2.6 // indirect
|
github.com/go-ole/go-ole v1.2.6 // indirect
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
|
||||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||||
github.com/golang/snappy v1.0.0 // indirect
|
github.com/golang/snappy v1.0.0 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 // indirect
|
github.com/google/jsonschema-go v0.4.3 // indirect
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0 // indirect
|
||||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||||
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
github.com/jackc/puddle/v2 v2.2.2 // indirect
|
||||||
@@ -86,7 +87,7 @@ require (
|
|||||||
github.com/jinzhu/now v1.1.5 // indirect
|
github.com/jinzhu/now v1.1.5 // indirect
|
||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
|
||||||
github.com/magiconair/properties v1.8.10 // indirect
|
github.com/magiconair/properties v1.8.10 // indirect
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
github.com/mattn/go-isatty v0.0.22 // indirect
|
||||||
github.com/moby/docker-image-spec v1.3.1 // indirect
|
github.com/moby/docker-image-spec v1.3.1 // indirect
|
||||||
github.com/moby/go-archive v0.1.0 // indirect
|
github.com/moby/go-archive v0.1.0 // indirect
|
||||||
github.com/moby/patternmatcher v0.6.0 // indirect
|
github.com/moby/patternmatcher v0.6.0 // indirect
|
||||||
@@ -94,25 +95,26 @@ require (
|
|||||||
github.com/moby/sys/user v0.4.0 // indirect
|
github.com/moby/sys/user v0.4.0 // indirect
|
||||||
github.com/moby/sys/userns v0.1.0 // indirect
|
github.com/moby/sys/userns v0.1.0 // indirect
|
||||||
github.com/moby/term v0.5.0 // indirect
|
github.com/moby/term v0.5.0 // indirect
|
||||||
github.com/montanaflynn/stats v0.7.1 // indirect
|
github.com/montanaflynn/stats v0.9.0 // indirect
|
||||||
github.com/morikuni/aec v1.0.0 // indirect
|
github.com/morikuni/aec v1.0.0 // indirect
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||||
github.com/nats-io/nkeys v0.4.11 // indirect
|
github.com/nats-io/nkeys v0.4.15 // indirect
|
||||||
github.com/nats-io/nuid v1.0.1 // indirect
|
github.com/nats-io/nuid v1.0.1 // indirect
|
||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
github.com/ncruces/go-strftime v1.0.0 // indirect
|
||||||
github.com/opencontainers/go-digest v1.0.0 // indirect
|
github.com/opencontainers/go-digest v1.0.0 // indirect
|
||||||
github.com/opencontainers/image-spec v1.1.1 // indirect
|
github.com/opencontainers/image-spec v1.1.1 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
github.com/pelletier/go-toml/v2 v2.3.1 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
|
||||||
github.com/prometheus/client_model v0.6.2 // indirect
|
github.com/prometheus/client_model v0.6.2 // indirect
|
||||||
github.com/prometheus/common v0.67.4 // indirect
|
github.com/prometheus/common v0.67.5 // indirect
|
||||||
github.com/prometheus/procfs v0.19.2 // indirect
|
github.com/prometheus/procfs v0.20.1 // indirect
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||||
github.com/rs/xid v1.4.0 // indirect
|
github.com/rs/xid v1.6.0 // indirect
|
||||||
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
github.com/sagikazarmark/locafero v0.12.0 // indirect
|
||||||
|
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||||
github.com/shopspring/decimal v1.4.0 // indirect
|
github.com/shopspring/decimal v1.4.0 // indirect
|
||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
@@ -131,31 +133,32 @@ require (
|
|||||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||||
github.com/xdg-go/scram v1.2.0 // indirect
|
github.com/xdg-go/scram v1.2.0 // indirect
|
||||||
github.com/xdg-go/stringprep v1.0.4 // indirect
|
github.com/xdg-go/stringprep v1.0.4 // indirect
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect
|
||||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
go.opentelemetry.io/otel/metric v1.43.0 // indirect
|
||||||
go.opentelemetry.io/proto/otlp v1.7.1 // indirect
|
go.opentelemetry.io/proto/otlp v1.10.0 // indirect
|
||||||
|
go.uber.org/atomic v1.11.0 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
go.yaml.in/yaml/v2 v2.4.4 // indirect
|
||||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect
|
golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a // indirect
|
||||||
golang.org/x/mod v0.31.0 // indirect
|
golang.org/x/mod v0.36.0 // indirect
|
||||||
golang.org/x/net v0.48.0 // indirect
|
golang.org/x/net v0.54.0 // indirect
|
||||||
golang.org/x/oauth2 v0.34.0 // indirect
|
golang.org/x/sync v0.20.0 // indirect
|
||||||
golang.org/x/sync v0.19.0 // indirect
|
golang.org/x/sys v0.44.0 // indirect
|
||||||
golang.org/x/sys v0.39.0 // indirect
|
golang.org/x/text v0.37.0 // indirect
|
||||||
golang.org/x/text v0.32.0 // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20260519071638-aa98bba5eb94 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260519071638-aa98bba5eb94 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 // indirect
|
google.golang.org/grpc v1.81.1 // indirect
|
||||||
google.golang.org/grpc v1.75.0 // indirect
|
|
||||||
google.golang.org/protobuf v1.36.11 // indirect
|
google.golang.org/protobuf v1.36.11 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
modernc.org/libc v1.67.4 // indirect
|
modernc.org/libc v1.72.3 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.11.0 // indirect
|
modernc.org/memory v1.11.0 // indirect
|
||||||
modernc.org/sqlite v1.42.2 // indirect
|
modernc.org/sqlite v1.50.1 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
replace github.com/uptrace/bun => github.com/warkanum/bun v1.2.17
|
replace github.com/uptrace/bun => github.com/warkanum/bun v1.2.17
|
||||||
|
|||||||
@@ -7,27 +7,33 @@ github.com/Azure/azure-sdk-for-go/sdk/azcore v1.7.1/go.mod h1:bjGvMhVMb+EEm3VRNQ
|
|||||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U=
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0 h1:Gt0j3wceWMwPmiazCa8MzMA0MfhmPIz0Qp0FJ6qcM0U=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM=
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.18.0/go.mod h1:Ot/6aikWnKWi4l9QB7qVSwa8iMphQNqkWALMoNT3rzM=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.1 h1:jHb/wfvRikGdxMXYV3QG/SzUOPYN9KEUUuC0Yd0/vC0=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1/go.mod h1:uE9zaUfEQT/nbQjVi2IblCG9iaLtZsuYZ8ne+PuQ02M=
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.3.1/go.mod h1:uE9zaUfEQT/nbQjVi2IblCG9iaLtZsuYZ8ne+PuQ02M=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg=
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.6.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1 h1:B+blDbyVIG3WaikNxPnhPiJ1MThR03b3vKGtER95TP4=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.10.1/go.mod h1:JdM5psgjfBf5fo2uWOZhflPWyDBZ/O/CNAH9CtsuZE4=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM=
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc=
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg=
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.8.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4=
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1 h1:FPKJS1T+clwv+OLGt13a8UjqeRuh0O4SJ3lUriThc+4=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA=
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.1/go.mod h1:j2chePtV91HrC22tGoRX3sGY42uF13WzmmV80/OdVAA=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/internal v1.12.0 h1:fhqpLE3UEXi9lPaBRpQ6XuRW0nU7hgg4zlmZZa+a9q4=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1/go.mod h1:GpPjLhVR9dnUoJMyHWSPy71xY9/lcmpzIPZXmF0FCVY=
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.0.1/go.mod h1:GpPjLhVR9dnUoJMyHWSPy71xY9/lcmpzIPZXmF0FCVY=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w=
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1 h1:Wgf5rZba3YZqeTNJPtvqZoBu1sBN/L4sry+u2U3Y75w=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14=
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.3.1/go.mod h1:xxCBG/f/4Vbmh2XQJBsOmNdxWUY5j/s27jujKPbQf14=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/azkeys v1.4.0 h1:E4MgwLBGeVB5f2MdcIVD3ELVAWpr+WD6MUe1i+tM/PA=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0/go.mod h1:bTSOgj05NGRuHHhQwAdPnYr9TOdNmKlZTgGLL6nyAdI=
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.0.0/go.mod h1:bTSOgj05NGRuHHhQwAdPnYr9TOdNmKlZTgGLL6nyAdI=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI=
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1 h1:bFWuoEKg+gImo7pvkiQEFAc8ocibADgXeiLAxWhWmkI=
|
||||||
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww=
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.1.1/go.mod h1:Vih/3yc6yac2JzU4hzpaDupBJP0Flaia9rXXrU8xyww=
|
||||||
|
github.com/Azure/azure-sdk-for-go/sdk/security/keyvault/internal v1.2.0 h1:nCYfgcSyHZXJI8J0IWE5MsCGlb2xp9fJiXyxWgmOFg4=
|
||||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
|
||||||
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
|
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2 h1:oygO0locgZJe7PpYPXT5A29ZkwJaPqcva7BVeemZOZs=
|
||||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
github.com/AzureAD/microsoft-authentication-library-for-go v1.4.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI=
|
||||||
|
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
@@ -36,6 +42,8 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
|||||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
|
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf h1:TqhNAT4zKbTdLa62d2HDBFdvgSbIGB3eJE8HqhgiL9I=
|
||||||
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c=
|
github.com/bradfitz/gomemcache v0.0.0-20250403215159-8d39553ac7cf/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c=
|
||||||
|
github.com/bradfitz/gomemcache v0.0.0-20260422231931-4d751bb6e37c h1:6Gpm9YYUEQx2T9zMsYolQhr6sjwwGtFitSA0pQsa7a8=
|
||||||
|
github.com/bradfitz/gomemcache v0.0.0-20260422231931-4d751bb6e37c/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c=
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||||
@@ -62,6 +70,8 @@ github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr
|
|||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
|
||||||
|
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||||
@@ -86,8 +96,12 @@ github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHk
|
|||||||
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
|
github.com/fsnotify/fsnotify v1.10.1 h1:b0/UzAf9yR5rhf3RPm9gf3ehBPpf0oZKIjtpKrx59Ho=
|
||||||
|
github.com/fsnotify/fsnotify v1.10.1/go.mod h1:TLheqan6HD6GBK6PrDWyDPBaEV8LspOxvPSjC+bVfgo=
|
||||||
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
|
github.com/getsentry/sentry-go v0.40.0 h1:VTJMN9zbTvqDqPwheRVLcp0qcUcM+8eFivvGocAaSbo=
|
||||||
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
|
github.com/getsentry/sentry-go v0.40.0/go.mod h1:eRXCoh3uvmjQLY6qu63BjUZnaBu5L5WhMV1RwYO8W5s=
|
||||||
|
github.com/getsentry/sentry-go v0.46.2 h1:1jhYwrKGa3sIpo/y5iDNXS5wDoT7I1KNzMHrnK6ojns=
|
||||||
|
github.com/getsentry/sentry-go v0.46.2/go.mod h1:evVbw2qotNUdYG8KxXbAdjOQWWvWIwKxpjdZZIvcIPw=
|
||||||
github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ=
|
github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ=
|
||||||
github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc=
|
github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc=
|
||||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||||
@@ -103,11 +117,14 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
|||||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
|
||||||
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro=
|
||||||
|
github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
|
||||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY=
|
||||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA=
|
||||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
|
||||||
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A=
|
||||||
@@ -120,6 +137,10 @@ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/google/jsonschema-go v0.4.2 h1:tmrUohrwoLZZS/P3x7ex0WAVknEkBZM46iALbcqoRA8=
|
||||||
|
github.com/google/jsonschema-go v0.4.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||||
|
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
|
||||||
|
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
||||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
@@ -133,6 +154,8 @@ github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aN
|
|||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnVTyacbefKhmbLhIhU=
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs=
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0 h1:5VipnvEpbqr2gA2VbM+nYVbkIF28c5ZQfqCBQ5g2xfk=
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0/go.mod h1:Hyl3n6Twe1hvtd9XUXDec4pTvgMSEixRuQKPTMH2bNs=
|
||||||
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||||
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
@@ -143,6 +166,8 @@ github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7Ulw
|
|||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||||
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo=
|
||||||
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw=
|
||||||
|
github.com/jackc/pgx/v5 v5.9.2 h1:3ZhOzMWnR4yJ+RW1XImIPsD1aNSz4T4fyP7zlQb56hw=
|
||||||
|
github.com/jackc/pgx/v5 v5.9.2/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4=
|
||||||
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
|
||||||
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
|
github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
|
||||||
@@ -160,6 +185,8 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/
|
|||||||
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||||
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
||||||
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||||
|
github.com/klauspost/compress v1.18.6 h1:2jupLlAwFm95+YDR+NwD2MEfFO9d4z4Prjl1XXDjuao=
|
||||||
|
github.com/klauspost/compress v1.18.6/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=
|
||||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
@@ -173,13 +200,23 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ
|
|||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||||
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
|
||||||
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||||
|
github.com/mark3labs/mcp-go v0.46.0 h1:8KRibF4wcKejbLsHxCA/QBVUr5fQ9nwz/n8lGqmaALo=
|
||||||
|
github.com/mark3labs/mcp-go v0.46.0/go.mod h1:JKTC7R2LLVagkEWK7Kwu7DbmA6iIvnNAod6yrHiQMag=
|
||||||
|
github.com/mark3labs/mcp-go v0.54.0 h1:PZhQvd+5xrT43cUoiaKn/hDcvLUhcLc1twSEKYPTcTA=
|
||||||
|
github.com/mark3labs/mcp-go v0.54.0/go.mod h1:+8WclSK1ZUweCP3hvktSji8n8ABG/95QaEkeVE/Uwas=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/mattn/go-isatty v0.0.22 h1:j8l17JJ9i6VGPUFUYoTUKPSgKe/83EYU2zBC7YNKMw4=
|
||||||
|
github.com/mattn/go-isatty v0.0.22/go.mod h1:ZXfXG4SQHsB/w3ZeOYbR0PrPwLy+n6xiMrJlRFqopa4=
|
||||||
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
|
github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0=
|
||||||
github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.44 h1:3VSe+xafpbzsLbdr2AWlAZk9yRHiBhTBakioXaCKTF8=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.44/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ=
|
||||||
github.com/microsoft/go-mssqldb v1.8.2/go.mod h1:vp38dT33FGfVotRiTmDo3bFyaHq+p3LektQrjTULowo=
|
github.com/microsoft/go-mssqldb v1.8.2/go.mod h1:vp38dT33FGfVotRiTmDo3bFyaHq+p3LektQrjTULowo=
|
||||||
github.com/microsoft/go-mssqldb v1.9.5 h1:orwya0X/5bsL1o+KasupTkk2eNTNFkTQG0BEe/HxCn0=
|
github.com/microsoft/go-mssqldb v1.9.5 h1:orwya0X/5bsL1o+KasupTkk2eNTNFkTQG0BEe/HxCn0=
|
||||||
github.com/microsoft/go-mssqldb v1.9.5/go.mod h1:VCP2a0KEZZtGLRHd1PsLavLFYy/3xX2yJUPycv3Sr2Q=
|
github.com/microsoft/go-mssqldb v1.9.5/go.mod h1:VCP2a0KEZZtGLRHd1PsLavLFYy/3xX2yJUPycv3Sr2Q=
|
||||||
|
github.com/microsoft/go-mssqldb v1.10.0 h1:pHEt+Qz6YFPWqREq10mqSE524QQo+/QremwTCQht7TY=
|
||||||
|
github.com/microsoft/go-mssqldb v1.10.0/go.mod h1:mnG7lGa9iYJbzJqGCXyuQCegStKMr3kogDLD6+bmggg=
|
||||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||||
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
github.com/moby/go-archive v0.1.0 h1:Kk/5rdW/g+H8NHdJW2gsXyZ7UnzvJNOy6VKJqueWdcQ=
|
||||||
@@ -202,14 +239,20 @@ github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3P
|
|||||||
github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
||||||
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
|
github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE=
|
||||||
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
||||||
|
github.com/montanaflynn/stats v0.9.0 h1:tsBJ0RXwph9BmAuFoCmqGv6e8xa0MENQ8m0ptKq29mQ=
|
||||||
|
github.com/montanaflynn/stats v0.9.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
|
||||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
|
||||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
|
||||||
github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U=
|
github.com/nats-io/nats.go v1.48.0 h1:pSFyXApG+yWU/TgbKCjmm5K4wrHu86231/w84qRVR+U=
|
||||||
github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g=
|
github.com/nats-io/nats.go v1.48.0/go.mod h1:iRWIPokVIFbVijxuMQq4y9ttaBTMe0SFdlZfMDd+33g=
|
||||||
|
github.com/nats-io/nats.go v1.52.0 h1:n3avV4VBsCgsdwh71TppsTwtv+QdPs7ntSKM8qJLGsc=
|
||||||
|
github.com/nats-io/nats.go v1.52.0/go.mod h1:26HypzazeOkyO3/mqd1zZd53STJN0EjCYF9Uy2ZOBno=
|
||||||
github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0=
|
github.com/nats-io/nkeys v0.4.11 h1:q44qGV008kYd9W1b1nEBkNzvnWxtRSQ7A8BoqRrcfa0=
|
||||||
github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE=
|
github.com/nats-io/nkeys v0.4.11/go.mod h1:szDimtgmfOi9n25JpfIdGw12tZFYXqhGxjhVxsatHVE=
|
||||||
|
github.com/nats-io/nkeys v0.4.15 h1:JACV5jRVO9V856KOapQ7x+EY8Jo3qw1vJt/9Jpwzkk4=
|
||||||
|
github.com/nats-io/nkeys v0.4.15/go.mod h1:CpMchTXC9fxA5zrMo4KpySxNjiDVvr8ANOSZdiNfUrs=
|
||||||
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
|
github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw=
|
||||||
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
|
github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c=
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
||||||
@@ -220,6 +263,8 @@ github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJw
|
|||||||
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.3.1/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4=
|
||||||
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
|
github.com/pingcap/errors v0.11.4/go.mod h1:Oi8TUi2kEtXXLMJk9l1cGmz20kV3TaQ0usTwv5KuLY8=
|
||||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
|
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
|
||||||
@@ -238,22 +283,33 @@ github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNw
|
|||||||
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE=
|
||||||
github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc=
|
github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+Lvsc=
|
||||||
github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI=
|
github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI=
|
||||||
|
github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4=
|
||||||
|
github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw=
|
||||||
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
|
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
|
||||||
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
|
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
|
||||||
|
github.com/prometheus/procfs v0.20.1 h1:XwbrGOIplXW/AU3YhIhLODXMJYyC1isLFfYCsTEycfc=
|
||||||
|
github.com/prometheus/procfs v0.20.1/go.mod h1:o9EMBZGRyvDrSPH1RqdxhojkuXstoe4UlK79eF5TGGo=
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
github.com/puzpuzpuz/xsync/v3 v3.5.1 h1:GJYJZwO6IdxN/IKbneznS6yPkVC+c3zyY/j19c++5Fg=
|
||||||
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
github.com/puzpuzpuz/xsync/v3 v3.5.1/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
|
||||||
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
|
||||||
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||||
|
github.com/redis/go-redis/v9 v9.19.0 h1:XPVaaPSnG6RhYf7p+rmSa9zZfeVAnWsH5h3lxthOm/k=
|
||||||
|
github.com/redis/go-redis/v9 v9.19.0/go.mod h1:v/M13XI1PVCDcm01VtPFOADfZtHf8YW3baQf57KlIkA=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||||
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4=
|
||||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||||
|
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
|
||||||
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
github.com/rs/xid v1.4.0 h1:qd7wPTDkN6KQx2VmMBLrpHkiyQwgFXRnkOLacUiaSNY=
|
||||||
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||||
|
github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
|
||||||
|
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
|
||||||
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
|
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
|
||||||
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
|
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
|
||||||
|
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 h1:KRzFb2m7YtdldCEkzs6KqmJw4nqEVZGK7IN2kJkjTuQ=
|
||||||
|
github.com/santhosh-tekuri/jsonschema/v6 v6.0.2/go.mod h1:JXeL+ps8p7/KNMjDQk3TCwPpBy0wYklyWTfbkIzdIFU=
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||||
@@ -290,6 +346,8 @@ github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3
|
|||||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
|
||||||
|
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
|
||||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM=
|
||||||
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
@@ -306,10 +364,22 @@ github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYm
|
|||||||
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
|
||||||
github.com/uptrace/bun/dialect/mssqldialect v1.2.16 h1:rKv0cKPNBviXadB/+2Y/UedA/c1JnwGzUWZkdN5FdSQ=
|
github.com/uptrace/bun/dialect/mssqldialect v1.2.16 h1:rKv0cKPNBviXadB/+2Y/UedA/c1JnwGzUWZkdN5FdSQ=
|
||||||
github.com/uptrace/bun/dialect/mssqldialect v1.2.16/go.mod h1:J5U7tGKWDsx2Q7MwDZF2417jCdpD6yD/ZMFJcCR80bk=
|
github.com/uptrace/bun/dialect/mssqldialect v1.2.16/go.mod h1:J5U7tGKWDsx2Q7MwDZF2417jCdpD6yD/ZMFJcCR80bk=
|
||||||
|
github.com/uptrace/bun/dialect/mssqldialect v1.2.17 h1:xEUH4WamuY9rXT9d8wHVZanhmLJCrc4s4v7frDH/PMc=
|
||||||
|
github.com/uptrace/bun/dialect/mssqldialect v1.2.17/go.mod h1:i1NRx/5cz1nivwtV7FEb/gP3CIbRTj4AQC9/Q0lNVno=
|
||||||
|
github.com/uptrace/bun/dialect/mssqldialect v1.2.18 h1:nYzHoyJKJlIyl5i95Exi8ZTK8ooKWG+o3z3f404d/yQ=
|
||||||
|
github.com/uptrace/bun/dialect/mssqldialect v1.2.18/go.mod h1:Su45Je7z66sfeZ3d1ZsnOQEK8xfzGgaMzBvtoE8yFhk=
|
||||||
github.com/uptrace/bun/dialect/pgdialect v1.2.16 h1:KFNZ0LxAyczKNfK/IJWMyaleO6eI9/Z5tUv3DE1NVL4=
|
github.com/uptrace/bun/dialect/pgdialect v1.2.16 h1:KFNZ0LxAyczKNfK/IJWMyaleO6eI9/Z5tUv3DE1NVL4=
|
||||||
github.com/uptrace/bun/dialect/pgdialect v1.2.16/go.mod h1:IJdMeV4sLfh0LDUZl7TIxLI0LipF1vwTK3hBC7p5qLo=
|
github.com/uptrace/bun/dialect/pgdialect v1.2.16/go.mod h1:IJdMeV4sLfh0LDUZl7TIxLI0LipF1vwTK3hBC7p5qLo=
|
||||||
|
github.com/uptrace/bun/dialect/pgdialect v1.2.17 h1:DFmhOollvbYHvooxoS8ZIbiGC0wXIzstKeFUmWs+TP4=
|
||||||
|
github.com/uptrace/bun/dialect/pgdialect v1.2.17/go.mod h1:ej8ZDsvLETvyELlRDfUtIoA57sWnATv1GhOEVsuVG/k=
|
||||||
|
github.com/uptrace/bun/dialect/pgdialect v1.2.18 h1:IZ6nM2+OYrL8lkEAy7UkSEZvoa3vluTAUlZfPtlRB2k=
|
||||||
|
github.com/uptrace/bun/dialect/pgdialect v1.2.18/go.mod h1:Tqdf4QP1okrGYpXfodXvCOK6Ob1OOTwSaoAzCgBB3IU=
|
||||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16 h1:6wVAiYLj1pMibRthGwy4wDLa3D5AQo32Y8rvwPd8CQ0=
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16 h1:6wVAiYLj1pMibRthGwy4wDLa3D5AQo32Y8rvwPd8CQ0=
|
||||||
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16/go.mod h1:Z7+5qK8CGZkDQiPMu+LSdVuDuR1I5jcwtkB1Pi3F82E=
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.16/go.mod h1:Z7+5qK8CGZkDQiPMu+LSdVuDuR1I5jcwtkB1Pi3F82E=
|
||||||
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.17 h1:ZipEoNr+wQJQleGy2poKSSoaQDavzc+nXTDp3ZzkA0E=
|
||||||
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.17/go.mod h1:phXmrxxeYqUhMU09FgazbfNxq9LlArdqjZqHc1ILy9U=
|
||||||
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.18 h1:Z33SY/U++XK9uGWqS4h8OZVxfCXguIG+sU9cYq2PGFQ=
|
||||||
|
github.com/uptrace/bun/dialect/sqlitedialect v1.2.18/go.mod h1:1MVOS/Ncy4FZbkJcgUFH6OqYoQinYNjkEwsmNQEXz2A=
|
||||||
github.com/uptrace/bun/driver/sqliteshim v1.2.16 h1:M6Dh5kkDWFbUWBrOsIE1g1zdZ5JbSytTD4piFRBOUAI=
|
github.com/uptrace/bun/driver/sqliteshim v1.2.16 h1:M6Dh5kkDWFbUWBrOsIE1g1zdZ5JbSytTD4piFRBOUAI=
|
||||||
github.com/uptrace/bun/driver/sqliteshim v1.2.16/go.mod h1:iKdJ06P3XS+pwKcONjSIK07bbhksH3lWsw3mpfr0+bY=
|
github.com/uptrace/bun/driver/sqliteshim v1.2.16/go.mod h1:iKdJ06P3XS+pwKcONjSIK07bbhksH3lWsw3mpfr0+bY=
|
||||||
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
|
github.com/uptrace/bunrouter v1.0.23 h1:Bi7NKw3uCQkcA/GUCtDNPq5LE5UdR9pe+UyWbjHB/wU=
|
||||||
@@ -326,6 +396,8 @@ github.com/xdg-go/scram v1.2.0 h1:bYKF2AEwG5rqd1BumT4gAnvwU/M9nBp2pTSxeZw7Wvs=
|
|||||||
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
|
github.com/xdg-go/scram v1.2.0/go.mod h1:3dlrS0iBaWKYVt2ZfA4cj48umJZ+cAEbR6/SjLA88I8=
|
||||||
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
|
||||||
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
|
||||||
|
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM=
|
||||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
@@ -333,36 +405,61 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo
|
|||||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
|
||||||
go.mongodb.org/mongo-driver v1.17.6 h1:87JUG1wZfWsr6rIz3ZmpH90rL5tea7O3IHuSwHUpsss=
|
go.mongodb.org/mongo-driver v1.17.6 h1:87JUG1wZfWsr6rIz3ZmpH90rL5tea7O3IHuSwHUpsss=
|
||||||
go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
go.mongodb.org/mongo-driver v1.17.6/go.mod h1:Hy04i7O2kC4RS06ZrhPRqj/u4DTYkFDAAccj+rVKqgQ=
|
||||||
|
go.mongodb.org/mongo-driver v1.17.9 h1:IexDdCuuNJ3BHrELgBlyaH9p60JXAvdzWR128q+U5tU=
|
||||||
|
go.mongodb.org/mongo-driver v1.17.9/go.mod h1:LlOhpH5NUEfhxcAwG0UEkMqwYcc4JU18gtCdGudk/tQ=
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk=
|
||||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
|
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw=
|
||||||
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
|
||||||
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
|
||||||
|
go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I=
|
||||||
|
go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0 h1:GqRJVj7UmLjCVyVJ3ZFLdPRmhDUp2zFmQe3RHIOsw24=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.38.0/go.mod h1:ri3aaHSmCTVYu2AWv44YMauwAQc0aqI9gHKIcSbI1pU=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bTWkw0ICGcOLCAI5l6zsD1j20k=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0 h1:lwI4Dc5leUqENgGuQImwLo4WnuXFPetmPpkLi2IrX54=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.38.0/go.mod h1:Kz/oCE7z5wuyhPxsXDuaPteSWqjSBD5YaSdbxZYGbGk=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 h1:RAE+JPfvEmvy+0LzyUA25/SGawPwIUbZ6u0Wug54sLc=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0/go.mod h1:AGmbycVGEsRx9mXMZ75CsOyhSP6MFIcj/6dnG+vhVjk=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0 h1:IeMeyr1aBvBiPVYihXIaeIZba6b8E1bYp7lbdxK8CQg=
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.19.0/go.mod h1:oVdCUtjq9MK9BlS7TtucsQwUcXcymNiEDjgDD2jMtZU=
|
||||||
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
|
||||||
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
|
||||||
|
go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM=
|
||||||
|
go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY=
|
||||||
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
|
go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E=
|
||||||
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
|
go.opentelemetry.io/otel/sdk v1.38.0/go.mod h1:ghmNdGlVemJI3+ZB5iDEuk4bWA3GkTpW+DOoZMYBVVg=
|
||||||
|
go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg=
|
||||||
|
go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg=
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
|
go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM=
|
||||||
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
|
go.opentelemetry.io/otel/sdk/metric v1.38.0/go.mod h1:dg9PBnW9XdQ1Hd6ZnRz689CbtrUp0wMMs9iPcgT9EZA=
|
||||||
|
go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw=
|
||||||
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
|
||||||
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
|
||||||
|
go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A=
|
||||||
|
go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0=
|
||||||
go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4=
|
go.opentelemetry.io/proto/otlp v1.7.1 h1:gTOMpGDb0WTBOP8JaO72iL3auEZhVmAQg4ipjOVAtj4=
|
||||||
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
|
go.opentelemetry.io/proto/otlp v1.7.1/go.mod h1:b2rVh6rfI/s2pHWNlB7ILJcRALpcNDzKhACevjI+ZnE=
|
||||||
|
go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g=
|
||||||
|
go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk=
|
||||||
|
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
|
||||||
|
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
|
||||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||||
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||||
|
go.uber.org/zap v1.28.0 h1:IZzaP1Fv73/T/pBMLk4VutPl36uNC+OSUh3JLG3FIjo=
|
||||||
|
go.uber.org/zap v1.28.0/go.mod h1:rDLpOi171uODNm/mxFcuYWxDsqWSAVkFdX4XojSKg/Q=
|
||||||
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0=
|
||||||
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8=
|
||||||
|
go.yaml.in/yaml/v2 v2.4.4 h1:tuyd0P+2Ont/d6e2rl3be67goVK4R6deVxCUX5vyPaQ=
|
||||||
|
go.yaml.in/yaml/v2 v2.4.4/go.mod h1:gMZqIpDtDqOfM0uNfy0SkpRhvUryYH0Z6wdMYcacYXQ=
|
||||||
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
|
||||||
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
@@ -379,8 +476,12 @@ golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v
|
|||||||
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
|
||||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||||
|
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
|
||||||
|
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
|
||||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
|
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
|
||||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
|
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
|
||||||
|
golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a h1:+3jdDGGB8NGb1Zktc737jlt3/A5f6UlwSzmvqUuufxw=
|
||||||
|
golang.org/x/exp v0.0.0-20260508232706-74f9aab9d74a/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw=
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
@@ -389,6 +490,8 @@ golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
|||||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||||
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI=
|
||||||
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg=
|
||||||
|
golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4=
|
||||||
|
golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ=
|
||||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
@@ -408,8 +511,12 @@ golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
|||||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||||
|
golang.org/x/net v0.54.0 h1:2zJIZAxAHV/OHCDTCOHAYehQzLfSXuf/5SoL/Dv6w/w=
|
||||||
|
golang.org/x/net v0.54.0/go.mod h1:Sj4oj8jK6XmHpBZU/zWHw3BV3abl4Kvi+Ut7cQcY+cQ=
|
||||||
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
|
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
|
||||||
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||||
|
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||||
|
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -419,6 +526,8 @@ golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
|||||||
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||||
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4=
|
||||||
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||||
|
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
|
||||||
|
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@@ -444,6 +553,8 @@ golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|||||||
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
|
||||||
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
golang.org/x/sys v0.44.0 h1:ildZl3J4uzeKP07r2F++Op7E9B29JRUy+a27EibtBTQ=
|
||||||
|
golang.org/x/sys v0.44.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
|
||||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
@@ -461,6 +572,7 @@ golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
|||||||
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
|
||||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||||
|
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
@@ -477,8 +589,12 @@ golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI=
|
|||||||
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4=
|
||||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||||
|
golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc=
|
||||||
|
golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38=
|
||||||
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
|
||||||
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
|
||||||
|
golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U=
|
||||||
|
golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||||
@@ -487,16 +603,24 @@ golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58
|
|||||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||||
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA=
|
||||||
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc=
|
||||||
|
golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
|
||||||
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
|
||||||
|
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
|
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5 h1:BIRfGDEjiHRrk0QKZe3Xv2ieMhtgRGeLcZQ0mIVn4EY=
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE=
|
google.golang.org/genproto/googleapis/api v0.0.0-20250825161204-c5933d9347a5/go.mod h1:j3QtIyytwqGr1JUDtYXwtMXWPKsEa5LtzIFN1Wn5WvE=
|
||||||
|
google.golang.org/genproto/googleapis/api v0.0.0-20260519071638-aa98bba5eb94 h1:DddG61lE5LkX6144z22i0gma9BMBs5aZ9B8lZLobxyw=
|
||||||
|
google.golang.org/genproto/googleapis/api v0.0.0-20260519071638-aa98bba5eb94/go.mod h1:1dCETSCY2YKZNXQE3h4fun3TYwF5p8jejRKZgfWAgAY=
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5 h1:eaY8u2EuxbRv7c3NiGK0/NedzVsCcV6hDuU5qPX5EGE=
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250825161204-c5933d9347a5/go.mod h1:M4/wBTSeyLxupu3W3tJtOgB14jILAS/XWPSSa3TAlJc=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260519071638-aa98bba5eb94 h1:eZCjr/aAF8c5ccm5pb6T4EXgIei5MlAAPWPJk+5ArfY=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20260519071638-aa98bba5eb94/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8=
|
||||||
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
|
google.golang.org/grpc v1.75.0 h1:+TW+dqTd2Biwe6KKfhE5JpiYIBWq865PhKGSXiivqt4=
|
||||||
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
google.golang.org/grpc v1.75.0/go.mod h1:JtPAzKiq4v1xcAB2hydNlWI2RnF85XXcV0mhKXr2ecQ=
|
||||||
|
google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ=
|
||||||
|
google.golang.org/grpc v1.81.1/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I=
|
||||||
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
|
||||||
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
@@ -522,28 +646,37 @@ gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
|||||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
||||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
||||||
|
modernc.org/cc/v4 v4.28.2 h1:3tQ0lf2ADtoby2EtSP+J7IE2SHwEJdP8ioR59wx7XpY=
|
||||||
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
||||||
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
|
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
|
||||||
|
modernc.org/ccgo/v4 v4.34.0 h1:yRLPFZieg532OT4rp4JFNIVcquwalMX26G95WQDqwCQ=
|
||||||
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
|
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
|
||||||
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
||||||
|
modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM=
|
||||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
||||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
||||||
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
||||||
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
||||||
|
modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo=
|
||||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
||||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
||||||
modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg=
|
modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg=
|
||||||
modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
|
modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA=
|
||||||
|
modernc.org/libc v1.72.3 h1:ZnDF4tXn4NBXFutMMQC4vtbTFSXhhKzR73fv0beZEAU=
|
||||||
|
modernc.org/libc v1.72.3/go.mod h1:dn0dZNnnn1clLyvRxLxYExxiKRZIRENOfqQ8XEeg4Qs=
|
||||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
||||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
||||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
||||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
||||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
||||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
||||||
|
modernc.org/opt v0.2.0 h1:tGyef5ApycA7FSEOMraay9SaTk5zmbx7Tu+cJs4QKZg=
|
||||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
||||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
||||||
modernc.org/sqlite v1.42.2 h1:7hkZUNJvJFN2PgfUdjni9Kbvd4ef4mNLOu0B9FGxM74=
|
modernc.org/sqlite v1.42.2 h1:7hkZUNJvJFN2PgfUdjni9Kbvd4ef4mNLOu0B9FGxM74=
|
||||||
modernc.org/sqlite v1.42.2/go.mod h1:+VkC6v3pLOAE0A0uVucQEcbVW0I5nHCeDaBf+DpsQT8=
|
modernc.org/sqlite v1.42.2/go.mod h1:+VkC6v3pLOAE0A0uVucQEcbVW0I5nHCeDaBf+DpsQT8=
|
||||||
|
modernc.org/sqlite v1.50.1 h1:l+cQvn0sd0zJJtfygGHuQJ5AjlrwXmWPw4KP3ZMwr9w=
|
||||||
|
modernc.org/sqlite v1.50.1/go.mod h1:tcNzv5p84E0skkmJn038y+hWJbLQXQqEnQfeh5r2JLM=
|
||||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
||||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
||||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
||||||
|
|||||||
-362
@@ -1,362 +0,0 @@
|
|||||||
openapi: 3.0.0
|
|
||||||
info:
|
|
||||||
title: ResolveSpec API
|
|
||||||
version: '1.0'
|
|
||||||
description: A flexible REST API with GraphQL-like capabilities
|
|
||||||
|
|
||||||
servers:
|
|
||||||
- url: 'http://api.example.com/v1'
|
|
||||||
|
|
||||||
paths:
|
|
||||||
'/{schema}/{entity}':
|
|
||||||
parameters:
|
|
||||||
- name: schema
|
|
||||||
in: path
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
- name: entity
|
|
||||||
in: path
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
get:
|
|
||||||
summary: Get table metadata
|
|
||||||
description: Retrieve table metadata including columns, types, and relationships
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Successful operation
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
allOf:
|
|
||||||
- $ref: '#/components/schemas/Response'
|
|
||||||
- type: object
|
|
||||||
properties:
|
|
||||||
data:
|
|
||||||
$ref: '#/components/schemas/TableMetadata'
|
|
||||||
'400':
|
|
||||||
$ref: '#/components/responses/BadRequest'
|
|
||||||
'404':
|
|
||||||
$ref: '#/components/responses/NotFound'
|
|
||||||
'500':
|
|
||||||
$ref: '#/components/responses/ServerError'
|
|
||||||
post:
|
|
||||||
summary: Perform operations on entities
|
|
||||||
requestBody:
|
|
||||||
required: true
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/Request'
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Successful operation
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/Response'
|
|
||||||
'400':
|
|
||||||
$ref: '#/components/responses/BadRequest'
|
|
||||||
'404':
|
|
||||||
$ref: '#/components/responses/NotFound'
|
|
||||||
'500':
|
|
||||||
$ref: '#/components/responses/ServerError'
|
|
||||||
|
|
||||||
'/{schema}/{entity}/{id}':
|
|
||||||
parameters:
|
|
||||||
- name: schema
|
|
||||||
in: path
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
- name: entity
|
|
||||||
in: path
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
- name: id
|
|
||||||
in: path
|
|
||||||
required: true
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
post:
|
|
||||||
summary: Perform operations on a specific entity
|
|
||||||
requestBody:
|
|
||||||
required: true
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/Request'
|
|
||||||
responses:
|
|
||||||
'200':
|
|
||||||
description: Successful operation
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/Response'
|
|
||||||
'400':
|
|
||||||
$ref: '#/components/responses/BadRequest'
|
|
||||||
'404':
|
|
||||||
$ref: '#/components/responses/NotFound'
|
|
||||||
'500':
|
|
||||||
$ref: '#/components/responses/ServerError'
|
|
||||||
|
|
||||||
components:
|
|
||||||
schemas:
|
|
||||||
Request:
|
|
||||||
type: object
|
|
||||||
required:
|
|
||||||
- operation
|
|
||||||
properties:
|
|
||||||
operation:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- read
|
|
||||||
- create
|
|
||||||
- update
|
|
||||||
- delete
|
|
||||||
id:
|
|
||||||
oneOf:
|
|
||||||
- type: string
|
|
||||||
- type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
description: Optional record identifier(s) when not provided in URL
|
|
||||||
data:
|
|
||||||
oneOf:
|
|
||||||
- type: object
|
|
||||||
- type: array
|
|
||||||
items:
|
|
||||||
type: object
|
|
||||||
description: Data for single or bulk create/update operations
|
|
||||||
options:
|
|
||||||
$ref: '#/components/schemas/Options'
|
|
||||||
|
|
||||||
Options:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
preload:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/PreloadOption'
|
|
||||||
columns:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
filters:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/FilterOption'
|
|
||||||
sort:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/SortOption'
|
|
||||||
limit:
|
|
||||||
type: integer
|
|
||||||
minimum: 0
|
|
||||||
offset:
|
|
||||||
type: integer
|
|
||||||
minimum: 0
|
|
||||||
customOperators:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/CustomOperator'
|
|
||||||
computedColumns:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/ComputedColumn'
|
|
||||||
|
|
||||||
PreloadOption:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
relation:
|
|
||||||
type: string
|
|
||||||
columns:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
filters:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/FilterOption'
|
|
||||||
|
|
||||||
FilterOption:
|
|
||||||
type: object
|
|
||||||
required:
|
|
||||||
- column
|
|
||||||
- operator
|
|
||||||
- value
|
|
||||||
properties:
|
|
||||||
column:
|
|
||||||
type: string
|
|
||||||
operator:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- eq
|
|
||||||
- neq
|
|
||||||
- gt
|
|
||||||
- gte
|
|
||||||
- lt
|
|
||||||
- lte
|
|
||||||
- like
|
|
||||||
- ilike
|
|
||||||
- in
|
|
||||||
value:
|
|
||||||
type: object
|
|
||||||
|
|
||||||
SortOption:
|
|
||||||
type: object
|
|
||||||
required:
|
|
||||||
- column
|
|
||||||
- direction
|
|
||||||
properties:
|
|
||||||
column:
|
|
||||||
type: string
|
|
||||||
direction:
|
|
||||||
type: string
|
|
||||||
enum:
|
|
||||||
- asc
|
|
||||||
- desc
|
|
||||||
|
|
||||||
CustomOperator:
|
|
||||||
type: object
|
|
||||||
required:
|
|
||||||
- name
|
|
||||||
- sql
|
|
||||||
properties:
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
sql:
|
|
||||||
type: string
|
|
||||||
|
|
||||||
ComputedColumn:
|
|
||||||
type: object
|
|
||||||
required:
|
|
||||||
- name
|
|
||||||
- expression
|
|
||||||
properties:
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
expression:
|
|
||||||
type: string
|
|
||||||
|
|
||||||
Response:
|
|
||||||
type: object
|
|
||||||
required:
|
|
||||||
- success
|
|
||||||
properties:
|
|
||||||
success:
|
|
||||||
type: boolean
|
|
||||||
data:
|
|
||||||
type: object
|
|
||||||
metadata:
|
|
||||||
$ref: '#/components/schemas/Metadata'
|
|
||||||
error:
|
|
||||||
$ref: '#/components/schemas/Error'
|
|
||||||
|
|
||||||
Metadata:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
total:
|
|
||||||
type: integer
|
|
||||||
filtered:
|
|
||||||
type: integer
|
|
||||||
limit:
|
|
||||||
type: integer
|
|
||||||
offset:
|
|
||||||
type: integer
|
|
||||||
|
|
||||||
Error:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
code:
|
|
||||||
type: string
|
|
||||||
message:
|
|
||||||
type: string
|
|
||||||
details:
|
|
||||||
type: object
|
|
||||||
|
|
||||||
TableMetadata:
|
|
||||||
type: object
|
|
||||||
required:
|
|
||||||
- schema
|
|
||||||
- table
|
|
||||||
- columns
|
|
||||||
- relations
|
|
||||||
properties:
|
|
||||||
schema:
|
|
||||||
type: string
|
|
||||||
description: Schema name
|
|
||||||
table:
|
|
||||||
type: string
|
|
||||||
description: Table name
|
|
||||||
columns:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
$ref: '#/components/schemas/Column'
|
|
||||||
relations:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: string
|
|
||||||
description: List of relation names
|
|
||||||
|
|
||||||
Column:
|
|
||||||
type: object
|
|
||||||
required:
|
|
||||||
- name
|
|
||||||
- type
|
|
||||||
- is_nullable
|
|
||||||
- is_primary
|
|
||||||
- is_unique
|
|
||||||
- has_index
|
|
||||||
properties:
|
|
||||||
name:
|
|
||||||
type: string
|
|
||||||
description: Column name
|
|
||||||
type:
|
|
||||||
type: string
|
|
||||||
description: Data type of the column
|
|
||||||
is_nullable:
|
|
||||||
type: boolean
|
|
||||||
description: Whether the column can contain null values
|
|
||||||
is_primary:
|
|
||||||
type: boolean
|
|
||||||
description: Whether the column is a primary key
|
|
||||||
is_unique:
|
|
||||||
type: boolean
|
|
||||||
description: Whether the column has a unique constraint
|
|
||||||
has_index:
|
|
||||||
type: boolean
|
|
||||||
description: Whether the column is indexed
|
|
||||||
|
|
||||||
responses:
|
|
||||||
BadRequest:
|
|
||||||
description: Bad request
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/Response'
|
|
||||||
|
|
||||||
NotFound:
|
|
||||||
description: Resource not found
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/Response'
|
|
||||||
|
|
||||||
ServerError:
|
|
||||||
description: Internal server error
|
|
||||||
content:
|
|
||||||
application/json:
|
|
||||||
schema:
|
|
||||||
$ref: '#/components/schemas/Response'
|
|
||||||
|
|
||||||
securitySchemes:
|
|
||||||
bearerAuth:
|
|
||||||
type: http
|
|
||||||
scheme: bearer
|
|
||||||
bearerFormat: JWT
|
|
||||||
|
|
||||||
security:
|
|
||||||
- bearerAuth: []
|
|
||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/uptrace/bun"
|
"github.com/uptrace/bun"
|
||||||
@@ -95,17 +96,56 @@ func debugScanIntoStruct(rows interface{}, dest interface{}) error {
|
|||||||
// This demonstrates how the abstraction works with different ORMs
|
// This demonstrates how the abstraction works with different ORMs
|
||||||
type BunAdapter struct {
|
type BunAdapter struct {
|
||||||
db *bun.DB
|
db *bun.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*bun.DB, error)
|
||||||
|
driverName string
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBunAdapter creates a new Bun adapter
|
// NewBunAdapter creates a new Bun adapter
|
||||||
func NewBunAdapter(db *bun.DB) *BunAdapter {
|
func NewBunAdapter(db *bun.DB) *BunAdapter {
|
||||||
return &BunAdapter{db: db}
|
adapter := &BunAdapter{db: db, metricsEnabled: true}
|
||||||
|
// Initialize driver name
|
||||||
|
adapter.driverName = adapter.DriverName()
|
||||||
|
return adapter
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||||
|
func (b *BunAdapter) WithDBFactory(factory func() (*bun.DB, error)) *BunAdapter {
|
||||||
|
b.dbFactory = factory
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetricsEnabled enables or disables query metrics for this adapter.
|
||||||
|
func (b *BunAdapter) SetMetricsEnabled(enabled bool) *BunAdapter {
|
||||||
|
b.metricsEnabled = enabled
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BunAdapter) getDB() *bun.DB {
|
||||||
|
b.dbMu.RLock()
|
||||||
|
defer b.dbMu.RUnlock()
|
||||||
|
return b.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *BunAdapter) reconnectDB() error {
|
||||||
|
if b.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := b.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
b.dbMu.Lock()
|
||||||
|
b.db = newDB
|
||||||
|
b.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||||
// This is useful for debugging preload queries that may be failing
|
// This is useful for debugging preload queries that may be failing
|
||||||
func (b *BunAdapter) EnableQueryDebug() {
|
func (b *BunAdapter) EnableQueryDebug() {
|
||||||
b.db.AddQueryHook(&QueryDebugHook{})
|
b.getDB().AddQueryHook(&QueryDebugHook{})
|
||||||
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
|
logger.Info("Bun query debug mode enabled - all SQL queries will be logged")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,21 +166,23 @@ func (b *BunAdapter) DisableQueryDebug() {
|
|||||||
|
|
||||||
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
func (b *BunAdapter) NewSelect() common.SelectQuery {
|
||||||
return &BunSelectQuery{
|
return &BunSelectQuery{
|
||||||
query: b.db.NewSelect(),
|
query: b.getDB().NewSelect(),
|
||||||
db: b.db,
|
db: b.db,
|
||||||
|
driverName: b.driverName,
|
||||||
|
metricsEnabled: b.metricsEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) NewInsert() common.InsertQuery {
|
func (b *BunAdapter) NewInsert() common.InsertQuery {
|
||||||
return &BunInsertQuery{query: b.db.NewInsert()}
|
return &BunInsertQuery{query: b.getDB().NewInsert(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) NewUpdate() common.UpdateQuery {
|
func (b *BunAdapter) NewUpdate() common.UpdateQuery {
|
||||||
return &BunUpdateQuery{query: b.db.NewUpdate()}
|
return &BunUpdateQuery{query: b.getDB().NewUpdate(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
func (b *BunAdapter) NewDelete() common.DeleteQuery {
|
||||||
return &BunDeleteQuery{query: b.db.NewDelete()}
|
return &BunDeleteQuery{query: b.getDB().NewDelete(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||||
@@ -149,7 +191,17 @@ func (b *BunAdapter) Exec(ctx context.Context, query string, args ...interface{}
|
|||||||
err = logger.HandlePanic("BunAdapter.Exec", r)
|
err = logger.HandlePanic("BunAdapter.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result, err := b.db.ExecContext(ctx, query, args...)
|
startedAt := time.Now()
|
||||||
|
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
|
||||||
|
var result sql.Result
|
||||||
|
run := func() error { var e error; result, e = b.getDB().ExecContext(ctx, query, args...); return e }
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -159,16 +211,29 @@ func (b *BunAdapter) Query(ctx context.Context, dest interface{}, query string,
|
|||||||
err = logger.HandlePanic("BunAdapter.Query", r)
|
err = logger.HandlePanic("BunAdapter.Query", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return b.db.NewRaw(query, args...).Scan(ctx, dest)
|
startedAt := time.Now()
|
||||||
|
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
|
||||||
|
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||||
|
err = b.getDB().NewRaw(query, args...).Scan(ctx, dest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
func (b *BunAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||||
tx, err := b.db.BeginTx(ctx, &sql.TxOptions{})
|
tx, err := b.getDB().BeginTx(ctx, &sql.TxOptions{})
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||||
|
tx, err = b.getDB().BeginTx(ctx, &sql.TxOptions{})
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// For Bun, we'll return a special wrapper that holds the transaction
|
return &BunTxAdapter{tx: tx, driverName: b.driverName, metricsEnabled: b.metricsEnabled}, nil
|
||||||
return &BunTxAdapter{tx: tx, driverName: b.DriverName()}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) CommitTx(ctx context.Context) error {
|
func (b *BunAdapter) CommitTx(ctx context.Context) error {
|
||||||
@@ -189,23 +254,34 @@ func (b *BunAdapter) RunInTransaction(ctx context.Context, fn func(common.Databa
|
|||||||
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
err = logger.HandlePanic("BunAdapter.RunInTransaction", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return b.db.RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
run := func() error {
|
||||||
// Create adapter with transaction
|
return b.getDB().RunInTx(ctx, &sql.TxOptions{}, func(ctx context.Context, tx bun.Tx) error {
|
||||||
adapter := &BunTxAdapter{tx: tx, driverName: b.DriverName()}
|
adapter := &BunTxAdapter{tx: tx, driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||||
return fn(adapter)
|
return fn(adapter)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := b.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
func (b *BunAdapter) GetUnderlyingDB() interface{} {
|
||||||
return b.db
|
return b.getDB()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunAdapter) DriverName() string {
|
func (b *BunAdapter) DriverName() string {
|
||||||
// Normalize Bun's dialect name to match the project's canonical vocabulary.
|
// Normalize Bun's dialect name to match the project's canonical vocabulary.
|
||||||
// Bun returns "pg" for PostgreSQL; the rest of the project uses "postgres".
|
// Bun returns "pg" for PostgreSQL; the rest of the project uses "postgres".
|
||||||
|
// Bun returns "sqlite3" for SQLite; we normalize to "sqlite".
|
||||||
switch name := b.db.Dialect().Name().String(); name {
|
switch name := b.db.Dialect().Name().String(); name {
|
||||||
case "pg":
|
case "pg":
|
||||||
return "postgres"
|
return "postgres"
|
||||||
|
case "sqlite3":
|
||||||
|
return "sqlite"
|
||||||
default:
|
default:
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
@@ -218,23 +294,25 @@ type BunSelectQuery struct {
|
|||||||
hasModel bool // Track if Model() was called
|
hasModel bool // Track if Model() was called
|
||||||
schema string // Separated schema name
|
schema string // Separated schema name
|
||||||
tableName string // Just the table name, without schema
|
tableName string // Just the table name, without schema
|
||||||
|
entity string
|
||||||
tableAlias string
|
tableAlias string
|
||||||
|
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||||
inJoinContext bool // Track if we're in a JOIN relation context
|
inJoinContext bool // Track if we're in a JOIN relation context
|
||||||
joinTableAlias string // Alias to use for JOIN conditions
|
joinTableAlias string // Alias to use for JOIN conditions
|
||||||
skipAutoDetect bool // Skip auto-detection to prevent circular calls
|
skipAutoDetect bool // Skip auto-detection to prevent circular calls
|
||||||
|
preloadRelationAlias string // Relation alias used in separate-query preloads (e.g. "tprp" for relation "TPRP")
|
||||||
customPreloads map[string][]func(common.SelectQuery) common.SelectQuery // Relations to load with custom implementation
|
customPreloads map[string][]func(common.SelectQuery) common.SelectQuery // Relations to load with custom implementation
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||||
b.query = b.query.Model(model)
|
b.query = b.query.Model(model)
|
||||||
b.hasModel = true // Mark that we have a model
|
b.hasModel = true // Mark that we have a model
|
||||||
|
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
|
||||||
// Try to get table name from model if it implements TableNameProvider
|
if b.tableName == "" {
|
||||||
if provider, ok := model.(common.TableNameProvider); ok {
|
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
|
||||||
fullTableName := provider.TableName()
|
|
||||||
// Check if the table name contains schema (e.g., "schema.table")
|
|
||||||
b.schema, b.tableName = parseTableName(fullTableName)
|
|
||||||
}
|
}
|
||||||
|
b.entity = entityNameFromModel(model, b.tableName)
|
||||||
|
|
||||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||||
b.tableAlias = provider.TableAlias()
|
b.tableAlias = provider.TableAlias()
|
||||||
@@ -246,7 +324,11 @@ func (b *BunSelectQuery) Model(model interface{}) common.SelectQuery {
|
|||||||
func (b *BunSelectQuery) Table(table string) common.SelectQuery {
|
func (b *BunSelectQuery) Table(table string) common.SelectQuery {
|
||||||
b.query = b.query.Table(table)
|
b.query = b.query.Table(table)
|
||||||
// Check if the table name contains schema (e.g., "schema.table")
|
// Check if the table name contains schema (e.g., "schema.table")
|
||||||
b.schema, b.tableName = parseTableName(table)
|
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||||
|
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||||
|
if b.entity == "" {
|
||||||
|
b.entity = cleanMetricIdentifier(b.tableName)
|
||||||
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -265,12 +347,14 @@ func (b *BunSelectQuery) ColumnExpr(query string, args ...interface{}) common.Se
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
func (b *BunSelectQuery) Where(query string, args ...interface{}) common.SelectQuery {
|
||||||
// If we're in a JOIN context, add table prefix to unqualified columns
|
|
||||||
if b.inJoinContext && b.joinTableAlias != "" {
|
if b.inJoinContext && b.joinTableAlias != "" {
|
||||||
query = addTablePrefix(query, b.joinTableAlias)
|
query = addTablePrefix(query, b.joinTableAlias)
|
||||||
|
} else if b.preloadRelationAlias != "" && b.tableName != "" {
|
||||||
|
// Separate-query preload: the caller may have written conditions using the
|
||||||
|
// relation name as a prefix (e.g. "TPRP.col"). Bun uses the real table name
|
||||||
|
// as the alias, so rewrite any such references to use tableName instead.
|
||||||
|
query = replaceRelationAlias(query, b.preloadRelationAlias, b.tableName)
|
||||||
} else if b.tableAlias != "" && b.tableName != "" {
|
} else if b.tableAlias != "" && b.tableName != "" {
|
||||||
// If we have a table alias defined, check if the query references a different alias
|
|
||||||
// This can happen in preloads where the user expects a certain alias but Bun generates another
|
|
||||||
query = normalizeTableAlias(query, b.tableAlias, b.tableName)
|
query = normalizeTableAlias(query, b.tableAlias, b.tableName)
|
||||||
}
|
}
|
||||||
b.query = b.query.Where(query, args...)
|
b.query = b.query.Where(query, args...)
|
||||||
@@ -406,6 +490,38 @@ func normalizeTableAlias(query, expectedAlias, tableName string) string {
|
|||||||
return modified
|
return modified
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// replaceRelationAlias rewrites WHERE conditions written with a relation alias prefix
|
||||||
|
// (e.g. "TPRP.col") to use the real table name that bun uses in separate queries
|
||||||
|
// (e.g. "t_proposalinstance.col"). Only called for separate-query preload wrappers.
|
||||||
|
func replaceRelationAlias(query, relationAlias, tableName string) string {
|
||||||
|
if relationAlias == "" || tableName == "" || query == "" {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
parts := strings.FieldsFunc(query, func(r rune) bool {
|
||||||
|
return r == ' ' || r == '(' || r == ')' || r == ','
|
||||||
|
})
|
||||||
|
modified := query
|
||||||
|
for _, part := range parts {
|
||||||
|
if dotIndex := strings.Index(part, "."); dotIndex > 0 {
|
||||||
|
prefix := part[:dotIndex]
|
||||||
|
column := part[dotIndex+1:]
|
||||||
|
if strings.EqualFold(prefix, relationAlias) {
|
||||||
|
logger.Debug("Replacing relation alias '%s' with table name '%s' in preload WHERE condition", prefix, tableName)
|
||||||
|
modified = strings.ReplaceAll(modified, part, tableName+"."+column)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return modified
|
||||||
|
}
|
||||||
|
|
||||||
|
func isJoinKeyword(word string) bool {
|
||||||
|
switch strings.ToUpper(word) {
|
||||||
|
case "JOIN", "INNER", "LEFT", "RIGHT", "FULL", "OUTER", "CROSS":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
func (b *BunSelectQuery) WhereOr(query string, args ...interface{}) common.SelectQuery {
|
||||||
b.query = b.query.WhereOr(query, args...)
|
b.query = b.query.WhereOr(query, args...)
|
||||||
return b
|
return b
|
||||||
@@ -436,7 +552,7 @@ func (b *BunSelectQuery) Join(query string, args ...interface{}) common.SelectQu
|
|||||||
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
|
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
|
||||||
// If query doesn't already have AS, check if it's a simple table name
|
// If query doesn't already have AS, check if it's a simple table name
|
||||||
parts := strings.Fields(query)
|
parts := strings.Fields(query)
|
||||||
if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") {
|
if len(parts) > 0 && !isJoinKeyword(parts[0]) {
|
||||||
// Simple table name, add prefix: "table AS prefix"
|
// Simple table name, add prefix: "table AS prefix"
|
||||||
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
|
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
@@ -471,7 +587,7 @@ func (b *BunSelectQuery) LeftJoin(query string, args ...interface{}) common.Sele
|
|||||||
joinClause := query
|
joinClause := query
|
||||||
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
|
if prefix != "" && !strings.Contains(strings.ToUpper(query), " AS ") {
|
||||||
parts := strings.Fields(query)
|
parts := strings.Fields(query)
|
||||||
if len(parts) > 0 && !strings.HasPrefix(strings.ToUpper(parts[0]), "LEFT") && !strings.HasPrefix(strings.ToUpper(parts[0]), "JOIN") {
|
if len(parts) > 0 && !isJoinKeyword(parts[0]) {
|
||||||
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
|
joinClause = fmt.Sprintf("%s AS %s", parts[0], prefix)
|
||||||
if len(parts) > 1 {
|
if len(parts) > 1 {
|
||||||
joinClause += " " + strings.Join(parts[1:], " ")
|
joinClause += " " + strings.Join(parts[1:], " ")
|
||||||
@@ -516,6 +632,19 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
|||||||
if !b.skipAutoDetect {
|
if !b.skipAutoDetect {
|
||||||
model := b.query.GetModel()
|
model := b.query.GetModel()
|
||||||
if model != nil && model.Value() != nil {
|
if model != nil && model.Value() != nil {
|
||||||
|
// Guard against relations that don't exist on the model. Without this,
|
||||||
|
// bun panics inside Count/Scan with `model=X does not have relation="Y"`.
|
||||||
|
// Only validate the root segment so nested paths (e.g. "PRM.CHILD") still
|
||||||
|
// fall through to bun's native resolution.
|
||||||
|
rootRelation := relation
|
||||||
|
if idx := strings.Index(rootRelation, "."); idx >= 0 {
|
||||||
|
rootRelation = rootRelation[:idx]
|
||||||
|
}
|
||||||
|
if reflection.GetRelationType(model.Value(), rootRelation) == reflection.RelationUnknown {
|
||||||
|
logger.Warn("Skipping preload '%s': relation '%s' is not declared on model %T", relation, rootRelation, model.Value())
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
relType := reflection.GetRelationType(model.Value(), relation)
|
relType := reflection.GetRelationType(model.Value(), relation)
|
||||||
|
|
||||||
// Log the detected relationship type
|
// Log the detected relationship type
|
||||||
@@ -554,6 +683,8 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
|||||||
wrapper := &BunSelectQuery{
|
wrapper := &BunSelectQuery{
|
||||||
query: sq,
|
query: sq,
|
||||||
db: b.db,
|
db: b.db,
|
||||||
|
driverName: b.driverName,
|
||||||
|
metricsEnabled: b.metricsEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to extract table name and alias from the preload model
|
// Try to extract table name and alias from the preload model
|
||||||
@@ -563,7 +694,8 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
|||||||
// Extract table name if model implements TableNameProvider
|
// Extract table name if model implements TableNameProvider
|
||||||
if provider, ok := modelValue.(common.TableNameProvider); ok {
|
if provider, ok := modelValue.(common.TableNameProvider); ok {
|
||||||
fullTableName := provider.TableName()
|
fullTableName := provider.TableName()
|
||||||
wrapper.schema, wrapper.tableName = parseTableName(fullTableName)
|
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||||
|
wrapper.schema, wrapper.tableName = parseTableName(fullTableName, b.driverName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract table alias if model implements TableAliasProvider
|
// Extract table alias if model implements TableAliasProvider
|
||||||
@@ -571,8 +703,20 @@ func (b *BunSelectQuery) PreloadRelation(relation string, apply ...func(common.S
|
|||||||
wrapper.tableAlias = provider.TableAlias()
|
wrapper.tableAlias = provider.TableAlias()
|
||||||
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
|
logger.Debug("Preload relation '%s' using table alias: %s", relation, wrapper.tableAlias)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fallback: if the model didn't provide a table name, ask bun directly.
|
||||||
|
if wrapper.tableName == "" {
|
||||||
|
wrapper.schema, wrapper.tableName = parseTableName(sq.GetTableName(), b.driverName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For separate-query preloads (has-many), bun aliases the related table using
|
||||||
|
// the actual table name, not the relation name. Record the relation alias so
|
||||||
|
// Where() can rewrite conditions like "TPRP.col" to "t_proposalinstance.col".
|
||||||
|
wrapper.preloadRelationAlias = strings.ToLower(relation)
|
||||||
|
logger.Debug("Preload relation '%s' registered alias '%s' for separate-query WHERE rewriting", relation, wrapper.preloadRelationAlias)
|
||||||
|
|
||||||
// Start with the interface value (not pointer)
|
// Start with the interface value (not pointer)
|
||||||
current := common.SelectQuery(wrapper)
|
current := common.SelectQuery(wrapper)
|
||||||
|
|
||||||
@@ -803,7 +947,7 @@ func (b *BunSelectQuery) loadRelationLevel(ctx context.Context, parentRecords re
|
|||||||
|
|
||||||
// Apply user's functions (if any)
|
// Apply user's functions (if any)
|
||||||
if isLast && len(applyFuncs) > 0 {
|
if isLast && len(applyFuncs) > 0 {
|
||||||
wrapper := &BunSelectQuery{query: query, db: b.db}
|
wrapper := &BunSelectQuery{query: query, db: b.db, driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||||
for _, fn := range applyFuncs {
|
for _, fn := range applyFuncs {
|
||||||
if fn != nil {
|
if fn != nil {
|
||||||
wrapper = fn(wrapper).(*BunSelectQuery)
|
wrapper = fn(wrapper).(*BunSelectQuery)
|
||||||
@@ -1155,27 +1299,29 @@ func (b *BunSelectQuery) Having(having string, args ...interface{}) common.Selec
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
func (b *BunSelectQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||||
|
startedAt := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
err = logger.HandlePanic("BunSelectQuery.Scan", r)
|
err = logger.HandlePanic("BunSelectQuery.Scan", r)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||||
}()
|
}()
|
||||||
if dest == nil {
|
if dest == nil {
|
||||||
return fmt.Errorf("destination cannot be nil")
|
err = fmt.Errorf("destination cannot be nil")
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = b.query.Scan(ctx, dest)
|
err = b.query.Scan(ctx, dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
|
||||||
sqlStr := b.query.String()
|
sqlStr := b.query.String()
|
||||||
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
logger.Error("BunSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
err = common.WrapSQLError(err, sqlStr)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
||||||
|
startedAt := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
// Enhanced panic recovery with model information
|
// Enhanced panic recovery with model information
|
||||||
@@ -1185,7 +1331,6 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
modelValue := model.Value()
|
modelValue := model.Value()
|
||||||
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
|
modelInfo = fmt.Sprintf("Model type: %T", modelValue)
|
||||||
|
|
||||||
// Try to get the model's underlying struct type
|
|
||||||
v := reflect.ValueOf(modelValue)
|
v := reflect.ValueOf(modelValue)
|
||||||
if v.Kind() == reflect.Ptr {
|
if v.Kind() == reflect.Ptr {
|
||||||
v = v.Elem()
|
v = v.Elem()
|
||||||
@@ -1205,9 +1350,11 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
logger.Error("Panic in BunSelectQuery.ScanModel: %v. %s. SQL: %s", r, modelInfo, sqlStr)
|
logger.Error("Panic in BunSelectQuery.ScanModel: %v. %s. SQL: %s", r, modelInfo, sqlStr)
|
||||||
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
|
err = logger.HandlePanic("BunSelectQuery.ScanModel", r)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(b.metricsEnabled, "SELECT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||||
}()
|
}()
|
||||||
if b.query.GetModel() == nil {
|
if b.query.GetModel() == nil {
|
||||||
return fmt.Errorf("model is nil")
|
err = fmt.Errorf("model is nil")
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optional: Enable detailed field-level debugging (set to true to debug)
|
// Optional: Enable detailed field-level debugging (set to true to debug)
|
||||||
@@ -1223,16 +1370,15 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
|
|
||||||
err = b.query.Scan(ctx)
|
err = b.query.Scan(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
|
||||||
sqlStr := b.query.String()
|
sqlStr := b.query.String()
|
||||||
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
logger.Error("BunSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
return err
|
return common.WrapSQLError(err, sqlStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// After main query, load custom preloads using separate queries
|
// After main query, load custom preloads using separate queries
|
||||||
if len(b.customPreloads) > 0 {
|
if len(b.customPreloads) > 0 {
|
||||||
logger.Info("Loading %d custom preload(s) with separate queries", len(b.customPreloads))
|
logger.Info("Loading %d custom preload(s) with separate queries", len(b.customPreloads))
|
||||||
if err := b.loadCustomPreloads(ctx); err != nil {
|
if err = b.loadCustomPreloads(ctx); err != nil {
|
||||||
logger.Error("Failed to load custom preloads: %v", err)
|
logger.Error("Failed to load custom preloads: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -1242,21 +1388,23 @@ func (b *BunSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||||
|
startedAt := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
err = logger.HandlePanic("BunSelectQuery.Count", r)
|
err = logger.HandlePanic("BunSelectQuery.Count", r)
|
||||||
count = 0
|
count = 0
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(b.metricsEnabled, "COUNT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||||
}()
|
}()
|
||||||
// If Model() was set, use bun's native Count() which works properly
|
// If Model() was set, use bun's native Count() which works properly
|
||||||
if b.hasModel {
|
if b.hasModel {
|
||||||
count, err := b.query.Count(ctx)
|
count, err = b.query.Count(ctx) // assign to named returns, not shadow vars
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
|
||||||
sqlStr := b.query.String()
|
sqlStr := b.query.String()
|
||||||
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
logger.Error("BunSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
err = common.WrapSQLError(err, sqlStr)
|
||||||
}
|
}
|
||||||
return count, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
// Otherwise, wrap as subquery to avoid "Model(nil)" error
|
||||||
@@ -1266,27 +1414,29 @@ func (b *BunSelectQuery) Count(ctx context.Context) (count int, err error) {
|
|||||||
ColumnExpr("COUNT(*)")
|
ColumnExpr("COUNT(*)")
|
||||||
err = countQuery.Scan(ctx, &count)
|
err = countQuery.Scan(ctx, &count)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
|
||||||
sqlStr := countQuery.String()
|
sqlStr := countQuery.String()
|
||||||
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
|
logger.Error("BunSelectQuery.Count (subquery) failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
err = common.WrapSQLError(err, sqlStr)
|
||||||
}
|
}
|
||||||
return count, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
func (b *BunSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
||||||
|
startedAt := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
err = logger.HandlePanic("BunSelectQuery.Exists", r)
|
err = logger.HandlePanic("BunSelectQuery.Exists", r)
|
||||||
exists = false
|
exists = false
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(b.metricsEnabled, "EXISTS", b.schema, b.entity, b.tableName, startedAt, err)
|
||||||
}()
|
}()
|
||||||
exists, err = b.query.Exists(ctx)
|
exists, err = b.query.Exists(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
|
||||||
sqlStr := b.query.String()
|
sqlStr := b.query.String()
|
||||||
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
logger.Error("BunSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
err = common.WrapSQLError(err, sqlStr)
|
||||||
}
|
}
|
||||||
return exists, err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// BunInsertQuery implements InsertQuery for Bun
|
// BunInsertQuery implements InsertQuery for Bun
|
||||||
@@ -1294,11 +1444,21 @@ type BunInsertQuery struct {
|
|||||||
query *bun.InsertQuery
|
query *bun.InsertQuery
|
||||||
values map[string]interface{}
|
values map[string]interface{}
|
||||||
hasModel bool
|
hasModel bool
|
||||||
|
driverName string
|
||||||
|
schema string
|
||||||
|
tableName string
|
||||||
|
entity string
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
func (b *BunInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||||
b.query = b.query.Model(model)
|
b.query = b.query.Model(model)
|
||||||
b.hasModel = true
|
b.hasModel = true
|
||||||
|
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
|
||||||
|
if b.tableName == "" {
|
||||||
|
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
|
||||||
|
}
|
||||||
|
b.entity = entityNameFromModel(model, b.tableName)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1307,6 +1467,10 @@ func (b *BunInsertQuery) Table(table string) common.InsertQuery {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
b.query = b.query.Table(table)
|
b.query = b.query.Table(table)
|
||||||
|
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||||
|
if b.entity == "" {
|
||||||
|
b.entity = cleanMetricIdentifier(b.tableName)
|
||||||
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1325,53 +1489,84 @@ func (b *BunInsertQuery) OnConflict(action string) common.InsertQuery {
|
|||||||
|
|
||||||
func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
func (b *BunInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||||
if len(columns) > 0 {
|
if len(columns) > 0 {
|
||||||
b.query = b.query.Returning(columns[0])
|
b.query = b.query.Returning(strings.Join(columns, ", "))
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BunInsertQuery) prepareValues() {
|
||||||
|
if len(b.values) > 0 {
|
||||||
|
if !b.hasModel {
|
||||||
|
b.query = b.query.Model(&b.values)
|
||||||
|
} else {
|
||||||
|
for k, v := range b.values {
|
||||||
|
b.query = b.query.Value(k, "?", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
func (b *BunInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
err = logger.HandlePanic("BunInsertQuery.Exec", r)
|
err = logger.HandlePanic("BunInsertQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
if len(b.values) > 0 {
|
startedAt := time.Now()
|
||||||
if !b.hasModel {
|
b.prepareValues()
|
||||||
// If no model was set, use the values map as the model
|
|
||||||
// Bun can insert map[string]interface{} directly
|
|
||||||
b.query = b.query.Model(&b.values)
|
|
||||||
} else {
|
|
||||||
// If model was set, use Value() to add individual values
|
|
||||||
for k, v := range b.values {
|
|
||||||
b.query = b.query.Value(k, "?", v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result, err := b.query.Exec(ctx)
|
result, err := b.query.Exec(ctx)
|
||||||
|
recordQueryMetrics(b.metricsEnabled, "INSERT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *BunInsertQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("BunInsertQuery.Scan", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
startedAt := time.Now()
|
||||||
|
b.prepareValues()
|
||||||
|
err = b.query.Scan(ctx, dest)
|
||||||
|
recordQueryMetrics(b.metricsEnabled, "INSERT", b.schema, b.entity, b.tableName, startedAt, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// BunUpdateQuery implements UpdateQuery for Bun
|
// BunUpdateQuery implements UpdateQuery for Bun
|
||||||
type BunUpdateQuery struct {
|
type BunUpdateQuery struct {
|
||||||
query *bun.UpdateQuery
|
query *bun.UpdateQuery
|
||||||
model interface{}
|
model interface{}
|
||||||
|
driverName string
|
||||||
|
schema string
|
||||||
|
tableName string
|
||||||
|
entity string
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
func (b *BunUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||||
b.query = b.query.Model(model)
|
b.query = b.query.Model(model)
|
||||||
b.model = model
|
b.model = model
|
||||||
|
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
|
||||||
|
if b.tableName == "" {
|
||||||
|
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
|
||||||
|
}
|
||||||
|
b.entity = entityNameFromModel(model, b.tableName)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
|
func (b *BunUpdateQuery) Table(table string) common.UpdateQuery {
|
||||||
b.query = b.query.Table(table)
|
b.query = b.query.Table(table)
|
||||||
|
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||||
|
if b.entity == "" {
|
||||||
|
b.entity = cleanMetricIdentifier(b.tableName)
|
||||||
|
}
|
||||||
if b.model == nil {
|
if b.model == nil {
|
||||||
// Try to get table name from table string if model is not set
|
// Try to get table name from table string if model is not set
|
||||||
|
|
||||||
model, err := modelregistry.GetModelByName(table)
|
model, err := modelregistry.GetModelByName(table)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
b.model = model
|
b.model = model
|
||||||
|
b.entity = entityNameFromModel(model, b.tableName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
@@ -1399,7 +1594,7 @@ func (b *BunUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQuer
|
|||||||
// Skip primary key updates
|
// Skip primary key updates
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
b.query = b.query.Set(column+" = ?", value)
|
b.query = b.query.Set(column+" = ?", common.ConvertSliceForBun(value))
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
@@ -1411,7 +1606,7 @@ func (b *BunUpdateQuery) Where(query string, args ...interface{}) common.UpdateQ
|
|||||||
|
|
||||||
func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
func (b *BunUpdateQuery) Returning(columns ...string) common.UpdateQuery {
|
||||||
if len(columns) > 0 {
|
if len(columns) > 0 {
|
||||||
b.query = b.query.Returning(columns[0])
|
b.query = b.query.Returning(strings.Join(columns, ", "))
|
||||||
}
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
@@ -1422,27 +1617,44 @@ func (b *BunUpdateQuery) Exec(ctx context.Context) (res common.Result, err error
|
|||||||
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
|
err = logger.HandlePanic("BunUpdateQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
startedAt := time.Now()
|
||||||
result, err := b.query.Exec(ctx)
|
result, err := b.query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := b.query.String()
|
sqlStr := b.query.String()
|
||||||
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
logger.Error("BunUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
err = common.WrapSQLError(err, sqlStr)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(b.metricsEnabled, "UPDATE", b.schema, b.entity, b.tableName, startedAt, err)
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// BunDeleteQuery implements DeleteQuery for Bun
|
// BunDeleteQuery implements DeleteQuery for Bun
|
||||||
type BunDeleteQuery struct {
|
type BunDeleteQuery struct {
|
||||||
query *bun.DeleteQuery
|
query *bun.DeleteQuery
|
||||||
|
driverName string
|
||||||
|
schema string
|
||||||
|
tableName string
|
||||||
|
entity string
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
func (b *BunDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||||
b.query = b.query.Model(model)
|
b.query = b.query.Model(model)
|
||||||
|
b.schema, b.tableName = schemaAndTableFromModel(model, b.driverName)
|
||||||
|
if b.tableName == "" {
|
||||||
|
b.schema, b.tableName = parseTableName(b.query.GetTableName(), b.driverName)
|
||||||
|
}
|
||||||
|
b.entity = entityNameFromModel(model, b.tableName)
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunDeleteQuery) Table(table string) common.DeleteQuery {
|
func (b *BunDeleteQuery) Table(table string) common.DeleteQuery {
|
||||||
b.query = b.query.Table(table)
|
b.query = b.query.Table(table)
|
||||||
|
b.schema, b.tableName = parseTableName(table, b.driverName)
|
||||||
|
if b.entity == "" {
|
||||||
|
b.entity = cleanMetricIdentifier(b.tableName)
|
||||||
|
}
|
||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1457,12 +1669,15 @@ func (b *BunDeleteQuery) Exec(ctx context.Context) (res common.Result, err error
|
|||||||
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
|
err = logger.HandlePanic("BunDeleteQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
startedAt := time.Now()
|
||||||
result, err := b.query.Exec(ctx)
|
result, err := b.query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := b.query.String()
|
sqlStr := b.query.String()
|
||||||
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
logger.Error("BunDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
err = common.WrapSQLError(err, sqlStr)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(b.metricsEnabled, "DELETE", b.schema, b.entity, b.tableName, startedAt, err)
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1490,34 +1705,44 @@ func (b *BunResult) LastInsertId() (int64, error) {
|
|||||||
type BunTxAdapter struct {
|
type BunTxAdapter struct {
|
||||||
tx bun.Tx
|
tx bun.Tx
|
||||||
driverName string
|
driverName string
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunTxAdapter) NewSelect() common.SelectQuery {
|
func (b *BunTxAdapter) NewSelect() common.SelectQuery {
|
||||||
return &BunSelectQuery{
|
return &BunSelectQuery{
|
||||||
query: b.tx.NewSelect(),
|
query: b.tx.NewSelect(),
|
||||||
db: b.tx,
|
db: b.tx,
|
||||||
|
driverName: b.driverName,
|
||||||
|
metricsEnabled: b.metricsEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunTxAdapter) NewInsert() common.InsertQuery {
|
func (b *BunTxAdapter) NewInsert() common.InsertQuery {
|
||||||
return &BunInsertQuery{query: b.tx.NewInsert()}
|
return &BunInsertQuery{query: b.tx.NewInsert(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunTxAdapter) NewUpdate() common.UpdateQuery {
|
func (b *BunTxAdapter) NewUpdate() common.UpdateQuery {
|
||||||
return &BunUpdateQuery{query: b.tx.NewUpdate()}
|
return &BunUpdateQuery{query: b.tx.NewUpdate(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunTxAdapter) NewDelete() common.DeleteQuery {
|
func (b *BunTxAdapter) NewDelete() common.DeleteQuery {
|
||||||
return &BunDeleteQuery{query: b.tx.NewDelete()}
|
return &BunDeleteQuery{query: b.tx.NewDelete(), driverName: b.driverName, metricsEnabled: b.metricsEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
func (b *BunTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
||||||
|
startedAt := time.Now()
|
||||||
|
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
|
||||||
result, err := b.tx.ExecContext(ctx, query, args...)
|
result, err := b.tx.ExecContext(ctx, query, args...)
|
||||||
|
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||||
return &BunResult{result: result}, err
|
return &BunResult{result: result}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
func (b *BunTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||||
return b.tx.NewRaw(query, args...).Scan(ctx, dest)
|
startedAt := time.Now()
|
||||||
|
operation, schema, entity, table := metricTargetFromRawQuery(query, b.driverName)
|
||||||
|
err := b.tx.NewRaw(query, args...).Scan(ctx, dest)
|
||||||
|
recordQueryMetrics(b.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BunTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
func (b *BunTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||||
|
|||||||
@@ -3,9 +3,13 @@ package database
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/clause"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -15,18 +19,93 @@ import (
|
|||||||
|
|
||||||
// GormAdapter adapts GORM to work with our Database interface
|
// GormAdapter adapts GORM to work with our Database interface
|
||||||
type GormAdapter struct {
|
type GormAdapter struct {
|
||||||
|
dbMu sync.RWMutex
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
dbFactory func() (*gorm.DB, error)
|
||||||
|
driverName string
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewGormAdapter creates a new GORM adapter
|
// NewGormAdapter creates a new GORM adapter
|
||||||
func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
func NewGormAdapter(db *gorm.DB) *GormAdapter {
|
||||||
return &GormAdapter{db: db}
|
adapter := &GormAdapter{db: db, metricsEnabled: true}
|
||||||
|
// Initialize driver name
|
||||||
|
adapter.driverName = adapter.DriverName()
|
||||||
|
return adapter
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||||
|
func (g *GormAdapter) WithDBFactory(factory func() (*gorm.DB, error)) *GormAdapter {
|
||||||
|
g.dbFactory = factory
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetricsEnabled enables or disables query metrics for this adapter.
|
||||||
|
func (g *GormAdapter) SetMetricsEnabled(enabled bool) *GormAdapter {
|
||||||
|
g.metricsEnabled = enabled
|
||||||
|
return g
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GormAdapter) getDB() *gorm.DB {
|
||||||
|
g.dbMu.RLock()
|
||||||
|
defer g.dbMu.RUnlock()
|
||||||
|
return g.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GormAdapter) reconnectDB(targets ...*gorm.DB) error {
|
||||||
|
if g.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
|
||||||
|
freshDB, err := g.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
g.dbMu.Lock()
|
||||||
|
previous := g.db
|
||||||
|
g.db = freshDB
|
||||||
|
g.driverName = normalizeGormDriverName(freshDB)
|
||||||
|
g.dbMu.Unlock()
|
||||||
|
|
||||||
|
if previous != nil {
|
||||||
|
syncGormConnPool(previous, freshDB)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, target := range targets {
|
||||||
|
if target != nil && target != previous {
|
||||||
|
syncGormConnPool(target, freshDB)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func syncGormConnPool(target, fresh *gorm.DB) {
|
||||||
|
if target == nil || fresh == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if target.Config != nil && fresh.Config != nil {
|
||||||
|
target.ConnPool = fresh.ConnPool
|
||||||
|
}
|
||||||
|
|
||||||
|
if target.Statement != nil {
|
||||||
|
if fresh.Statement != nil && fresh.Statement.ConnPool != nil {
|
||||||
|
target.Statement.ConnPool = fresh.Statement.ConnPool
|
||||||
|
} else if fresh.Config != nil {
|
||||||
|
target.Statement.ConnPool = fresh.ConnPool
|
||||||
|
}
|
||||||
|
target.Statement.DB = target
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
// EnableQueryDebug enables query debugging which logs all SQL queries including preloads
|
||||||
// This is useful for debugging preload queries that may be failing
|
// This is useful for debugging preload queries that may be failing
|
||||||
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
|
func (g *GormAdapter) EnableQueryDebug() *GormAdapter {
|
||||||
|
g.dbMu.Lock()
|
||||||
g.db = g.db.Debug()
|
g.db = g.db.Debug()
|
||||||
|
g.dbMu.Unlock()
|
||||||
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
|
logger.Info("GORM query debug mode enabled - all SQL queries will be logged")
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
@@ -40,19 +119,19 @@ func (g *GormAdapter) DisableQueryDebug() *GormAdapter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
func (g *GormAdapter) NewSelect() common.SelectQuery {
|
||||||
return &GormSelectQuery{db: g.db}
|
return &GormSelectQuery{db: g.getDB(), driverName: g.driverName, reconnect: g.reconnectDB, metricsEnabled: g.metricsEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) NewInsert() common.InsertQuery {
|
func (g *GormAdapter) NewInsert() common.InsertQuery {
|
||||||
return &GormInsertQuery{db: g.db}
|
return &GormInsertQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
|
func (g *GormAdapter) NewUpdate() common.UpdateQuery {
|
||||||
return &GormUpdateQuery{db: g.db}
|
return &GormUpdateQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
func (g *GormAdapter) NewDelete() common.DeleteQuery {
|
||||||
return &GormDeleteQuery{db: g.db}
|
return &GormDeleteQuery{db: g.getDB(), reconnect: g.reconnectDB, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{}) (res common.Result, err error) {
|
||||||
@@ -61,7 +140,18 @@ func (g *GormAdapter) Exec(ctx context.Context, query string, args ...interface{
|
|||||||
err = logger.HandlePanic("GormAdapter.Exec", r)
|
err = logger.HandlePanic("GormAdapter.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result := g.db.WithContext(ctx).Exec(query, args...)
|
startedAt := time.Now()
|
||||||
|
operation, schema, entity, table := metricTargetFromRawQuery(query, g.driverName)
|
||||||
|
run := func() *gorm.DB {
|
||||||
|
return g.getDB().WithContext(ctx).Exec(query, args...)
|
||||||
|
}
|
||||||
|
result := run()
|
||||||
|
if isDBClosed(result.Error) {
|
||||||
|
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||||
|
result = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recordQueryMetrics(g.metricsEnabled, operation, schema, entity, table, startedAt, result.Error)
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,15 +161,35 @@ func (g *GormAdapter) Query(ctx context.Context, dest interface{}, query string,
|
|||||||
err = logger.HandlePanic("GormAdapter.Query", r)
|
err = logger.HandlePanic("GormAdapter.Query", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return g.db.WithContext(ctx).Raw(query, args...).Find(dest).Error
|
startedAt := time.Now()
|
||||||
|
operation, schema, entity, table := metricTargetFromRawQuery(query, g.driverName)
|
||||||
|
run := func() error {
|
||||||
|
return g.getDB().WithContext(ctx).Raw(query, args...).Find(dest).Error
|
||||||
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recordQueryMetrics(g.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
func (g *GormAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||||
tx := g.db.WithContext(ctx).Begin()
|
run := func() *gorm.DB {
|
||||||
|
return g.getDB().WithContext(ctx).Begin()
|
||||||
|
}
|
||||||
|
tx := run()
|
||||||
|
if isDBClosed(tx.Error) {
|
||||||
|
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||||
|
tx = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if tx.Error != nil {
|
if tx.Error != nil {
|
||||||
return nil, tx.Error
|
return nil, tx.Error
|
||||||
}
|
}
|
||||||
return &GormAdapter{db: tx}, nil
|
return &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName, metricsEnabled: g.metricsEnabled}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) CommitTx(ctx context.Context) error {
|
func (g *GormAdapter) CommitTx(ctx context.Context) error {
|
||||||
@@ -96,25 +206,41 @@ func (g *GormAdapter) RunInTransaction(ctx context.Context, fn func(common.Datab
|
|||||||
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
err = logger.HandlePanic("GormAdapter.RunInTransaction", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return g.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
run := func() error {
|
||||||
adapter := &GormAdapter{db: tx}
|
return g.getDB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||||
|
adapter := &GormAdapter{db: tx, dbFactory: g.dbFactory, driverName: g.driverName, metricsEnabled: g.metricsEnabled}
|
||||||
return fn(adapter)
|
return fn(adapter)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := g.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
func (g *GormAdapter) GetUnderlyingDB() interface{} {
|
||||||
return g.db
|
return g.getDB()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormAdapter) DriverName() string {
|
func (g *GormAdapter) DriverName() string {
|
||||||
if g.db.Dialector == nil {
|
return normalizeGormDriverName(g.getDB())
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeGormDriverName(db *gorm.DB) string {
|
||||||
|
if db == nil || db.Dialector == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
// Normalize GORM's dialector name to match the project's canonical vocabulary.
|
// Normalize GORM's dialector name to match the project's canonical vocabulary.
|
||||||
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
|
// GORM returns "sqlserver" for MSSQL; the rest of the project uses "mssql".
|
||||||
switch name := g.db.Name(); name {
|
// GORM returns "sqlite" or "sqlite3" for SQLite; we normalize to "sqlite".
|
||||||
|
switch name := db.Name(); name {
|
||||||
case "sqlserver":
|
case "sqlserver":
|
||||||
return "mssql"
|
return "mssql"
|
||||||
|
case "sqlite3":
|
||||||
|
return "sqlite"
|
||||||
default:
|
default:
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
@@ -123,22 +249,21 @@ func (g *GormAdapter) DriverName() string {
|
|||||||
// GormSelectQuery implements SelectQuery for GORM
|
// GormSelectQuery implements SelectQuery for GORM
|
||||||
type GormSelectQuery struct {
|
type GormSelectQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
reconnect func(...*gorm.DB) error
|
||||||
schema string // Separated schema name
|
schema string // Separated schema name
|
||||||
tableName string // Just the table name, without schema
|
tableName string // Just the table name, without schema
|
||||||
|
entity string
|
||||||
tableAlias string
|
tableAlias string
|
||||||
|
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||||
inJoinContext bool // Track if we're in a JOIN relation context
|
inJoinContext bool // Track if we're in a JOIN relation context
|
||||||
joinTableAlias string // Alias to use for JOIN conditions
|
joinTableAlias string // Alias to use for JOIN conditions
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||||
g.db = g.db.Model(model)
|
g.db = g.db.Model(model)
|
||||||
|
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
|
||||||
// Try to get table name from model if it implements TableNameProvider
|
g.entity = entityNameFromModel(model, g.tableName)
|
||||||
if provider, ok := model.(common.TableNameProvider); ok {
|
|
||||||
fullTableName := provider.TableName()
|
|
||||||
// Check if the table name contains schema (e.g., "schema.table")
|
|
||||||
g.schema, g.tableName = parseTableName(fullTableName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||||
g.tableAlias = provider.TableAlias()
|
g.tableAlias = provider.TableAlias()
|
||||||
@@ -150,7 +275,11 @@ func (g *GormSelectQuery) Model(model interface{}) common.SelectQuery {
|
|||||||
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
func (g *GormSelectQuery) Table(table string) common.SelectQuery {
|
||||||
g.db = g.db.Table(table)
|
g.db = g.db.Table(table)
|
||||||
// Check if the table name contains schema (e.g., "schema.table")
|
// Check if the table name contains schema (e.g., "schema.table")
|
||||||
g.schema, g.tableName = parseTableName(table)
|
// For SQLite, this will convert "schema.table" to "schema_table"
|
||||||
|
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||||
|
if g.entity == "" {
|
||||||
|
g.entity = cleanMetricIdentifier(g.tableName)
|
||||||
|
}
|
||||||
|
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
@@ -337,6 +466,9 @@ func (g *GormSelectQuery) PreloadRelation(relation string, apply ...func(common.
|
|||||||
|
|
||||||
wrapper := &GormSelectQuery{
|
wrapper := &GormSelectQuery{
|
||||||
db: db,
|
db: db,
|
||||||
|
reconnect: g.reconnect,
|
||||||
|
driverName: g.driverName,
|
||||||
|
metricsEnabled: g.metricsEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
current := common.SelectQuery(wrapper)
|
current := common.SelectQuery(wrapper)
|
||||||
@@ -374,8 +506,11 @@ func (g *GormSelectQuery) JoinRelation(relation string, apply ...func(common.Sel
|
|||||||
|
|
||||||
wrapper := &GormSelectQuery{
|
wrapper := &GormSelectQuery{
|
||||||
db: db,
|
db: db,
|
||||||
|
reconnect: g.reconnect,
|
||||||
|
driverName: g.driverName,
|
||||||
inJoinContext: true, // Mark as JOIN context
|
inJoinContext: true, // Mark as JOIN context
|
||||||
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
joinTableAlias: strings.ToLower(relation), // Use relation name as alias
|
||||||
|
metricsEnabled: g.metricsEnabled,
|
||||||
}
|
}
|
||||||
current := common.SelectQuery(wrapper)
|
current := common.SelectQuery(wrapper)
|
||||||
|
|
||||||
@@ -432,14 +567,25 @@ func (g *GormSelectQuery) Scan(ctx context.Context, dest interface{}) (err error
|
|||||||
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
err = logger.HandlePanic("GormSelectQuery.Scan", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
err = g.db.WithContext(ctx).Find(dest).Error
|
startedAt := time.Now()
|
||||||
|
run := func() error {
|
||||||
|
return g.db.WithContext(ctx).Find(dest).Error
|
||||||
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
return tx.Find(dest)
|
return tx.Find(dest)
|
||||||
})
|
})
|
||||||
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
logger.Error("GormSelectQuery.Scan failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
err = common.WrapSQLError(err, sqlStr)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(g.metricsEnabled, "SELECT", g.schema, g.entity, g.tableName, startedAt, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -452,14 +598,25 @@ func (g *GormSelectQuery) ScanModel(ctx context.Context) (err error) {
|
|||||||
if g.db.Statement.Model == nil {
|
if g.db.Statement.Model == nil {
|
||||||
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
return fmt.Errorf("ScanModel requires Model() to be set before scanning")
|
||||||
}
|
}
|
||||||
err = g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
startedAt := time.Now()
|
||||||
|
run := func() error {
|
||||||
|
return g.db.WithContext(ctx).Find(g.db.Statement.Model).Error
|
||||||
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
return tx.Find(g.db.Statement.Model)
|
return tx.Find(g.db.Statement.Model)
|
||||||
})
|
})
|
||||||
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
logger.Error("GormSelectQuery.ScanModel failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
err = common.WrapSQLError(err, sqlStr)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(g.metricsEnabled, "SELECT", g.schema, g.entity, g.tableName, startedAt, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -470,15 +627,26 @@ func (g *GormSelectQuery) Count(ctx context.Context) (count int, err error) {
|
|||||||
count = 0
|
count = 0
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
startedAt := time.Now()
|
||||||
var count64 int64
|
var count64 int64
|
||||||
err = g.db.WithContext(ctx).Count(&count64).Error
|
run := func() error {
|
||||||
|
return g.db.WithContext(ctx).Count(&count64).Error
|
||||||
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
return tx.Count(&count64)
|
return tx.Count(&count64)
|
||||||
})
|
})
|
||||||
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
logger.Error("GormSelectQuery.Count failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
err = common.WrapSQLError(err, sqlStr)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(g.metricsEnabled, "COUNT", g.schema, g.entity, g.tableName, startedAt, err)
|
||||||
return int(count64), err
|
return int(count64), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -489,33 +657,57 @@ func (g *GormSelectQuery) Exists(ctx context.Context) (exists bool, err error) {
|
|||||||
exists = false
|
exists = false
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
startedAt := time.Now()
|
||||||
var count int64
|
var count int64
|
||||||
err = g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
run := func() error {
|
||||||
|
return g.db.WithContext(ctx).Limit(1).Count(&count).Error
|
||||||
|
}
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
return tx.Limit(1).Count(&count)
|
return tx.Limit(1).Count(&count)
|
||||||
})
|
})
|
||||||
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
logger.Error("GormSelectQuery.Exists failed. SQL: %s. Error: %v", sqlStr, err)
|
||||||
|
err = common.WrapSQLError(err, sqlStr)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(g.metricsEnabled, "EXISTS", g.schema, g.entity, g.tableName, startedAt, err)
|
||||||
return count > 0, err
|
return count > 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GormInsertQuery implements InsertQuery for GORM
|
// GormInsertQuery implements InsertQuery for GORM
|
||||||
type GormInsertQuery struct {
|
type GormInsertQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
reconnect func(...*gorm.DB) error
|
||||||
model interface{}
|
model interface{}
|
||||||
values map[string]interface{}
|
values map[string]interface{}
|
||||||
|
schema string
|
||||||
|
tableName string
|
||||||
|
entity string
|
||||||
|
driverName string
|
||||||
|
metricsEnabled bool
|
||||||
|
returningColumns []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery {
|
func (g *GormInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||||
g.model = model
|
g.model = model
|
||||||
g.db = g.db.Model(model)
|
g.db = g.db.Model(model)
|
||||||
|
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
|
||||||
|
g.entity = entityNameFromModel(model, g.tableName)
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormInsertQuery) Table(table string) common.InsertQuery {
|
func (g *GormInsertQuery) Table(table string) common.InsertQuery {
|
||||||
g.db = g.db.Table(table)
|
g.db = g.db.Table(table)
|
||||||
|
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||||
|
if g.entity == "" {
|
||||||
|
g.entity = cleanMetricIdentifier(g.tableName)
|
||||||
|
}
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -533,7 +725,7 @@ func (g *GormInsertQuery) OnConflict(action string) common.InsertQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
|
func (g *GormInsertQuery) Returning(columns ...string) common.InsertQuery {
|
||||||
// GORM doesn't have explicit RETURNING, but updates the model
|
g.returningColumns = columns
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -543,38 +735,130 @@ func (g *GormInsertQuery) Exec(ctx context.Context) (res common.Result, err erro
|
|||||||
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
err = logger.HandlePanic("GormInsertQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
startedAt := time.Now()
|
||||||
|
run := func() *gorm.DB {
|
||||||
|
switch {
|
||||||
|
case g.model != nil:
|
||||||
|
return g.db.WithContext(ctx).Create(g.model)
|
||||||
|
case g.values != nil:
|
||||||
|
return g.db.WithContext(ctx).Create(g.values)
|
||||||
|
default:
|
||||||
|
return g.db.WithContext(ctx).Create(map[string]interface{}{})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result := run()
|
||||||
|
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
result = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
recordQueryMetrics(g.metricsEnabled, "INSERT", g.schema, g.entity, g.tableName, startedAt, result.Error)
|
||||||
|
return &GormResult{result: result}, result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *GormInsertQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("GormInsertQuery.Scan", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
startedAt := time.Now()
|
||||||
|
|
||||||
|
var returningCols []clause.Column
|
||||||
|
for _, col := range g.returningColumns {
|
||||||
|
returningCols = append(returningCols, clause.Column{Name: col})
|
||||||
|
}
|
||||||
|
|
||||||
|
db := g.db.WithContext(ctx)
|
||||||
|
if len(returningCols) > 0 {
|
||||||
|
db = db.Clauses(clause.Returning{Columns: returningCols})
|
||||||
|
}
|
||||||
|
|
||||||
var result *gorm.DB
|
var result *gorm.DB
|
||||||
switch {
|
switch {
|
||||||
case g.model != nil:
|
case g.model != nil:
|
||||||
result = g.db.WithContext(ctx).Create(g.model)
|
result = db.Create(g.model)
|
||||||
case g.values != nil:
|
case g.values != nil:
|
||||||
result = g.db.WithContext(ctx).Create(g.values)
|
result = db.Create(g.values)
|
||||||
default:
|
default:
|
||||||
result = g.db.WithContext(ctx).Create(map[string]interface{}{})
|
result = db.Create(map[string]interface{}{})
|
||||||
}
|
}
|
||||||
return &GormResult{result: result}, result.Error
|
|
||||||
|
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
result = db.Create(g.model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
recordQueryMetrics(g.metricsEnabled, "INSERT", g.schema, g.entity, g.tableName, startedAt, result.Error)
|
||||||
|
if result.Error != nil {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the returning column value from the model or values map
|
||||||
|
if len(g.returningColumns) == 1 {
|
||||||
|
col := g.returningColumns[0]
|
||||||
|
if g.model != nil {
|
||||||
|
val := reflect.ValueOf(g.model)
|
||||||
|
if val.Kind() == reflect.Ptr {
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
if val.Kind() == reflect.Struct {
|
||||||
|
for i := 0; i < val.NumField(); i++ {
|
||||||
|
f := val.Type().Field(i)
|
||||||
|
dbTag := strings.Split(f.Tag.Get("bun"), ",")[0]
|
||||||
|
jsonTag := strings.Split(f.Tag.Get("json"), ",")[0]
|
||||||
|
if strings.EqualFold(f.Name, col) || dbTag == col || jsonTag == col {
|
||||||
|
reflect.ValueOf(dest).Elem().Set(val.Field(i))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if g.values != nil {
|
||||||
|
if v, ok := g.values[col]; ok {
|
||||||
|
reflect.ValueOf(dest).Elem().Set(reflect.ValueOf(v))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GormUpdateQuery implements UpdateQuery for GORM
|
// GormUpdateQuery implements UpdateQuery for GORM
|
||||||
type GormUpdateQuery struct {
|
type GormUpdateQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
reconnect func(...*gorm.DB) error
|
||||||
model interface{}
|
model interface{}
|
||||||
updates interface{}
|
updates interface{}
|
||||||
|
schema string
|
||||||
|
tableName string
|
||||||
|
entity string
|
||||||
|
driverName string
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
func (g *GormUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||||
g.model = model
|
g.model = model
|
||||||
g.db = g.db.Model(model)
|
g.db = g.db.Model(model)
|
||||||
|
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
|
||||||
|
g.entity = entityNameFromModel(model, g.tableName)
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
|
func (g *GormUpdateQuery) Table(table string) common.UpdateQuery {
|
||||||
g.db = g.db.Table(table)
|
g.db = g.db.Table(table)
|
||||||
|
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||||
|
if g.entity == "" {
|
||||||
|
g.entity = cleanMetricIdentifier(g.tableName)
|
||||||
|
}
|
||||||
if g.model == nil {
|
if g.model == nil {
|
||||||
// Try to get table name from table string if model is not set
|
// Try to get table name from table string if model is not set
|
||||||
model, err := modelregistry.GetModelByName(table)
|
model, err := modelregistry.GetModelByName(table)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
g.model = model
|
g.model = model
|
||||||
|
g.entity = entityNameFromModel(model, g.tableName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return g
|
return g
|
||||||
@@ -635,31 +919,54 @@ func (g *GormUpdateQuery) Exec(ctx context.Context) (res common.Result, err erro
|
|||||||
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
err = logger.HandlePanic("GormUpdateQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result := g.db.WithContext(ctx).Updates(g.updates)
|
startedAt := time.Now()
|
||||||
|
run := func() *gorm.DB {
|
||||||
|
return g.db.WithContext(ctx).Updates(g.updates)
|
||||||
|
}
|
||||||
|
result := run()
|
||||||
|
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
result = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
return tx.Updates(g.updates)
|
return tx.Updates(g.updates)
|
||||||
})
|
})
|
||||||
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
logger.Error("GormUpdateQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
||||||
|
return &GormResult{result: result}, common.WrapSQLError(result.Error, sqlStr)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(g.metricsEnabled, "UPDATE", g.schema, g.entity, g.tableName, startedAt, result.Error)
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// GormDeleteQuery implements DeleteQuery for GORM
|
// GormDeleteQuery implements DeleteQuery for GORM
|
||||||
type GormDeleteQuery struct {
|
type GormDeleteQuery struct {
|
||||||
db *gorm.DB
|
db *gorm.DB
|
||||||
|
reconnect func(...*gorm.DB) error
|
||||||
model interface{}
|
model interface{}
|
||||||
|
schema string
|
||||||
|
tableName string
|
||||||
|
entity string
|
||||||
|
driverName string
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
func (g *GormDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||||
g.model = model
|
g.model = model
|
||||||
g.db = g.db.Model(model)
|
g.db = g.db.Model(model)
|
||||||
|
g.schema, g.tableName = schemaAndTableFromModel(model, g.driverName)
|
||||||
|
g.entity = entityNameFromModel(model, g.tableName)
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *GormDeleteQuery) Table(table string) common.DeleteQuery {
|
func (g *GormDeleteQuery) Table(table string) common.DeleteQuery {
|
||||||
g.db = g.db.Table(table)
|
g.db = g.db.Table(table)
|
||||||
|
g.schema, g.tableName = parseTableName(table, g.driverName)
|
||||||
|
if g.entity == "" {
|
||||||
|
g.entity = cleanMetricIdentifier(g.tableName)
|
||||||
|
}
|
||||||
return g
|
return g
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -674,14 +981,25 @@ func (g *GormDeleteQuery) Exec(ctx context.Context) (res common.Result, err erro
|
|||||||
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
err = logger.HandlePanic("GormDeleteQuery.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
result := g.db.WithContext(ctx).Delete(g.model)
|
startedAt := time.Now()
|
||||||
|
run := func() *gorm.DB {
|
||||||
|
return g.db.WithContext(ctx).Delete(g.model)
|
||||||
|
}
|
||||||
|
result := run()
|
||||||
|
if isDBClosed(result.Error) && g.reconnect != nil {
|
||||||
|
if reconnErr := g.reconnect(g.db); reconnErr == nil {
|
||||||
|
result = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
// Log SQL string for debugging
|
// Log SQL string for debugging
|
||||||
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
sqlStr := g.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
|
||||||
return tx.Delete(g.model)
|
return tx.Delete(g.model)
|
||||||
})
|
})
|
||||||
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
logger.Error("GormDeleteQuery.Exec failed. SQL: %s. Error: %v", sqlStr, result.Error)
|
||||||
|
return &GormResult{result: result}, common.WrapSQLError(result.Error, sqlStr)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(g.metricsEnabled, "DELETE", g.schema, g.entity, g.tableName, startedAt, result.Error)
|
||||||
return &GormResult{result: result}, result.Error
|
return &GormResult{result: result}, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -17,7 +20,10 @@ import (
|
|||||||
// This provides a lightweight PostgreSQL adapter without ORM overhead
|
// This provides a lightweight PostgreSQL adapter without ORM overhead
|
||||||
type PgSQLAdapter struct {
|
type PgSQLAdapter struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*sql.DB, error)
|
||||||
driverName string
|
driverName string
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPgSQLAdapter creates a new adapter wrapping a standard sql.DB.
|
// NewPgSQLAdapter creates a new adapter wrapping a standard sql.DB.
|
||||||
@@ -28,7 +34,43 @@ func NewPgSQLAdapter(db *sql.DB, driverName ...string) *PgSQLAdapter {
|
|||||||
if len(driverName) > 0 && driverName[0] != "" {
|
if len(driverName) > 0 && driverName[0] != "" {
|
||||||
name = driverName[0]
|
name = driverName[0]
|
||||||
}
|
}
|
||||||
return &PgSQLAdapter{db: db, driverName: name}
|
return &PgSQLAdapter{db: db, driverName: name, metricsEnabled: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||||
|
func (p *PgSQLAdapter) WithDBFactory(factory func() (*sql.DB, error)) *PgSQLAdapter {
|
||||||
|
p.dbFactory = factory
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMetricsEnabled enables or disables query metrics for this adapter.
|
||||||
|
func (p *PgSQLAdapter) SetMetricsEnabled(enabled bool) *PgSQLAdapter {
|
||||||
|
p.metricsEnabled = enabled
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PgSQLAdapter) getDB() *sql.DB {
|
||||||
|
p.dbMu.RLock()
|
||||||
|
defer p.dbMu.RUnlock()
|
||||||
|
return p.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PgSQLAdapter) reconnectDB() error {
|
||||||
|
if p.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := p.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.dbMu.Lock()
|
||||||
|
p.db = newDB
|
||||||
|
p.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDBClosed(err error) bool {
|
||||||
|
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableQueryDebug enables query debugging for development
|
// EnableQueryDebug enables query debugging for development
|
||||||
@@ -38,33 +80,41 @@ func (p *PgSQLAdapter) EnableQueryDebug() {
|
|||||||
|
|
||||||
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
func (p *PgSQLAdapter) NewSelect() common.SelectQuery {
|
||||||
return &PgSQLSelectQuery{
|
return &PgSQLSelectQuery{
|
||||||
db: p.db,
|
db: p.getDB(),
|
||||||
|
driverName: p.driverName,
|
||||||
columns: []string{"*"},
|
columns: []string{"*"},
|
||||||
args: make([]interface{}, 0),
|
args: make([]interface{}, 0),
|
||||||
|
metricsEnabled: p.metricsEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
func (p *PgSQLAdapter) NewInsert() common.InsertQuery {
|
||||||
return &PgSQLInsertQuery{
|
return &PgSQLInsertQuery{
|
||||||
db: p.db,
|
db: p.getDB(),
|
||||||
|
driverName: p.driverName,
|
||||||
values: make(map[string]interface{}),
|
values: make(map[string]interface{}),
|
||||||
|
metricsEnabled: p.metricsEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
func (p *PgSQLAdapter) NewUpdate() common.UpdateQuery {
|
||||||
return &PgSQLUpdateQuery{
|
return &PgSQLUpdateQuery{
|
||||||
db: p.db,
|
db: p.getDB(),
|
||||||
|
driverName: p.driverName,
|
||||||
sets: make(map[string]interface{}),
|
sets: make(map[string]interface{}),
|
||||||
args: make([]interface{}, 0),
|
args: make([]interface{}, 0),
|
||||||
whereClauses: make([]string, 0),
|
whereClauses: make([]string, 0),
|
||||||
|
metricsEnabled: p.metricsEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
|
func (p *PgSQLAdapter) NewDelete() common.DeleteQuery {
|
||||||
return &PgSQLDeleteQuery{
|
return &PgSQLDeleteQuery{
|
||||||
db: p.db,
|
db: p.getDB(),
|
||||||
|
driverName: p.driverName,
|
||||||
args: make([]interface{}, 0),
|
args: make([]interface{}, 0),
|
||||||
whereClauses: make([]string, 0),
|
whereClauses: make([]string, 0),
|
||||||
|
metricsEnabled: p.metricsEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,12 +124,23 @@ func (p *PgSQLAdapter) Exec(ctx context.Context, query string, args ...interface
|
|||||||
err = logger.HandlePanic("PgSQLAdapter.Exec", r)
|
err = logger.HandlePanic("PgSQLAdapter.Exec", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
startedAt := time.Now()
|
||||||
|
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
|
||||||
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
|
logger.Debug("PgSQL Exec: %s [args: %v]", query, args)
|
||||||
result, err := p.db.ExecContext(ctx, query, args...)
|
var result sql.Result
|
||||||
|
run := func() error { var e error; result, e = p.getDB().ExecContext(ctx, query, args...); return e }
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("PgSQL Exec failed: %v", err)
|
logger.Error("PgSQL Exec failed: %v", err)
|
||||||
return nil, err
|
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||||
|
return nil, common.WrapSQLError(err, query)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, nil)
|
||||||
return &PgSQLResult{result: result}, nil
|
return &PgSQLResult{result: result}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,23 +150,35 @@ func (p *PgSQLAdapter) Query(ctx context.Context, dest interface{}, query string
|
|||||||
err = logger.HandlePanic("PgSQLAdapter.Query", r)
|
err = logger.HandlePanic("PgSQLAdapter.Query", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
startedAt := time.Now()
|
||||||
|
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
|
||||||
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
|
logger.Debug("PgSQL Query: %s [args: %v]", query, args)
|
||||||
rows, err := p.db.QueryContext(ctx, query, args...)
|
var rows *sql.Rows
|
||||||
|
run := func() error { var e error; rows, e = p.getDB().QueryContext(ctx, query, args...); return e }
|
||||||
|
err = run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("PgSQL Query failed: %v", err)
|
logger.Error("PgSQL Query failed: %v", err)
|
||||||
return err
|
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||||
|
return common.WrapSQLError(err, query)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
return scanRows(rows, dest)
|
err = scanRows(rows, dest)
|
||||||
|
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
func (p *PgSQLAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||||
tx, err := p.db.BeginTx(ctx, nil)
|
tx, err := p.getDB().BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName}, nil
|
return &PgSQLTxAdapter{tx: tx, driverName: p.driverName, metricsEnabled: p.metricsEnabled}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
|
func (p *PgSQLAdapter) CommitTx(ctx context.Context) error {
|
||||||
@@ -123,12 +196,12 @@ func (p *PgSQLAdapter) RunInTransaction(ctx context.Context, fn func(common.Data
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tx, err := p.db.BeginTx(ctx, nil)
|
tx, err := p.getDB().BeginTx(ctx, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName}
|
adapter := &PgSQLTxAdapter{tx: tx, driverName: p.driverName, metricsEnabled: p.metricsEnabled}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if p := recover(); p != nil {
|
if p := recover(); p != nil {
|
||||||
@@ -174,8 +247,11 @@ type PgSQLSelectQuery struct {
|
|||||||
db *sql.DB
|
db *sql.DB
|
||||||
tx *sql.Tx
|
tx *sql.Tx
|
||||||
model interface{}
|
model interface{}
|
||||||
|
entity string
|
||||||
tableName string
|
tableName string
|
||||||
|
schema string
|
||||||
tableAlias string
|
tableAlias string
|
||||||
|
driverName string // Database driver name (postgres, sqlite, mssql)
|
||||||
columns []string
|
columns []string
|
||||||
columnExprs []string
|
columnExprs []string
|
||||||
whereClauses []string
|
whereClauses []string
|
||||||
@@ -189,13 +265,13 @@ type PgSQLSelectQuery struct {
|
|||||||
args []interface{}
|
args []interface{}
|
||||||
paramCounter int
|
paramCounter int
|
||||||
preloads []preloadConfig
|
preloads []preloadConfig
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
||||||
p.model = model
|
p.model = model
|
||||||
if provider, ok := model.(common.TableNameProvider); ok {
|
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
|
||||||
p.tableName = provider.TableName()
|
p.entity = entityNameFromModel(model, p.tableName)
|
||||||
}
|
|
||||||
if provider, ok := model.(common.TableAliasProvider); ok {
|
if provider, ok := model.(common.TableAliasProvider); ok {
|
||||||
p.tableAlias = provider.TableAlias()
|
p.tableAlias = provider.TableAlias()
|
||||||
}
|
}
|
||||||
@@ -203,7 +279,11 @@ func (p *PgSQLSelectQuery) Model(model interface{}) common.SelectQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
|
func (p *PgSQLSelectQuery) Table(table string) common.SelectQuery {
|
||||||
p.tableName = table
|
// For SQLite, convert "schema.table" to "schema_table"
|
||||||
|
p.schema, p.tableName = parseTableName(table, p.driverName)
|
||||||
|
if p.entity == "" {
|
||||||
|
p.entity = cleanMetricIdentifier(p.tableName)
|
||||||
|
}
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -386,12 +466,12 @@ func (p *PgSQLSelectQuery) buildSQL() string {
|
|||||||
|
|
||||||
// LIMIT clause
|
// LIMIT clause
|
||||||
if p.limit > 0 {
|
if p.limit > 0 {
|
||||||
sb.WriteString(fmt.Sprintf(" LIMIT %d", p.limit))
|
fmt.Fprintf(&sb, " LIMIT %d", p.limit)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OFFSET clause
|
// OFFSET clause
|
||||||
if p.offset > 0 {
|
if p.offset > 0 {
|
||||||
sb.WriteString(fmt.Sprintf(" OFFSET %d", p.offset))
|
fmt.Fprintf(&sb, " OFFSET %d", p.offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String()
|
return sb.String()
|
||||||
@@ -413,6 +493,7 @@ func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err erro
|
|||||||
err = logger.HandlePanic("PgSQLSelectQuery.Scan", r)
|
err = logger.HandlePanic("PgSQLSelectQuery.Scan", r)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
startedAt := time.Now()
|
||||||
|
|
||||||
// Apply preloads that use JOINs
|
// Apply preloads that use JOINs
|
||||||
p.applyJoinPreloads()
|
p.applyJoinPreloads()
|
||||||
@@ -429,17 +510,21 @@ func (p *PgSQLSelectQuery) Scan(ctx context.Context, dest interface{}) (err erro
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("PgSQL SELECT failed: %v", err)
|
logger.Error("PgSQL SELECT failed: %v", err)
|
||||||
return err
|
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||||
|
return common.WrapSQLError(err, query)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
err = scanRows(rows, dest)
|
err = scanRows(rows, dest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply preloads that use separate queries
|
// Apply preloads that use separate queries
|
||||||
return p.applySubqueryPreloads(ctx, dest)
|
err = p.applySubqueryPreloads(ctx, dest)
|
||||||
|
recordQueryMetrics(p.metricsEnabled, "SELECT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
|
func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
|
||||||
@@ -449,15 +534,8 @@ func (p *PgSQLSelectQuery) ScanModel(ctx context.Context) error {
|
|||||||
return p.Scan(ctx, p.model)
|
return p.Scan(ctx, p.model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
|
// countInternal executes the COUNT query and returns the result and the SQL string without recording metrics.
|
||||||
defer func() {
|
func (p *PgSQLSelectQuery) countInternal(ctx context.Context) (rowCount int, querySQL string, retErr error) {
|
||||||
if r := recover(); r != nil {
|
|
||||||
err = logger.HandlePanic("PgSQLSelectQuery.Count", r)
|
|
||||||
count = 0
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Build a COUNT query
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("SELECT COUNT(*) FROM ")
|
sb.WriteString("SELECT COUNT(*) FROM ")
|
||||||
sb.WriteString(p.tableName)
|
sb.WriteString(p.tableName)
|
||||||
@@ -491,10 +569,28 @@ func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
|
|||||||
row = p.db.QueryRowContext(ctx, query, p.args...)
|
row = p.db.QueryRowContext(ctx, query, p.args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = row.Scan(&count)
|
var count int
|
||||||
|
if err := row.Scan(&count); err != nil {
|
||||||
|
return 0, query, err
|
||||||
|
}
|
||||||
|
return count, query, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PgSQLSelectQuery) Count(ctx context.Context) (count int, err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("PgSQLSelectQuery.Count", r)
|
||||||
|
count = 0
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
startedAt := time.Now()
|
||||||
|
var sqlStr string
|
||||||
|
count, sqlStr, err = p.countInternal(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("PgSQL COUNT failed: %v", err)
|
logger.Error("PgSQL COUNT failed: %v", err)
|
||||||
|
err = common.WrapSQLError(err, sqlStr)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(p.metricsEnabled, "COUNT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -505,8 +601,14 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error)
|
|||||||
exists = false
|
exists = false
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
startedAt := time.Now()
|
||||||
count, err := p.Count(ctx)
|
var sqlStr string
|
||||||
|
count, sqlStr, err := p.countInternal(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("PgSQL EXISTS failed: %v", err)
|
||||||
|
err = common.WrapSQLError(err, sqlStr)
|
||||||
|
}
|
||||||
|
recordQueryMetrics(p.metricsEnabled, "EXISTS", p.schema, p.entity, p.tableName, startedAt, err)
|
||||||
return count > 0, err
|
return count > 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -514,26 +616,37 @@ func (p *PgSQLSelectQuery) Exists(ctx context.Context) (exists bool, err error)
|
|||||||
type PgSQLInsertQuery struct {
|
type PgSQLInsertQuery struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
tx *sql.Tx
|
tx *sql.Tx
|
||||||
|
schema string
|
||||||
tableName string
|
tableName string
|
||||||
|
entity string
|
||||||
|
driverName string
|
||||||
values map[string]interface{}
|
values map[string]interface{}
|
||||||
|
valueOrder []string
|
||||||
returning []string
|
returning []string
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
func (p *PgSQLInsertQuery) Model(model interface{}) common.InsertQuery {
|
||||||
if provider, ok := model.(common.TableNameProvider); ok {
|
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
|
||||||
p.tableName = provider.TableName()
|
p.entity = entityNameFromModel(model, p.tableName)
|
||||||
}
|
|
||||||
// Extract values from model using reflection
|
// Extract values from model using reflection
|
||||||
// This is a simplified implementation
|
// This is a simplified implementation
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
|
func (p *PgSQLInsertQuery) Table(table string) common.InsertQuery {
|
||||||
p.tableName = table
|
// For SQLite, convert "schema.table" to "schema_table"
|
||||||
|
p.schema, p.tableName = parseTableName(table, p.driverName)
|
||||||
|
if p.entity == "" {
|
||||||
|
p.entity = cleanMetricIdentifier(p.tableName)
|
||||||
|
}
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLInsertQuery) Value(column string, value interface{}) common.InsertQuery {
|
func (p *PgSQLInsertQuery) Value(column string, value interface{}) common.InsertQuery {
|
||||||
|
if _, exists := p.values[column]; !exists {
|
||||||
|
p.valueOrder = append(p.valueOrder, column)
|
||||||
|
}
|
||||||
p.values[column] = value
|
p.values[column] = value
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
@@ -549,25 +662,27 @@ func (p *PgSQLInsertQuery) Returning(columns ...string) common.InsertQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
startedAt := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
err = logger.HandlePanic("PgSQLInsertQuery.Exec", r)
|
err = logger.HandlePanic("PgSQLInsertQuery.Exec", r)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if len(p.values) == 0 {
|
if len(p.values) == 0 {
|
||||||
return nil, fmt.Errorf("no values to insert")
|
err = fmt.Errorf("no values to insert")
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
columns := make([]string, 0, len(p.values))
|
columns := make([]string, 0, len(p.values))
|
||||||
placeholders := make([]string, 0, len(p.values))
|
placeholders := make([]string, 0, len(p.values))
|
||||||
args := make([]interface{}, 0, len(p.values))
|
args := make([]interface{}, 0, len(p.values))
|
||||||
|
|
||||||
i := 1
|
i := 1
|
||||||
for col, val := range p.values {
|
for _, col := range p.valueOrder {
|
||||||
columns = append(columns, col)
|
columns = append(columns, col)
|
||||||
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
|
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
|
||||||
args = append(args, val)
|
args = append(args, p.values[col])
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -591,39 +706,96 @@ func (p *PgSQLInsertQuery) Exec(ctx context.Context) (res common.Result, err err
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("PgSQL INSERT failed: %v", err)
|
logger.Error("PgSQL INSERT failed: %v", err)
|
||||||
return nil, err
|
return nil, common.WrapSQLError(err, query)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &PgSQLResult{result: result}, nil
|
return &PgSQLResult{result: result}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *PgSQLInsertQuery) Scan(ctx context.Context, dest interface{}) (err error) {
|
||||||
|
startedAt := time.Now()
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = logger.HandlePanic("PgSQLInsertQuery.Scan", r)
|
||||||
|
}
|
||||||
|
recordQueryMetrics(p.metricsEnabled, "INSERT", p.schema, p.entity, p.tableName, startedAt, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if len(p.values) == 0 {
|
||||||
|
return fmt.Errorf("no values to insert")
|
||||||
|
}
|
||||||
|
|
||||||
|
columns := make([]string, 0, len(p.values))
|
||||||
|
placeholders := make([]string, 0, len(p.values))
|
||||||
|
args := make([]interface{}, 0, len(p.values))
|
||||||
|
i := 1
|
||||||
|
for _, col := range p.valueOrder {
|
||||||
|
columns = append(columns, col)
|
||||||
|
placeholders = append(placeholders, fmt.Sprintf("$%d", i))
|
||||||
|
args = append(args, p.values[col])
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||||
|
p.tableName,
|
||||||
|
strings.Join(columns, ", "),
|
||||||
|
strings.Join(placeholders, ", "))
|
||||||
|
|
||||||
|
if len(p.returning) > 0 {
|
||||||
|
query += " RETURNING " + strings.Join(p.returning, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("PgSQL INSERT (Scan): %s [args: %v]", query, args)
|
||||||
|
|
||||||
|
var row *sql.Row
|
||||||
|
if p.tx != nil {
|
||||||
|
row = p.tx.QueryRowContext(ctx, query, args...)
|
||||||
|
} else {
|
||||||
|
row = p.db.QueryRowContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := row.Scan(dest); err != nil {
|
||||||
|
return common.WrapSQLError(err, query)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// PgSQLUpdateQuery implements UpdateQuery for PostgreSQL
|
// PgSQLUpdateQuery implements UpdateQuery for PostgreSQL
|
||||||
type PgSQLUpdateQuery struct {
|
type PgSQLUpdateQuery struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
tx *sql.Tx
|
tx *sql.Tx
|
||||||
|
schema string
|
||||||
tableName string
|
tableName string
|
||||||
|
entity string
|
||||||
|
driverName string
|
||||||
model interface{}
|
model interface{}
|
||||||
sets map[string]interface{}
|
sets map[string]interface{}
|
||||||
|
setOrder []string
|
||||||
whereClauses []string
|
whereClauses []string
|
||||||
args []interface{}
|
args []interface{}
|
||||||
paramCounter int
|
paramCounter int
|
||||||
returning []string
|
returning []string
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
func (p *PgSQLUpdateQuery) Model(model interface{}) common.UpdateQuery {
|
||||||
p.model = model
|
p.model = model
|
||||||
if provider, ok := model.(common.TableNameProvider); ok {
|
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
|
||||||
p.tableName = provider.TableName()
|
p.entity = entityNameFromModel(model, p.tableName)
|
||||||
}
|
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
|
func (p *PgSQLUpdateQuery) Table(table string) common.UpdateQuery {
|
||||||
p.tableName = table
|
// For SQLite, convert "schema.table" to "schema_table"
|
||||||
|
p.schema, p.tableName = parseTableName(table, p.driverName)
|
||||||
|
if p.entity == "" {
|
||||||
|
p.entity = cleanMetricIdentifier(p.tableName)
|
||||||
|
}
|
||||||
if p.model == nil {
|
if p.model == nil {
|
||||||
model, err := modelregistry.GetModelByName(table)
|
model, err := modelregistry.GetModelByName(table)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
p.model = model
|
p.model = model
|
||||||
|
p.entity = entityNameFromModel(model, p.tableName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return p
|
return p
|
||||||
@@ -633,6 +805,9 @@ func (p *PgSQLUpdateQuery) Set(column string, value interface{}) common.UpdateQu
|
|||||||
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
|
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
if _, exists := p.sets[column]; !exists {
|
||||||
|
p.setOrder = append(p.setOrder, column)
|
||||||
|
}
|
||||||
p.sets[column] = value
|
p.sets[column] = value
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
@@ -643,13 +818,23 @@ func (p *PgSQLUpdateQuery) SetMap(values map[string]interface{}) common.UpdateQu
|
|||||||
pkName = reflection.GetPrimaryKeyName(p.model)
|
pkName = reflection.GetPrimaryKeyName(p.model)
|
||||||
}
|
}
|
||||||
|
|
||||||
for column, value := range values {
|
orderedColumns := make([]string, 0, len(values))
|
||||||
|
for column := range values {
|
||||||
|
orderedColumns = append(orderedColumns, column)
|
||||||
|
}
|
||||||
|
sort.Strings(orderedColumns)
|
||||||
|
|
||||||
|
for _, column := range orderedColumns {
|
||||||
|
value := values[column]
|
||||||
if pkName != "" && column == pkName {
|
if pkName != "" && column == pkName {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
|
if p.model != nil && !reflection.IsColumnWritable(p.model, column) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if _, exists := p.sets[column]; !exists {
|
||||||
|
p.setOrder = append(p.setOrder, column)
|
||||||
|
}
|
||||||
p.sets[column] = value
|
p.sets[column] = value
|
||||||
}
|
}
|
||||||
return p
|
return p
|
||||||
@@ -678,24 +863,26 @@ func (p *PgSQLUpdateQuery) replacePlaceholders(query string, argCount int) strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
startedAt := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
err = logger.HandlePanic("PgSQLUpdateQuery.Exec", r)
|
err = logger.HandlePanic("PgSQLUpdateQuery.Exec", r)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(p.metricsEnabled, "UPDATE", p.schema, p.entity, p.tableName, startedAt, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if len(p.sets) == 0 {
|
if len(p.sets) == 0 {
|
||||||
return nil, fmt.Errorf("no values to update")
|
err = fmt.Errorf("no values to update")
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
setClauses := make([]string, 0, len(p.sets))
|
setClauses := make([]string, 0, len(p.sets))
|
||||||
setArgs := make([]interface{}, 0, len(p.sets))
|
setArgs := make([]interface{}, 0, len(p.sets))
|
||||||
|
|
||||||
// SET parameters start at $1
|
// SET parameters start at $1
|
||||||
i := 1
|
i := 1
|
||||||
for col, val := range p.sets {
|
for _, col := range p.setOrder {
|
||||||
setClauses = append(setClauses, fmt.Sprintf("%s = $%d", col, i))
|
setClauses = append(setClauses, fmt.Sprintf("%s = $%d", col, i))
|
||||||
setArgs = append(setArgs, val)
|
setArgs = append(setArgs, p.sets[col])
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -749,7 +936,7 @@ func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err err
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("PgSQL UPDATE failed: %v", err)
|
logger.Error("PgSQL UPDATE failed: %v", err)
|
||||||
return nil, err
|
return nil, common.WrapSQLError(err, query)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &PgSQLResult{result: result}, nil
|
return &PgSQLResult{result: result}, nil
|
||||||
@@ -759,21 +946,28 @@ func (p *PgSQLUpdateQuery) Exec(ctx context.Context) (res common.Result, err err
|
|||||||
type PgSQLDeleteQuery struct {
|
type PgSQLDeleteQuery struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
tx *sql.Tx
|
tx *sql.Tx
|
||||||
|
schema string
|
||||||
tableName string
|
tableName string
|
||||||
|
entity string
|
||||||
|
driverName string
|
||||||
whereClauses []string
|
whereClauses []string
|
||||||
args []interface{}
|
args []interface{}
|
||||||
paramCounter int
|
paramCounter int
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
func (p *PgSQLDeleteQuery) Model(model interface{}) common.DeleteQuery {
|
||||||
if provider, ok := model.(common.TableNameProvider); ok {
|
p.schema, p.tableName = schemaAndTableFromModel(model, p.driverName)
|
||||||
p.tableName = provider.TableName()
|
p.entity = entityNameFromModel(model, p.tableName)
|
||||||
}
|
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
|
func (p *PgSQLDeleteQuery) Table(table string) common.DeleteQuery {
|
||||||
p.tableName = table
|
// For SQLite, convert "schema.table" to "schema_table"
|
||||||
|
p.schema, p.tableName = parseTableName(table, p.driverName)
|
||||||
|
if p.entity == "" {
|
||||||
|
p.entity = cleanMetricIdentifier(p.tableName)
|
||||||
|
}
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -795,10 +989,12 @@ func (p *PgSQLDeleteQuery) replacePlaceholders(query string, argCount int) strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err error) {
|
||||||
|
startedAt := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
err = logger.HandlePanic("PgSQLDeleteQuery.Exec", r)
|
err = logger.HandlePanic("PgSQLDeleteQuery.Exec", r)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(p.metricsEnabled, "DELETE", p.schema, p.entity, p.tableName, startedAt, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
query := fmt.Sprintf("DELETE FROM %s", p.tableName)
|
query := fmt.Sprintf("DELETE FROM %s", p.tableName)
|
||||||
@@ -818,7 +1014,7 @@ func (p *PgSQLDeleteQuery) Exec(ctx context.Context) (res common.Result, err err
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("PgSQL DELETE failed: %v", err)
|
logger.Error("PgSQL DELETE failed: %v", err)
|
||||||
return nil, err
|
return nil, common.WrapSQLError(err, query)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &PgSQLResult{result: result}, nil
|
return &PgSQLResult{result: result}, nil
|
||||||
@@ -848,60 +1044,78 @@ func (p *PgSQLResult) LastInsertId() (int64, error) {
|
|||||||
type PgSQLTxAdapter struct {
|
type PgSQLTxAdapter struct {
|
||||||
tx *sql.Tx
|
tx *sql.Tx
|
||||||
driverName string
|
driverName string
|
||||||
|
metricsEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
|
func (p *PgSQLTxAdapter) NewSelect() common.SelectQuery {
|
||||||
return &PgSQLSelectQuery{
|
return &PgSQLSelectQuery{
|
||||||
tx: p.tx,
|
tx: p.tx,
|
||||||
|
driverName: p.driverName,
|
||||||
columns: []string{"*"},
|
columns: []string{"*"},
|
||||||
args: make([]interface{}, 0),
|
args: make([]interface{}, 0),
|
||||||
|
metricsEnabled: p.metricsEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
|
func (p *PgSQLTxAdapter) NewInsert() common.InsertQuery {
|
||||||
return &PgSQLInsertQuery{
|
return &PgSQLInsertQuery{
|
||||||
tx: p.tx,
|
tx: p.tx,
|
||||||
|
driverName: p.driverName,
|
||||||
values: make(map[string]interface{}),
|
values: make(map[string]interface{}),
|
||||||
|
metricsEnabled: p.metricsEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
|
func (p *PgSQLTxAdapter) NewUpdate() common.UpdateQuery {
|
||||||
return &PgSQLUpdateQuery{
|
return &PgSQLUpdateQuery{
|
||||||
tx: p.tx,
|
tx: p.tx,
|
||||||
|
driverName: p.driverName,
|
||||||
sets: make(map[string]interface{}),
|
sets: make(map[string]interface{}),
|
||||||
args: make([]interface{}, 0),
|
args: make([]interface{}, 0),
|
||||||
whereClauses: make([]string, 0),
|
whereClauses: make([]string, 0),
|
||||||
|
metricsEnabled: p.metricsEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
|
func (p *PgSQLTxAdapter) NewDelete() common.DeleteQuery {
|
||||||
return &PgSQLDeleteQuery{
|
return &PgSQLDeleteQuery{
|
||||||
tx: p.tx,
|
tx: p.tx,
|
||||||
|
driverName: p.driverName,
|
||||||
args: make([]interface{}, 0),
|
args: make([]interface{}, 0),
|
||||||
whereClauses: make([]string, 0),
|
whereClauses: make([]string, 0),
|
||||||
|
metricsEnabled: p.metricsEnabled,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
func (p *PgSQLTxAdapter) Exec(ctx context.Context, query string, args ...interface{}) (common.Result, error) {
|
||||||
|
startedAt := time.Now()
|
||||||
|
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
|
||||||
logger.Debug("PgSQL Tx Exec: %s [args: %v]", query, args)
|
logger.Debug("PgSQL Tx Exec: %s [args: %v]", query, args)
|
||||||
result, err := p.tx.ExecContext(ctx, query, args...)
|
result, err := p.tx.ExecContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("PgSQL Tx Exec failed: %v", err)
|
logger.Error("PgSQL Tx Exec failed: %v", err)
|
||||||
return nil, err
|
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||||
|
return nil, common.WrapSQLError(err, query)
|
||||||
}
|
}
|
||||||
|
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, nil)
|
||||||
return &PgSQLResult{result: result}, nil
|
return &PgSQLResult{result: result}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
func (p *PgSQLTxAdapter) Query(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||||
|
startedAt := time.Now()
|
||||||
|
operation, schema, entity, table := metricTargetFromRawQuery(query, p.driverName)
|
||||||
logger.Debug("PgSQL Tx Query: %s [args: %v]", query, args)
|
logger.Debug("PgSQL Tx Query: %s [args: %v]", query, args)
|
||||||
rows, err := p.tx.QueryContext(ctx, query, args...)
|
rows, err := p.tx.QueryContext(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("PgSQL Tx Query failed: %v", err)
|
logger.Error("PgSQL Tx Query failed: %v", err)
|
||||||
return err
|
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||||
|
return common.WrapSQLError(err, query)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
return scanRows(rows, dest)
|
err = scanRows(rows, dest)
|
||||||
|
recordQueryMetrics(p.metricsEnabled, operation, schema, entity, table, startedAt, err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PgSQLTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
func (p *PgSQLTxAdapter) BeginTx(ctx context.Context) (common.Database, error) {
|
||||||
@@ -1052,9 +1266,9 @@ func (p *PgSQLSelectQuery) executePreloadQuery(ctx context.Context, field reflec
|
|||||||
// Create a new select query for the related table
|
// Create a new select query for the related table
|
||||||
var db common.Database
|
var db common.Database
|
||||||
if p.tx != nil {
|
if p.tx != nil {
|
||||||
db = &PgSQLTxAdapter{tx: p.tx}
|
db = &PgSQLTxAdapter{tx: p.tx, driverName: p.driverName}
|
||||||
} else {
|
} else {
|
||||||
db = &PgSQLAdapter{db: p.db}
|
db = &PgSQLAdapter{db: p.db, driverName: p.driverName}
|
||||||
}
|
}
|
||||||
|
|
||||||
query := db.NewSelect().
|
query := db.NewSelect().
|
||||||
|
|||||||
@@ -0,0 +1,335 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxMetricFallbackEntityLength = 120
|
||||||
|
|
||||||
|
func recordQueryMetrics(enabled bool, operation, schema, entity, table string, startedAt time.Time, err error) {
|
||||||
|
if !enabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics.GetProvider().RecordDBQuery(
|
||||||
|
normalizeMetricOperation(operation),
|
||||||
|
normalizeMetricSchema(schema),
|
||||||
|
normalizeMetricEntity(entity, table),
|
||||||
|
normalizeMetricTable(table),
|
||||||
|
time.Since(startedAt),
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeMetricOperation(operation string) string {
|
||||||
|
operation = strings.ToUpper(strings.TrimSpace(operation))
|
||||||
|
if operation == "" {
|
||||||
|
return "UNKNOWN"
|
||||||
|
}
|
||||||
|
return operation
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeMetricSchema(schema string) string {
|
||||||
|
schema = cleanMetricIdentifier(schema)
|
||||||
|
if schema == "" {
|
||||||
|
return "default"
|
||||||
|
}
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeMetricEntity(entity, table string) string {
|
||||||
|
entity = cleanMetricIdentifier(entity)
|
||||||
|
if entity != "" {
|
||||||
|
return entity
|
||||||
|
}
|
||||||
|
|
||||||
|
table = cleanMetricIdentifier(table)
|
||||||
|
if table != "" {
|
||||||
|
return table
|
||||||
|
}
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeMetricTable(table string) string {
|
||||||
|
table = cleanMetricIdentifier(table)
|
||||||
|
if table == "" {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
return table
|
||||||
|
}
|
||||||
|
|
||||||
|
func entityNameFromModel(model interface{}, table string) string {
|
||||||
|
if model == nil {
|
||||||
|
return cleanMetricIdentifier(table)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil {
|
||||||
|
return cleanMetricIdentifier(table)
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType.Kind() == reflect.Struct && modelType.Name() != "" {
|
||||||
|
return reflection.ToSnakeCase(modelType.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
return cleanMetricIdentifier(table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func schemaAndTableFromModel(model interface{}, driverName string) (schema, table string) {
|
||||||
|
provider, ok := tableNameProviderFromModel(model)
|
||||||
|
if !ok {
|
||||||
|
return "", ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseTableName(provider.TableName(), driverName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tableNameProviderType is cached to avoid repeated reflection on every call.
|
||||||
|
var tableNameProviderType = reflect.TypeOf((*common.TableNameProvider)(nil)).Elem()
|
||||||
|
|
||||||
|
func tableNameProviderFromModel(model interface{}) (common.TableNameProvider, bool) {
|
||||||
|
if model == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider, ok := model.(common.TableNameProvider); ok {
|
||||||
|
return provider, true
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check whether *T implements TableNameProvider before allocating.
|
||||||
|
ptrType := reflect.PointerTo(modelType)
|
||||||
|
if !ptrType.Implements(tableNameProviderType) && !modelType.Implements(tableNameProviderType) {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
modelValue := reflect.New(modelType)
|
||||||
|
if provider, ok := modelValue.Interface().(common.TableNameProvider); ok {
|
||||||
|
return provider, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider, ok := modelValue.Elem().Interface().(common.TableNameProvider); ok {
|
||||||
|
return provider, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func metricTargetFromRawQuery(query, driverName string) (operation, schema, entity, table string) {
|
||||||
|
operation = normalizeMetricOperation(firstQueryKeyword(query))
|
||||||
|
tableRef := tableFromRawQuery(query, operation)
|
||||||
|
if tableRef == "" {
|
||||||
|
return operation, "", fallbackMetricEntityFromQuery(query), "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
schema, table = parseTableName(tableRef, driverName)
|
||||||
|
entity = cleanMetricIdentifier(table)
|
||||||
|
return operation, schema, entity, table
|
||||||
|
}
|
||||||
|
|
||||||
|
func fallbackMetricEntityFromQuery(query string) string {
|
||||||
|
query = sanitizeMetricQueryShape(query)
|
||||||
|
if query == "" {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(query) > maxMetricFallbackEntityLength {
|
||||||
|
return query[:maxMetricFallbackEntityLength-3] + "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeMetricQueryShape(query string) string {
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if query == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var out strings.Builder
|
||||||
|
for i := 0; i < len(query); {
|
||||||
|
if query[i] == '\'' {
|
||||||
|
out.WriteByte('?')
|
||||||
|
i++
|
||||||
|
for i < len(query) {
|
||||||
|
if query[i] == '\'' {
|
||||||
|
if i+1 < len(query) && query[i+1] == '\'' {
|
||||||
|
i += 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
break
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if query[i] == '?' {
|
||||||
|
out.WriteByte('?')
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if query[i] == '$' && i+1 < len(query) && isASCIIDigit(query[i+1]) {
|
||||||
|
out.WriteByte('?')
|
||||||
|
i++
|
||||||
|
for i < len(query) && isASCIIDigit(query[i]) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if query[i] == ':' && (i == 0 || query[i-1] != ':') && i+1 < len(query) && isIdentifierStart(query[i+1]) {
|
||||||
|
out.WriteByte('?')
|
||||||
|
i++
|
||||||
|
for i < len(query) && isIdentifierPart(query[i]) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if query[i] == '@' && (i == 0 || query[i-1] != '@') && i+1 < len(query) && isIdentifierStart(query[i+1]) {
|
||||||
|
out.WriteByte('?')
|
||||||
|
i++
|
||||||
|
for i < len(query) && isIdentifierPart(query[i]) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if startsNumericLiteral(query, i) {
|
||||||
|
out.WriteByte('?')
|
||||||
|
i++
|
||||||
|
for i < len(query) && (isASCIIDigit(query[i]) || query[i] == '.') {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
out.WriteByte(query[i])
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(strings.Fields(out.String()), " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func startsNumericLiteral(query string, idx int) bool {
|
||||||
|
if idx >= len(query) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
start := idx
|
||||||
|
if query[idx] == '-' {
|
||||||
|
if idx+1 >= len(query) || !isASCIIDigit(query[idx+1]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
start++
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isASCIIDigit(query[start]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if idx > 0 && isIdentifierPart(query[idx-1]) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if start+1 < len(query) && query[start] == '0' && (query[start+1] == 'x' || query[start+1] == 'X') {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func isASCIIDigit(ch byte) bool {
|
||||||
|
return ch >= '0' && ch <= '9'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIdentifierStart(ch byte) bool {
|
||||||
|
return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || ch == '_'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIdentifierPart(ch byte) bool {
|
||||||
|
return isIdentifierStart(ch) || isASCIIDigit(ch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstQueryKeyword(query string) string {
|
||||||
|
query = strings.TrimSpace(query)
|
||||||
|
if query == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := strings.Fields(query)
|
||||||
|
if len(fields) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return fields[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func tableFromRawQuery(query, operation string) string {
|
||||||
|
tokens := tokenizeQuery(query)
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch operation {
|
||||||
|
case "SELECT":
|
||||||
|
return tokenAfter(tokens, "FROM")
|
||||||
|
case "INSERT":
|
||||||
|
return tokenAfter(tokens, "INTO")
|
||||||
|
case "UPDATE":
|
||||||
|
return tokenAfter(tokens, "UPDATE")
|
||||||
|
case "DELETE":
|
||||||
|
return tokenAfter(tokens, "FROM")
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenAfter(tokens []string, keyword string) string {
|
||||||
|
for idx, token := range tokens {
|
||||||
|
if strings.EqualFold(token, keyword) && idx+1 < len(tokens) {
|
||||||
|
return cleanMetricIdentifier(tokens[idx+1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func tokenizeQuery(query string) []string {
|
||||||
|
replacer := strings.NewReplacer(
|
||||||
|
"\n", " ",
|
||||||
|
"\t", " ",
|
||||||
|
"(", " ",
|
||||||
|
")", " ",
|
||||||
|
",", " ",
|
||||||
|
)
|
||||||
|
return strings.Fields(replacer.Replace(query))
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanMetricIdentifier(value string) string {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
value = strings.Trim(value, "\"'`[]")
|
||||||
|
value = strings.TrimRight(value, ";")
|
||||||
|
return value
|
||||||
|
}
|
||||||
@@ -0,0 +1,348 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
"github.com/uptrace/bun/dialect/sqlitedialect"
|
||||||
|
"github.com/uptrace/bun/driver/sqliteshim"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/metrics"
|
||||||
|
)
|
||||||
|
|
||||||
|
type queryMetricCall struct {
|
||||||
|
operation string
|
||||||
|
schema string
|
||||||
|
entity string
|
||||||
|
table string
|
||||||
|
}
|
||||||
|
|
||||||
|
type capturingMetricsProvider struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
calls []queryMetricCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *capturingMetricsProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {
|
||||||
|
}
|
||||||
|
func (c *capturingMetricsProvider) IncRequestsInFlight() {}
|
||||||
|
func (c *capturingMetricsProvider) DecRequestsInFlight() {}
|
||||||
|
func (c *capturingMetricsProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.calls = append(c.calls, queryMetricCall{
|
||||||
|
operation: operation,
|
||||||
|
schema: schema,
|
||||||
|
entity: entity,
|
||||||
|
table: table,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
func (c *capturingMetricsProvider) RecordCacheHit(provider string) {}
|
||||||
|
func (c *capturingMetricsProvider) RecordCacheMiss(provider string) {}
|
||||||
|
func (c *capturingMetricsProvider) UpdateCacheSize(provider string, size int64) {
|
||||||
|
}
|
||||||
|
func (c *capturingMetricsProvider) RecordEventPublished(source, eventType string) {}
|
||||||
|
func (c *capturingMetricsProvider) RecordEventProcessed(source, eventType, status string, duration time.Duration) {
|
||||||
|
}
|
||||||
|
func (c *capturingMetricsProvider) UpdateEventQueueSize(size int64) {}
|
||||||
|
func (c *capturingMetricsProvider) RecordPanic(methodName string) {}
|
||||||
|
func (c *capturingMetricsProvider) Handler() http.Handler { return http.NewServeMux() }
|
||||||
|
|
||||||
|
func (c *capturingMetricsProvider) snapshot() []queryMetricCall {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
out := make([]queryMetricCall, len(c.calls))
|
||||||
|
copy(out, c.calls)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
type queryMetricsGormUser struct {
|
||||||
|
ID int `gorm:"primaryKey"`
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (queryMetricsGormUser) TableName() string {
|
||||||
|
return "metrics_gorm_users"
|
||||||
|
}
|
||||||
|
|
||||||
|
type queryMetricsBunUser struct {
|
||||||
|
bun.BaseModel `bun:"table:metrics_bun_users"`
|
||||||
|
ID int64 `bun:"id,pk,autoincrement"`
|
||||||
|
Name string `bun:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPgSQLAdapterRecordsSchemaEntityTableMetrics(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := &capturingMetricsProvider{}
|
||||||
|
prev := metrics.GetProvider()
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
defer metrics.SetProvider(prev)
|
||||||
|
|
||||||
|
mock.ExpectExec(`UPDATE users SET name = \$1 WHERE id = \$2`).
|
||||||
|
WithArgs("Alice", 1).
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
_, err = adapter.NewUpdate().
|
||||||
|
Table("public.users").
|
||||||
|
Set("name", "Alice").
|
||||||
|
Where("id = ?", 1).
|
||||||
|
Exec(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
|
||||||
|
calls := provider.snapshot()
|
||||||
|
require.Len(t, calls, 1)
|
||||||
|
assert.Equal(t, "UPDATE", calls[0].operation)
|
||||||
|
assert.Equal(t, "public", calls[0].schema)
|
||||||
|
assert.Equal(t, "users", calls[0].entity)
|
||||||
|
assert.Equal(t, "users", calls[0].table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPgSQLAdapterDisableMetricsSuppressesEmission(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := &capturingMetricsProvider{}
|
||||||
|
prev := metrics.GetProvider()
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
defer metrics.SetProvider(prev)
|
||||||
|
|
||||||
|
mock.ExpectExec(`DELETE FROM users WHERE id = \$1`).
|
||||||
|
WithArgs(1).
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db).SetMetricsEnabled(false)
|
||||||
|
_, err = adapter.NewDelete().
|
||||||
|
Table("users").
|
||||||
|
Where("id = ?", 1).
|
||||||
|
Exec(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, mock.ExpectationsWereMet())
|
||||||
|
assert.Empty(t, provider.snapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGormAdapterRecordsEntityAndTableMetrics(t *testing.T) {
|
||||||
|
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, db.AutoMigrate(&queryMetricsGormUser{}))
|
||||||
|
require.NoError(t, db.Create(&queryMetricsGormUser{Name: "Alice"}).Error)
|
||||||
|
|
||||||
|
provider := &capturingMetricsProvider{}
|
||||||
|
prev := metrics.GetProvider()
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
defer metrics.SetProvider(prev)
|
||||||
|
|
||||||
|
adapter := NewGormAdapter(db)
|
||||||
|
var users []queryMetricsGormUser
|
||||||
|
err = adapter.NewSelect().Model(&users).Scan(context.Background(), &users)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, users)
|
||||||
|
|
||||||
|
calls := provider.snapshot()
|
||||||
|
require.Len(t, calls, 1)
|
||||||
|
assert.Equal(t, "SELECT", calls[0].operation)
|
||||||
|
assert.Equal(t, "default", calls[0].schema)
|
||||||
|
assert.Equal(t, "query_metrics_gorm_user", calls[0].entity)
|
||||||
|
assert.Equal(t, "metrics_gorm_users", calls[0].table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPgSQLAdapterRecordsErrorMetric(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := &capturingMetricsProvider{}
|
||||||
|
prev := metrics.GetProvider()
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
defer metrics.SetProvider(prev)
|
||||||
|
|
||||||
|
mock.ExpectExec(`INSERT INTO users`).
|
||||||
|
WillReturnError(fmt.Errorf("unique constraint violation"))
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
_, err = adapter.NewInsert().
|
||||||
|
Table("users").
|
||||||
|
Value("name", "Alice").
|
||||||
|
Exec(context.Background())
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
calls := provider.snapshot()
|
||||||
|
require.Len(t, calls, 1)
|
||||||
|
assert.Equal(t, "INSERT", calls[0].operation)
|
||||||
|
assert.Equal(t, "users", calls[0].table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPgSQLAdapterRecordsExistsMetric(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := &capturingMetricsProvider{}
|
||||||
|
prev := metrics.GetProvider()
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
defer metrics.SetProvider(prev)
|
||||||
|
|
||||||
|
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(3))
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
exists, err := adapter.NewSelect().Table("users").Exists(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists)
|
||||||
|
|
||||||
|
calls := provider.snapshot()
|
||||||
|
require.Len(t, calls, 1)
|
||||||
|
assert.Equal(t, "EXISTS", calls[0].operation)
|
||||||
|
assert.Equal(t, "users", calls[0].table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPgSQLAdapterRecordsCountMetric(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := &capturingMetricsProvider{}
|
||||||
|
prev := metrics.GetProvider()
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
defer metrics.SetProvider(prev)
|
||||||
|
|
||||||
|
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`).
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(5))
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
count, err := adapter.NewSelect().Table("users").Count(context.Background())
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 5, count)
|
||||||
|
|
||||||
|
calls := provider.snapshot()
|
||||||
|
require.Len(t, calls, 1)
|
||||||
|
assert.Equal(t, "COUNT", calls[0].operation)
|
||||||
|
assert.Equal(t, "users", calls[0].table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPgSQLAdapterRawExecRecordsMetric(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := &capturingMetricsProvider{}
|
||||||
|
prev := metrics.GetProvider()
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
defer metrics.SetProvider(prev)
|
||||||
|
|
||||||
|
mock.ExpectExec(`UPDATE public\.orders SET status = \$1`).
|
||||||
|
WithArgs("shipped").
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 2))
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
_, err = adapter.Exec(context.Background(), `UPDATE public.orders SET status = $1`, "shipped")
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
calls := provider.snapshot()
|
||||||
|
require.Len(t, calls, 1)
|
||||||
|
assert.Equal(t, "UPDATE", calls[0].operation)
|
||||||
|
assert.Equal(t, "public", calls[0].schema)
|
||||||
|
assert.Equal(t, "orders", calls[0].table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPgSQLAdapterRawExecUsesSQLAsEntityWhenTargetUnknown(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
provider := &capturingMetricsProvider{}
|
||||||
|
prev := metrics.GetProvider()
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
defer metrics.SetProvider(prev)
|
||||||
|
|
||||||
|
query := `select core.c_setuserid($1)`
|
||||||
|
mock.ExpectExec(`select core\.c_setuserid\(\$1\)`).
|
||||||
|
WithArgs(42).
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||||
|
|
||||||
|
adapter := NewPgSQLAdapter(db)
|
||||||
|
_, err = adapter.Exec(context.Background(), query, 42)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
calls := provider.snapshot()
|
||||||
|
require.Len(t, calls, 1)
|
||||||
|
assert.Equal(t, "SELECT", calls[0].operation)
|
||||||
|
assert.Equal(t, "default", calls[0].schema)
|
||||||
|
assert.Equal(t, "select core.c_setuserid(?)", calls[0].entity)
|
||||||
|
assert.Equal(t, "unknown", calls[0].table)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFallbackMetricEntityFromQuerySanitizesAndTruncates(t *testing.T) {
|
||||||
|
entity := fallbackMetricEntityFromQuery(" \n SELECT some_function(1, 'abc', $2, ?, :name, @p1, true, null) \t ")
|
||||||
|
assert.Equal(t, "SELECT some_function(?, ?, ?, ?, ?, ?, true, null)", entity)
|
||||||
|
|
||||||
|
entity = fallbackMetricEntityFromQuery("SELECT price::numeric, id FROM logs WHERE code = -42")
|
||||||
|
assert.Equal(t, "SELECT price::numeric, id FROM logs WHERE code = ?", entity)
|
||||||
|
|
||||||
|
longQuery := "SELECT " + strings.Repeat("x", maxMetricFallbackEntityLength)
|
||||||
|
entity = fallbackMetricEntityFromQuery(longQuery)
|
||||||
|
assert.Len(t, entity, maxMetricFallbackEntityLength)
|
||||||
|
assert.True(t, strings.HasSuffix(entity, "..."))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBunAdapterRecordsEntityAndTableMetrics(t *testing.T) {
|
||||||
|
sqldb, err := sql.Open(sqliteshim.ShimName, "file::memory:?cache=shared")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer sqldb.Close()
|
||||||
|
|
||||||
|
db := bun.NewDB(sqldb, sqlitedialect.New())
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err = db.NewCreateTable().
|
||||||
|
Model((*queryMetricsBunUser)(nil)).
|
||||||
|
IfNotExists().
|
||||||
|
Exec(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = db.NewInsert().Model(&queryMetricsBunUser{Name: "Alice"}).Exec(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
provider := &capturingMetricsProvider{}
|
||||||
|
prev := metrics.GetProvider()
|
||||||
|
metrics.SetProvider(provider)
|
||||||
|
defer metrics.SetProvider(prev)
|
||||||
|
|
||||||
|
adapter := NewBunAdapter(db)
|
||||||
|
var users []queryMetricsBunUser
|
||||||
|
err = adapter.NewSelect().Model(&users).Scan(context.Background(), &users)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, users)
|
||||||
|
|
||||||
|
calls := provider.snapshot()
|
||||||
|
require.Len(t, calls, 1)
|
||||||
|
assert.Equal(t, "SELECT", calls[0].operation)
|
||||||
|
assert.Equal(t, "default", calls[0].schema)
|
||||||
|
assert.Equal(t, "query_metrics_bun_user", calls[0].entity)
|
||||||
|
assert.Equal(t, "metrics_bun_users", calls[0].table)
|
||||||
|
}
|
||||||
@@ -62,9 +62,20 @@ func checkAliasLength(relation string) bool {
|
|||||||
// For example: "public.users" -> ("public", "users")
|
// For example: "public.users" -> ("public", "users")
|
||||||
//
|
//
|
||||||
// "users" -> ("", "users")
|
// "users" -> ("", "users")
|
||||||
func parseTableName(fullTableName string) (schema, table string) {
|
//
|
||||||
|
// For SQLite, schema.table is translated to schema_table since SQLite doesn't support schemas
|
||||||
|
// in the same way as PostgreSQL/MSSQL
|
||||||
|
func parseTableName(fullTableName, driverName string) (schema, table string) {
|
||||||
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
if idx := strings.LastIndex(fullTableName, "."); idx != -1 {
|
||||||
return fullTableName[:idx], fullTableName[idx+1:]
|
schema = fullTableName[:idx]
|
||||||
|
table = fullTableName[idx+1:]
|
||||||
|
|
||||||
|
// For SQLite, convert schema.table to schema_table
|
||||||
|
if driverName == "sqlite" || driverName == "sqlite3" {
|
||||||
|
table = schema + "_" + table
|
||||||
|
schema = ""
|
||||||
|
}
|
||||||
|
return schema, table
|
||||||
}
|
}
|
||||||
return "", fullTableName
|
return "", fullTableName
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package common
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -261,3 +262,48 @@ func GetTableNameFromModel(model interface{}) string {
|
|||||||
// This handles cases like "MasterTaskItem" -> "mastertaskitem"
|
// This handles cases like "MasterTaskItem" -> "mastertaskitem"
|
||||||
return strings.ToLower(modelType.Name())
|
return strings.ToLower(modelType.Name())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConvertSliceForBun converts []interface{} values to PostgreSQL array literal strings.
|
||||||
|
// BUN's fallback appender for []interface{} is JSON encoding, which produces "[]" —
|
||||||
|
// invalid PostgreSQL array syntax. PostgreSQL expects "{}" for empty arrays and
|
||||||
|
// "{elem1,elem2}" for non-empty ones. All other value types are returned unchanged.
|
||||||
|
func ConvertSliceForBun(value interface{}) interface{} {
|
||||||
|
arr, ok := value.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
if len(arr) == 0 {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
parts := make([]string, len(arr))
|
||||||
|
for i, elem := range arr {
|
||||||
|
switch e := elem.(type) {
|
||||||
|
case string:
|
||||||
|
needsQuote := e == "" || strings.ContainsAny(e, `,"\\{}`+"\t\n\r ")
|
||||||
|
if needsQuote {
|
||||||
|
e = strings.ReplaceAll(e, `\`, `\\`)
|
||||||
|
e = strings.ReplaceAll(e, `"`, `""`)
|
||||||
|
parts[i] = `"` + e + `"`
|
||||||
|
} else {
|
||||||
|
parts[i] = e
|
||||||
|
}
|
||||||
|
case float64:
|
||||||
|
if e == float64(int64(e)) {
|
||||||
|
parts[i] = strconv.FormatInt(int64(e), 10)
|
||||||
|
} else {
|
||||||
|
parts[i] = strconv.FormatFloat(e, 'f', -1, 64)
|
||||||
|
}
|
||||||
|
case bool:
|
||||||
|
if e {
|
||||||
|
parts[i] = "t"
|
||||||
|
} else {
|
||||||
|
parts[i] = "f"
|
||||||
|
}
|
||||||
|
case nil:
|
||||||
|
parts[i] = "NULL"
|
||||||
|
default:
|
||||||
|
parts[i] = fmt.Sprintf("%v", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "{" + strings.Join(parts, ",") + "}"
|
||||||
|
}
|
||||||
|
|||||||
@@ -106,3 +106,66 @@ func TestExtractTagValue(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertSliceForBun(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input interface{}
|
||||||
|
expected interface{}
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty slice produces empty pg array",
|
||||||
|
input: []interface{}{},
|
||||||
|
expected: "{}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string elements",
|
||||||
|
input: []interface{}{"a", "b", "c"},
|
||||||
|
expected: "{a,b,c}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string element needing quotes",
|
||||||
|
input: []interface{}{"hello world", "ok"},
|
||||||
|
expected: `{"hello world",ok}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string with comma",
|
||||||
|
input: []interface{}{"a,b"},
|
||||||
|
expected: `{"a,b"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "integer elements (JSON float64)",
|
||||||
|
input: []interface{}{float64(1), float64(2), float64(3)},
|
||||||
|
expected: "{1,2,3}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bool elements",
|
||||||
|
input: []interface{}{true, false},
|
||||||
|
expected: "{t,f}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil input passthrough",
|
||||||
|
input: nil,
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string input passthrough",
|
||||||
|
input: "hello",
|
||||||
|
expected: "hello",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "int input passthrough",
|
||||||
|
input: 42,
|
||||||
|
expected: 42,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ConvertSliceForBun(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ConvertSliceForBun(%v) = %v; want %v", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ type InsertQuery interface {
|
|||||||
|
|
||||||
// Execution
|
// Execution
|
||||||
Exec(ctx context.Context) (Result, error)
|
Exec(ctx context.Context) (Result, error)
|
||||||
|
Scan(ctx context.Context, dest interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateQuery interface for building UPDATE queries
|
// UpdateQuery interface for building UPDATE queries
|
||||||
|
|||||||
+82
-115
@@ -98,8 +98,8 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Filter regularData to only include fields that exist in the model
|
// Filter regularData to only include fields that exist in the model,
|
||||||
// Use MapToStruct to validate and filter fields
|
// and translate JSON keys to their actual database column names.
|
||||||
regularData = p.filterValidFields(regularData, model)
|
regularData = p.filterValidFields(regularData, model)
|
||||||
|
|
||||||
// Inject parent IDs for foreign key resolution
|
// Inject parent IDs for foreign key resolution
|
||||||
@@ -125,6 +125,13 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
result.AffectedRows = 1
|
result.AffectedRows = 1
|
||||||
result.Data = regularData
|
result.Data = regularData
|
||||||
|
|
||||||
|
// Re-select the inserted row so result.Data reflects DB-generated defaults.
|
||||||
|
if row, err := p.processSelect(ctx, tableName, id); err != nil {
|
||||||
|
logger.Warn("Select after insert failed: table=%s, id=%v, error=%v", tableName, id, err)
|
||||||
|
} else if len(row) > 0 {
|
||||||
|
result.Data = row
|
||||||
|
}
|
||||||
|
|
||||||
// Process child relations after parent insert (to get parent ID)
|
// Process child relations after parent insert (to get parent ID)
|
||||||
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
if err := p.processChildRelations(ctx, "insert", id, relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
||||||
logger.Error("Failed to process child relations after insert: table=%s, parentID=%v, relations=%+v, error=%v", tableName, id, relationFields, err)
|
logger.Error("Failed to process child relations after insert: table=%s, parentID=%v, relations=%+v, error=%v", tableName, id, relationFields, err)
|
||||||
@@ -134,8 +141,12 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
logger.Debug("Skipping insert for %s - no data columns besides _request", tableName)
|
logger.Debug("Skipping insert for %s - no data columns besides _request", tableName)
|
||||||
}
|
}
|
||||||
|
|
||||||
case "update":
|
case "update", "change":
|
||||||
// Only perform update if we have data to update
|
// Only perform update if we have data to update
|
||||||
|
if reflection.IsEmptyValue(data[pkName]) {
|
||||||
|
logger.Warn("Skipping update for %s - no primary key", tableName)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
if hasData {
|
if hasData {
|
||||||
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
rows, err := p.processUpdate(ctx, regularData, tableName, data[pkName])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -146,9 +157,16 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
result.AffectedRows = rows
|
result.AffectedRows = rows
|
||||||
result.Data = regularData
|
result.Data = regularData
|
||||||
|
|
||||||
|
// Re-select the updated row so result.Data reflects current DB state.
|
||||||
|
if row, err := p.processSelect(ctx, tableName, result.ID); err != nil {
|
||||||
|
logger.Warn("Select after update failed: table=%s, id=%v, error=%v", tableName, result.ID, err)
|
||||||
|
} else if len(row) > 0 {
|
||||||
|
result.Data = row
|
||||||
|
}
|
||||||
|
|
||||||
// Process child relations for update
|
// Process child relations for update
|
||||||
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
if err := p.processChildRelations(ctx, "update", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
||||||
logger.Error("Failed to process child relations after update: table=%s, parentID=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err)
|
logger.Error("Failed to process child relations after update: table=%s, parentID=%v, relations=%+v, error=%v", tableName, data[pkName], regularData, err)
|
||||||
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -157,10 +175,15 @@ func (p *NestedCUDProcessor) ProcessNestedCUD(
|
|||||||
}
|
}
|
||||||
|
|
||||||
case "delete":
|
case "delete":
|
||||||
|
if reflection.IsEmptyValue(data[pkName]) {
|
||||||
|
logger.Warn("Skipping delete for %s - no primary key", tableName)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Process child relations first (for referential integrity)
|
// Process child relations first (for referential integrity)
|
||||||
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
if err := p.processChildRelations(ctx, "delete", data[pkName], relationFields, result.RelationData, modelType, parentIDs); err != nil {
|
||||||
logger.Error("Failed to process child relations before delete: table=%s, id=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err)
|
logger.Error("Failed to process child relations before delete: table=%s, id=%v, relations=%+v, error=%v", tableName, data[pkName], relationFields, err)
|
||||||
return nil, fmt.Errorf("failed to process child relations before delete: %w", err)
|
return nil, fmt.Errorf("failed to process child relations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rows, err := p.processDelete(ctx, tableName, data[pkName])
|
rows, err := p.processDelete(ctx, tableName, data[pkName])
|
||||||
@@ -191,14 +214,15 @@ func (p *NestedCUDProcessor) extractCRUDRequest(data map[string]interface{}) str
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterValidFields filters input data to only include fields that exist in the model
|
// filterValidFields filters input data to only include fields that exist in the model,
|
||||||
// Uses reflection.MapToStruct to validate fields and extract only those that match the model
|
// and translates JSON key names to their actual database column names.
|
||||||
|
// For example, a field tagged `json:"_changed_date" bun:"changed_date"` will be
|
||||||
|
// included in the result as "changed_date", not "_changed_date".
|
||||||
func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} {
|
func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, model interface{}) map[string]interface{} {
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new instance of the model to use with MapToStruct
|
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array) {
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
@@ -208,25 +232,16 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode
|
|||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new instance of the model
|
// Build a mapping from JSON key -> DB column name for all writable fields.
|
||||||
tempModel := reflect.New(modelType).Interface()
|
// This both validates which fields belong to the model and translates their names
|
||||||
|
// to the correct column names for use in SQL insert/update queries.
|
||||||
|
jsonToDBCol := reflection.BuildJSONToDBColumnMap(modelType)
|
||||||
|
|
||||||
// Use MapToStruct to map the data - this will only map valid fields
|
|
||||||
err := reflection.MapToStruct(data, tempModel)
|
|
||||||
if err != nil {
|
|
||||||
logger.Debug("Error mapping data to model: %v", err)
|
|
||||||
return data
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract the mapped fields back into a map
|
|
||||||
// This effectively filters out any fields that don't exist in the model
|
|
||||||
filteredData := make(map[string]interface{})
|
filteredData := make(map[string]interface{})
|
||||||
tempModelValue := reflect.ValueOf(tempModel).Elem()
|
|
||||||
|
|
||||||
for key, value := range data {
|
for key, value := range data {
|
||||||
// Check if the field was successfully mapped
|
dbColName, exists := jsonToDBCol[key]
|
||||||
if fieldWasMapped(tempModelValue, modelType, key) {
|
if exists {
|
||||||
filteredData[key] = value
|
filteredData[dbColName] = value
|
||||||
} else {
|
} else {
|
||||||
logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType)
|
logger.Debug("Skipping invalid field '%s' - not found in model %v", key, modelType)
|
||||||
}
|
}
|
||||||
@@ -235,97 +250,44 @@ func (p *NestedCUDProcessor) filterValidFields(data map[string]interface{}, mode
|
|||||||
return filteredData
|
return filteredData
|
||||||
}
|
}
|
||||||
|
|
||||||
// fieldWasMapped checks if a field with the given key was mapped to the model
|
// injectForeignKeys injects parent IDs into data for foreign key fields.
|
||||||
func fieldWasMapped(modelValue reflect.Value, modelType reflect.Type, key string) bool {
|
// data is expected to be keyed by DB column names (as returned by filterValidFields).
|
||||||
// Look for the field by JSON tag or field name
|
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
|
||||||
field := modelType.Field(i)
|
|
||||||
|
|
||||||
// Skip unexported fields
|
|
||||||
if !field.IsExported() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check JSON tag
|
|
||||||
jsonTag := field.Tag.Get("json")
|
|
||||||
if jsonTag != "" && jsonTag != "-" {
|
|
||||||
parts := strings.Split(jsonTag, ",")
|
|
||||||
if len(parts) > 0 && parts[0] == key {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check bun tag
|
|
||||||
bunTag := field.Tag.Get("bun")
|
|
||||||
if bunTag != "" && bunTag != "-" {
|
|
||||||
if colName := reflection.ExtractColumnFromBunTag(bunTag); colName == key {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check gorm tag
|
|
||||||
gormTag := field.Tag.Get("gorm")
|
|
||||||
if gormTag != "" && gormTag != "-" {
|
|
||||||
if colName := reflection.ExtractColumnFromGormTag(gormTag); colName == key {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check lowercase field name
|
|
||||||
if strings.EqualFold(field.Name, key) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle embedded structs recursively
|
|
||||||
if field.Anonymous {
|
|
||||||
fieldType := field.Type
|
|
||||||
if fieldType.Kind() == reflect.Ptr {
|
|
||||||
fieldType = fieldType.Elem()
|
|
||||||
}
|
|
||||||
if fieldType.Kind() == reflect.Struct {
|
|
||||||
embeddedValue := modelValue.Field(i)
|
|
||||||
if embeddedValue.Kind() == reflect.Ptr {
|
|
||||||
if embeddedValue.IsNil() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
embeddedValue = embeddedValue.Elem()
|
|
||||||
}
|
|
||||||
if fieldWasMapped(embeddedValue, fieldType, key) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// injectForeignKeys injects parent IDs into data for foreign key fields
|
|
||||||
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
|
func (p *NestedCUDProcessor) injectForeignKeys(data map[string]interface{}, modelType reflect.Type, parentIDs map[string]interface{}) {
|
||||||
if len(parentIDs) == 0 {
|
if len(parentIDs) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterate through model fields to find foreign key fields
|
pkCol := reflection.GetPrimaryKeyName(reflect.New(modelType).Interface())
|
||||||
|
|
||||||
|
for parentKey, parentID := range parentIDs {
|
||||||
|
dbColNames := reflection.GetForeignKeyColumn(modelType, parentKey)
|
||||||
|
|
||||||
|
if len(dbColNames) == 0 {
|
||||||
|
// No explicit tag found — fall back to naming convention by scanning scalar fields.
|
||||||
for i := 0; i < modelType.NumField(); i++ {
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
field := modelType.Field(i)
|
field := modelType.Field(i)
|
||||||
jsonTag := field.Tag.Get("json")
|
jsonName := strings.Split(field.Tag.Get("json"), ",")[0]
|
||||||
jsonName := strings.Split(jsonTag, ",")[0]
|
if strings.EqualFold(jsonName, "rid"+parentKey) ||
|
||||||
|
strings.EqualFold(jsonName, "rid_"+parentKey) ||
|
||||||
// Check if this field is a foreign key and we have a parent ID for it
|
strings.EqualFold(jsonName, "id_"+parentKey) ||
|
||||||
// Common patterns: DepartmentID, ManagerID, ProjectID, etc.
|
strings.EqualFold(jsonName, parentKey+"_id") ||
|
||||||
for parentKey, parentID := range parentIDs {
|
|
||||||
// Match field name patterns like "department_id" with parent key "department"
|
|
||||||
if strings.EqualFold(jsonName, parentKey+"_id") ||
|
|
||||||
strings.EqualFold(jsonName, parentKey+"id") ||
|
strings.EqualFold(jsonName, parentKey+"id") ||
|
||||||
strings.EqualFold(field.Name, parentKey+"ID") {
|
strings.EqualFold(field.Name, parentKey+"ID") {
|
||||||
// Only inject if not already present
|
dbColNames = []string{reflection.GetColumnName(field)}
|
||||||
if _, exists := data[jsonName]; !exists {
|
break
|
||||||
logger.Debug("Injecting foreign key: %s = %v", jsonName, parentID)
|
|
||||||
data[jsonName] = parentID
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, dbColName := range dbColNames {
|
||||||
|
if pkCol != "" && strings.EqualFold(dbColName, pkCol) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := data[dbColName]; !exists {
|
||||||
|
logger.Debug("Injecting foreign key: %s = %v", dbColName, parentID)
|
||||||
|
data[dbColName] = parentID
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,28 +302,33 @@ func (p *NestedCUDProcessor) processInsert(
|
|||||||
query := p.db.NewInsert().Table(tableName)
|
query := p.db.NewInsert().Table(tableName)
|
||||||
|
|
||||||
for key, value := range data {
|
for key, value := range data {
|
||||||
query = query.Value(key, value)
|
query = query.Value(key, ConvertSliceForBun(value))
|
||||||
}
|
}
|
||||||
pkName := reflection.GetPrimaryKeyName(tableName)
|
pkName := reflection.GetPrimaryKeyName(tableName)
|
||||||
// Add RETURNING clause to get the inserted ID
|
|
||||||
query = query.Returning(pkName)
|
query = query.Returning(pkName)
|
||||||
|
|
||||||
result, err := query.Exec(ctx)
|
var id interface{}
|
||||||
if err != nil {
|
if err := query.Scan(ctx, &id); err != nil {
|
||||||
logger.Error("Insert execution failed: table=%s, data=%+v, error=%v", tableName, data, err)
|
logger.Error("Insert execution failed: table=%s, data=%+v, error=%v", tableName, data, err)
|
||||||
return nil, fmt.Errorf("insert exec failed: %w", err)
|
return nil, fmt.Errorf("insert exec failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to get the ID
|
logger.Debug("Insert successful, ID: %v", id)
|
||||||
var id interface{}
|
return id, nil
|
||||||
if lastID, err := result.LastInsertId(); err == nil && lastID > 0 {
|
|
||||||
id = lastID
|
|
||||||
} else if data[pkName] != nil {
|
|
||||||
id = data[pkName]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Insert successful, ID: %v, rows affected: %d", id, result.RowsAffected())
|
// processSelect fetches the row identified by id from tableName into a flat map.
|
||||||
return id, nil
|
// Used to populate result.Data with the actual DB state after insert/update.
|
||||||
|
func (p *NestedCUDProcessor) processSelect(ctx context.Context, tableName string, id interface{}) (map[string]interface{}, error) {
|
||||||
|
pkName := reflection.GetPrimaryKeyName(tableName)
|
||||||
|
var row map[string]interface{}
|
||||||
|
if err := p.db.NewSelect().
|
||||||
|
Table(tableName).
|
||||||
|
Where(fmt.Sprintf("%s = ?", QuoteIdent(pkName)), id).
|
||||||
|
Scan(ctx, &row); err != nil {
|
||||||
|
return nil, fmt.Errorf("select after write failed: %w", err)
|
||||||
|
}
|
||||||
|
return row, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// processUpdate handles update operation
|
// processUpdate handles update operation
|
||||||
|
|||||||
@@ -101,12 +101,18 @@ func (m *mockInsertQuery) Value(column string, value interface{}) InsertQuery {
|
|||||||
func (m *mockInsertQuery) OnConflict(action string) InsertQuery { return m }
|
func (m *mockInsertQuery) OnConflict(action string) InsertQuery { return m }
|
||||||
func (m *mockInsertQuery) Returning(columns ...string) InsertQuery { return m }
|
func (m *mockInsertQuery) Returning(columns ...string) InsertQuery { return m }
|
||||||
func (m *mockInsertQuery) Exec(ctx context.Context) (Result, error) {
|
func (m *mockInsertQuery) Exec(ctx context.Context) (Result, error) {
|
||||||
// Record the insert call
|
|
||||||
m.db.insertCalls = append(m.db.insertCalls, m.values)
|
m.db.insertCalls = append(m.db.insertCalls, m.values)
|
||||||
m.db.lastID++
|
m.db.lastID++
|
||||||
return &mockResult{lastID: m.db.lastID, rowsAffected: 1}, nil
|
return &mockResult{lastID: m.db.lastID, rowsAffected: 1}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockInsertQuery) Scan(ctx context.Context, dest interface{}) error {
|
||||||
|
m.db.insertCalls = append(m.db.insertCalls, m.values)
|
||||||
|
m.db.lastID++
|
||||||
|
reflect.ValueOf(dest).Elem().Set(reflect.ValueOf(m.db.lastID))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Mock UpdateQuery
|
// Mock UpdateQuery
|
||||||
type mockUpdateQuery struct {
|
type mockUpdateQuery struct {
|
||||||
db *mockDatabase
|
db *mockDatabase
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package common
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -58,6 +59,38 @@ func IsSQLExpression(cond string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reEmptyCompMid matches a simple column comparison with an empty RHS that is immediately
|
||||||
|
// followed by AND/OR (only whitespace between the operator and the next keyword).
|
||||||
|
// Removing the match leaves the preceding AND/OR connector intact.
|
||||||
|
// Example: "cond1 and col = \n and cond2" → "cond1 and cond2"
|
||||||
|
var reEmptyCompMid = regexp.MustCompile(`(?i)[\w.]+\s*(?:=|<>|!=|>=|<=|>|<)\s+(?:and|or)\s+`)
|
||||||
|
|
||||||
|
// reEmptyCompEnd matches AND/OR + a simple column comparison with an empty RHS at the end
|
||||||
|
// of the string (or sub-clause).
|
||||||
|
// Example: "cond1 and col = " → "cond1"
|
||||||
|
var reEmptyCompEnd = regexp.MustCompile(`(?i)\s+(?:and|or)\s+[\w.]+\s*(?:=|<>|!=|>=|<=|>|<)\s*$`)
|
||||||
|
|
||||||
|
// stripEmptyComparisonClauses removes comparison conditions that have no right-hand side
|
||||||
|
// value from a raw SQL string. Operates on the whole string so it also cleans up conditions
|
||||||
|
// inside subqueries, not just top-level AND splits.
|
||||||
|
func stripEmptyComparisonClauses(sql string) string {
|
||||||
|
sql = reEmptyCompMid.ReplaceAllLiteralString(sql, "")
|
||||||
|
sql = reEmptyCompEnd.ReplaceAllLiteralString(sql, "")
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasEmptyRHS returns true when a condition ends with a comparison operator and has no
|
||||||
|
// right-hand side value — e.g., "col = ", "com.rid_parent = ", "col >= ".
|
||||||
|
func hasEmptyRHS(cond string) bool {
|
||||||
|
cond = strings.TrimSpace(cond)
|
||||||
|
for _, op := range []string{"<>", "!=", ">=", "<=", "=", ">", "<"} {
|
||||||
|
if strings.HasSuffix(cond, op) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
|
// IsTrivialCondition checks if a condition is trivial and always evaluates to true
|
||||||
// These conditions should be removed from WHERE clauses as they have no filtering effect
|
// These conditions should be removed from WHERE clauses as they have no filtering effect
|
||||||
func IsTrivialCondition(cond string) bool {
|
func IsTrivialCondition(cond string) bool {
|
||||||
@@ -146,6 +179,14 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Strip comparison conditions with empty RHS throughout the SQL string (including
|
||||||
|
// inside subqueries), before condition splitting.
|
||||||
|
where = stripEmptyComparisonClauses(where)
|
||||||
|
if where == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
where = strings.TrimSpace(where)
|
||||||
|
|
||||||
// Check if the original clause has outer parentheses and contains OR operators
|
// Check if the original clause has outer parentheses and contains OR operators
|
||||||
// If so, we need to preserve the outer parentheses to prevent OR logic from escaping
|
// If so, we need to preserve the outer parentheses to prevent OR logic from escaping
|
||||||
hasOuterParens := false
|
hasOuterParens := false
|
||||||
@@ -167,16 +208,17 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build a set of allowed table prefixes (main table + preloaded relations)
|
// Build a set of allowed table prefixes (main table + preloaded relations)
|
||||||
|
// Keys are stored lowercase for case-insensitive matching
|
||||||
allowedPrefixes := make(map[string]bool)
|
allowedPrefixes := make(map[string]bool)
|
||||||
if tableName != "" {
|
if tableName != "" {
|
||||||
allowedPrefixes[tableName] = true
|
allowedPrefixes[strings.ToLower(tableName)] = true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add preload relation names as allowed prefixes
|
// Add preload relation names as allowed prefixes
|
||||||
if len(options) > 0 && options[0] != nil {
|
if len(options) > 0 && options[0] != nil {
|
||||||
for pi := range options[0].Preload {
|
for pi := range options[0].Preload {
|
||||||
if options[0].Preload[pi].Relation != "" {
|
if options[0].Preload[pi].Relation != "" {
|
||||||
allowedPrefixes[options[0].Preload[pi].Relation] = true
|
allowedPrefixes[strings.ToLower(options[0].Preload[pi].Relation)] = true
|
||||||
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
logger.Debug("Added preload relation '%s' as allowed table prefix", options[0].Preload[pi].Relation)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -184,7 +226,7 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
// Add join aliases as allowed prefixes
|
// Add join aliases as allowed prefixes
|
||||||
for _, alias := range options[0].JoinAliases {
|
for _, alias := range options[0].JoinAliases {
|
||||||
if alias != "" {
|
if alias != "" {
|
||||||
allowedPrefixes[alias] = true
|
allowedPrefixes[strings.ToLower(alias)] = true
|
||||||
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
|
logger.Debug("Added join alias '%s' as allowed table prefix", alias)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -210,14 +252,20 @@ func SanitizeWhereClause(where string, tableName string, options ...*RequestOpti
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip conditions with no right-hand side value (e.g. "col = " with empty value)
|
||||||
|
if hasEmptyRHS(condToCheck) {
|
||||||
|
logger.Debug("Removing condition with empty value: '%s'", cond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// If tableName is provided and the condition HAS a table prefix, check if it's correct
|
// If tableName is provided and the condition HAS a table prefix, check if it's correct
|
||||||
if tableName != "" && hasTablePrefix(condToCheck) {
|
if tableName != "" && hasTablePrefix(condToCheck) {
|
||||||
// Extract the current prefix and column name
|
// Extract the current prefix and column name
|
||||||
currentPrefix, columnName := extractTableAndColumn(condToCheck)
|
currentPrefix, columnName := extractTableAndColumn(condToCheck)
|
||||||
|
|
||||||
if currentPrefix != "" && columnName != "" {
|
if currentPrefix != "" && columnName != "" {
|
||||||
// Check if the prefix is allowed (main table or preload relation)
|
// Check if the prefix is allowed (main table or preload relation) - case-insensitive
|
||||||
if !allowedPrefixes[currentPrefix] {
|
if !allowedPrefixes[strings.ToLower(currentPrefix)] {
|
||||||
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
|
// Prefix is not in the allowed list - only fix if it's a valid column in the main table
|
||||||
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
if validColumns == nil || isValidColumn(columnName, validColumns) {
|
||||||
// Replace the incorrect prefix with the correct main table name
|
// Replace the incorrect prefix with the correct main table name
|
||||||
@@ -925,3 +973,36 @@ func extractLeftSideOfComparison(cond string) string {
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FilterValueToSlice converts a filter value to []interface{} for use with IN operators.
|
||||||
|
// JSON-decoded arrays arrive as []interface{}, but typed slices (e.g. []string) also work.
|
||||||
|
// Returns a single-element slice if the value is not a slice type.
|
||||||
|
func FilterValueToSlice(v interface{}) []interface{} {
|
||||||
|
if v == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rv := reflect.ValueOf(v)
|
||||||
|
if rv.Kind() == reflect.Slice {
|
||||||
|
result := make([]interface{}, rv.Len())
|
||||||
|
for i := 0; i < rv.Len(); i++ {
|
||||||
|
result[i] = rv.Index(i).Interface()
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
return []interface{}{v}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildInCondition builds a parameterized IN condition from a filter value.
|
||||||
|
// Returns the condition string (e.g. "col IN (?,?)") and the individual values as args.
|
||||||
|
// Returns ("", nil) if the value is empty or not a slice.
|
||||||
|
func BuildInCondition(column string, v interface{}) (query string, args []interface{}) {
|
||||||
|
values := FilterValueToSlice(v)
|
||||||
|
if len(values) == 0 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
placeholders := make([]string, len(values))
|
||||||
|
for i := range values {
|
||||||
|
placeholders[i] = "?"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")), values
|
||||||
|
}
|
||||||
|
|||||||
@@ -134,6 +134,30 @@ func TestSanitizeWhereClause(t *testing.T) {
|
|||||||
tableName: "apiprovider",
|
tableName: "apiprovider",
|
||||||
expected: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
|
expected: "apiprovider.type in ('softphone') AND (apiprovider.rid_apiprovider in (select l.rid_apiprovider from core.apiproviderlink l where l.rid_hub = 2576))",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "empty RHS stripped mid-clause",
|
||||||
|
where: "com.tableprefix = 'tcli' and com.rid_parent = \n and com.status = 'Active'",
|
||||||
|
tableName: "",
|
||||||
|
expected: "com.tableprefix = 'tcli' AND com.status = 'Active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty RHS stripped at end of clause",
|
||||||
|
where: "com.tableprefix = 'tcli' and com.rid_parent =",
|
||||||
|
tableName: "",
|
||||||
|
expected: "com.tableprefix = 'tcli'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-empty value not stripped",
|
||||||
|
where: "com.tableprefix = 'tcli' and com.rid_parent = 123 and com.status = 'Active'",
|
||||||
|
tableName: "",
|
||||||
|
expected: "com.tableprefix = 'tcli' AND com.rid_parent = 123 AND com.status = 'Active'",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty RHS inside subquery stripped",
|
||||||
|
where: "a = 1 and b in (select x from t where c.rid = \n and d = 2)",
|
||||||
|
tableName: "",
|
||||||
|
expected: "a = 1 AND b in (select x from t where d = 2)",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@@ -1,5 +1,23 @@
|
|||||||
package common
|
package common
|
||||||
|
|
||||||
|
// SQLError wraps a database error together with the SQL that caused it,
|
||||||
|
// so callers can surface the query in API error responses for easier debugging.
|
||||||
|
type SQLError struct {
|
||||||
|
Err error
|
||||||
|
SQL string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *SQLError) Error() string { return e.Err.Error() }
|
||||||
|
func (e *SQLError) Unwrap() error { return e.Err }
|
||||||
|
|
||||||
|
// WrapSQLError wraps err with the given SQL. If err is nil it returns nil.
|
||||||
|
func WrapSQLError(err error, sql string) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &SQLError{Err: err, SQL: sql}
|
||||||
|
}
|
||||||
|
|
||||||
type RequestBody struct {
|
type RequestBody struct {
|
||||||
Operation string `json:"operation"`
|
Operation string `json:"operation"`
|
||||||
Data interface{} `json:"data"`
|
Data interface{} `json:"data"`
|
||||||
@@ -104,6 +122,7 @@ type APIError struct {
|
|||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Details interface{} `json:"details,omitempty"`
|
Details interface{} `json:"details,omitempty"`
|
||||||
Detail string `json:"detail,omitempty"`
|
Detail string `json:"detail,omitempty"`
|
||||||
|
SQL string `json:"sql,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Column struct {
|
type Column struct {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package common
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -43,7 +44,7 @@ func (v *ColumnValidator) buildValidColumns() {
|
|||||||
for i := 0; i < modelType.NumField(); i++ {
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
field := modelType.Field(i)
|
field := modelType.Field(i)
|
||||||
|
|
||||||
if !field.IsExported() {
|
if !field.IsExported() || field.Anonymous {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,6 +126,16 @@ func (v *ColumnValidator) IsValidColumn(column string) bool {
|
|||||||
return v.ValidateColumn(column) == nil
|
return v.ValidateColumn(column) == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Columns returns all valid column names known to this validator
|
||||||
|
func (v *ColumnValidator) Columns() []string {
|
||||||
|
cols := make([]string, 0, len(v.validColumns))
|
||||||
|
for col := range v.validColumns {
|
||||||
|
cols = append(cols, col)
|
||||||
|
}
|
||||||
|
sort.Strings(cols)
|
||||||
|
return cols
|
||||||
|
}
|
||||||
|
|
||||||
// FilterValidColumns filters a list of columns, returning only valid ones
|
// FilterValidColumns filters a list of columns, returning only valid ones
|
||||||
// Logs warnings for any invalid columns
|
// Logs warnings for any invalid columns
|
||||||
func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
|
func (v *ColumnValidator) FilterValidColumns(columns []string) []string {
|
||||||
@@ -224,7 +235,19 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
// Filter Filter columns
|
// Filter Filter columns
|
||||||
validFilters := make([]FilterOption, 0, len(options.Filters))
|
validFilters := make([]FilterOption, 0, len(options.Filters))
|
||||||
for _, filter := range options.Filters {
|
for _, filter := range options.Filters {
|
||||||
if v.IsValidColumn(filter.Column) {
|
if strings.EqualFold(filter.Column, "all") {
|
||||||
|
allCols := v.Columns()
|
||||||
|
if len(filtered.Columns) > 0 {
|
||||||
|
allCols = filtered.Columns
|
||||||
|
}
|
||||||
|
for _, col := range allCols {
|
||||||
|
expanded := filter
|
||||||
|
expanded.Column = col
|
||||||
|
expanded.LogicOperator = "OR"
|
||||||
|
|
||||||
|
validFilters = append(validFilters, expanded)
|
||||||
|
}
|
||||||
|
} else if v.IsValidColumn(filter.Column) {
|
||||||
validFilters = append(validFilters, filter)
|
validFilters = append(validFilters, filter)
|
||||||
} else {
|
} else {
|
||||||
logger.Warn("Invalid column in filter '%s' removed", filter.Column)
|
logger.Warn("Invalid column in filter '%s' removed", filter.Column)
|
||||||
@@ -266,11 +289,24 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
|
|
||||||
// Filter Preload columns
|
// Filter Preload columns
|
||||||
validPreloads := make([]PreloadOption, 0, len(options.Preload))
|
validPreloads := make([]PreloadOption, 0, len(options.Preload))
|
||||||
|
modelType := reflect.TypeOf(v.model)
|
||||||
|
if modelType != nil && modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
for idx := range options.Preload {
|
for idx := range options.Preload {
|
||||||
preload := options.Preload[idx]
|
preload := options.Preload[idx]
|
||||||
filteredPreload := preload
|
filteredPreload := preload
|
||||||
filteredPreload.Columns = v.FilterValidColumns(preload.Columns)
|
|
||||||
filteredPreload.OmitColumns = v.FilterValidColumns(preload.OmitColumns)
|
// Use the related model's validator for preload columns/filters/sorts
|
||||||
|
preloadValidator := v
|
||||||
|
if modelType != nil {
|
||||||
|
if relInfo := GetRelationshipInfo(modelType, preload.Relation); relInfo != nil && relInfo.RelatedModel != nil {
|
||||||
|
preloadValidator = NewColumnValidator(relInfo.RelatedModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredPreload.Columns = preloadValidator.FilterValidColumns(preload.Columns)
|
||||||
|
filteredPreload.OmitColumns = preloadValidator.FilterValidColumns(preload.OmitColumns)
|
||||||
|
|
||||||
// Preserve SqlJoins and JoinAliases for preloads with custom joins
|
// Preserve SqlJoins and JoinAliases for preloads with custom joins
|
||||||
filteredPreload.SqlJoins = preload.SqlJoins
|
filteredPreload.SqlJoins = preload.SqlJoins
|
||||||
@@ -279,7 +315,7 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
// Filter preload filters
|
// Filter preload filters
|
||||||
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
validPreloadFilters := make([]FilterOption, 0, len(preload.Filters))
|
||||||
for _, filter := range preload.Filters {
|
for _, filter := range preload.Filters {
|
||||||
if v.IsValidColumn(filter.Column) {
|
if preloadValidator.IsValidColumn(filter.Column) {
|
||||||
validPreloadFilters = append(validPreloadFilters, filter)
|
validPreloadFilters = append(validPreloadFilters, filter)
|
||||||
} else {
|
} else {
|
||||||
// Check if the filter column references a joined table alias
|
// Check if the filter column references a joined table alias
|
||||||
@@ -302,7 +338,7 @@ func (v *ColumnValidator) FilterRequestOptions(options RequestOptions) RequestOp
|
|||||||
// Filter preload sort columns
|
// Filter preload sort columns
|
||||||
validPreloadSorts := make([]SortOption, 0, len(preload.Sort))
|
validPreloadSorts := make([]SortOption, 0, len(preload.Sort))
|
||||||
for _, sort := range preload.Sort {
|
for _, sort := range preload.Sort {
|
||||||
if v.IsValidColumn(sort.Column) {
|
if preloadValidator.IsValidColumn(sort.Column) {
|
||||||
validPreloadSorts = append(validPreloadSorts, sort)
|
validPreloadSorts = append(validPreloadSorts, sort)
|
||||||
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
} else if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||||
// Allow sort by expression/subquery, but validate for security
|
// Allow sort by expression/subquery, but validate for security
|
||||||
|
|||||||
@@ -464,3 +464,84 @@ func TestFilterRequestOptions_WithSortExpressions(t *testing.T) {
|
|||||||
t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column)
|
t.Errorf("Expected third sort to be 'name', got '%s'", filtered.Sort[2].Column)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RelatedModel is used by PreloadParentModel to test preload column validation.
|
||||||
|
type RelatedModel struct {
|
||||||
|
RelatedID int64 `bun:"related_id,pk"`
|
||||||
|
Functionname string `bun:"functionname"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PreloadParentModel has a has-one relation to RelatedModel. The json tag on
|
||||||
|
// the relation field is the name used in x-preload headers.
|
||||||
|
type PreloadParentModel struct {
|
||||||
|
ID int64 `bun:"id,pk"`
|
||||||
|
Name string `bun:"name"`
|
||||||
|
RELATED *RelatedModel `json:"RELATED" bun:"rel:has-one,join:id=related_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFilterRequestOptions_PreloadColumnsValidatedAgainstRelatedModel verifies
|
||||||
|
// that preload columns are validated against the related model's fields, not the
|
||||||
|
// parent model's fields. This is the fix for the bug where specifying a column
|
||||||
|
// that exists only on the relation (e.g. "functionname") was incorrectly filtered
|
||||||
|
// out because it doesn't exist on the parent model.
|
||||||
|
func TestFilterRequestOptions_PreloadColumnsValidatedAgainstRelatedModel(t *testing.T) {
|
||||||
|
validator := NewColumnValidator(PreloadParentModel{})
|
||||||
|
|
||||||
|
options := RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{
|
||||||
|
Relation: "RELATED",
|
||||||
|
// "functionname" exists on RelatedModel but NOT on PreloadParentModel.
|
||||||
|
// "name" exists on PreloadParentModel but NOT on RelatedModel.
|
||||||
|
// "nonexistent" exists on neither.
|
||||||
|
Columns: []string{"functionname", "name", "nonexistent"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := validator.FilterRequestOptions(options)
|
||||||
|
|
||||||
|
if len(filtered.Preload) != 1 {
|
||||||
|
t.Fatalf("Expected 1 preload, got %d", len(filtered.Preload))
|
||||||
|
}
|
||||||
|
|
||||||
|
cols := filtered.Preload[0].Columns
|
||||||
|
// Only "functionname" should survive: it belongs to RelatedModel.
|
||||||
|
if len(cols) != 1 {
|
||||||
|
t.Errorf("Expected 1 preload column, got %d: %v", len(cols), cols)
|
||||||
|
}
|
||||||
|
if len(cols) > 0 && cols[0] != "functionname" {
|
||||||
|
t.Errorf("Expected preload column 'functionname', got '%s'", cols[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFilterRequestOptions_PreloadColumnsParentModelFallback verifies that when
|
||||||
|
// a preload relation is not found on the parent model, column validation falls
|
||||||
|
// back to the parent model's validator (no panic, no silent pass-through).
|
||||||
|
func TestFilterRequestOptions_PreloadColumnsParentModelFallback(t *testing.T) {
|
||||||
|
validator := NewColumnValidator(PreloadParentModel{})
|
||||||
|
|
||||||
|
options := RequestOptions{
|
||||||
|
Preload: []PreloadOption{
|
||||||
|
{
|
||||||
|
Relation: "UNKNOWN_RELATION",
|
||||||
|
Columns: []string{"id", "functionname"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := validator.FilterRequestOptions(options)
|
||||||
|
|
||||||
|
if len(filtered.Preload) != 1 {
|
||||||
|
t.Fatalf("Expected 1 preload, got %d", len(filtered.Preload))
|
||||||
|
}
|
||||||
|
|
||||||
|
cols := filtered.Preload[0].Columns
|
||||||
|
// Falls back to parent model: only "id" is valid on PreloadParentModel.
|
||||||
|
if len(cols) != 1 {
|
||||||
|
t.Errorf("Expected 1 preload column (fallback to parent), got %d: %v", len(cols), cols)
|
||||||
|
}
|
||||||
|
if len(cols) > 0 && cols[0] != "id" {
|
||||||
|
t.Errorf("Expected preload column 'id', got '%s'", cols[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ A comprehensive database connection manager for Go that provides centralized man
|
|||||||
- **GORM** - Popular Go ORM
|
- **GORM** - Popular Go ORM
|
||||||
- **Native** - Standard library `*sql.DB`
|
- **Native** - Standard library `*sql.DB`
|
||||||
- All three share the same underlying connection pool
|
- All three share the same underlying connection pool
|
||||||
|
- **SQLite Schema Translation**: Automatic conversion of `schema.table` to `schema_table` for SQLite compatibility
|
||||||
- **Configuration-Driven**: YAML configuration with Viper integration
|
- **Configuration-Driven**: YAML configuration with Viper integration
|
||||||
- **Production-Ready Features**:
|
- **Production-Ready Features**:
|
||||||
- Automatic health checks and reconnection
|
- Automatic health checks and reconnection
|
||||||
@@ -179,6 +180,35 @@ if err != nil {
|
|||||||
rows, err := nativeDB.QueryContext(ctx, "SELECT * FROM users WHERE active = $1", true)
|
rows, err := nativeDB.QueryContext(ctx, "SELECT * FROM users WHERE active = $1", true)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Cross-Database Example with SQLite
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Same model works across all databases
|
||||||
|
type User struct {
|
||||||
|
ID int `bun:"id,pk"`
|
||||||
|
Username string `bun:"username"`
|
||||||
|
Email string `bun:"email"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (User) TableName() string {
|
||||||
|
return "auth.users"
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostgreSQL connection
|
||||||
|
pgConn, _ := mgr.Get("primary")
|
||||||
|
pgDB, _ := pgConn.Bun()
|
||||||
|
var pgUsers []User
|
||||||
|
pgDB.NewSelect().Model(&pgUsers).Scan(ctx)
|
||||||
|
// Executes: SELECT * FROM auth.users
|
||||||
|
|
||||||
|
// SQLite connection
|
||||||
|
sqliteConn, _ := mgr.Get("cache-db")
|
||||||
|
sqliteDB, _ := sqliteConn.Bun()
|
||||||
|
var sqliteUsers []User
|
||||||
|
sqliteDB.NewSelect().Model(&sqliteUsers).Scan(ctx)
|
||||||
|
// Executes: SELECT * FROM auth_users (schema.table → schema_table)
|
||||||
|
```
|
||||||
|
|
||||||
#### Use MongoDB
|
#### Use MongoDB
|
||||||
|
|
||||||
```go
|
```go
|
||||||
@@ -368,6 +398,37 @@ Providers handle:
|
|||||||
- Connection statistics
|
- Connection statistics
|
||||||
- Connection cleanup
|
- Connection cleanup
|
||||||
|
|
||||||
|
### SQLite Schema Handling
|
||||||
|
|
||||||
|
SQLite doesn't support schemas in the same way as PostgreSQL or MSSQL. To ensure compatibility when using models designed for multi-schema databases:
|
||||||
|
|
||||||
|
**Automatic Translation**: When a table name contains a schema prefix (e.g., `myschema.mytable`), it is automatically converted to `myschema_mytable` for SQLite databases.
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Model definition (works across all databases)
|
||||||
|
func (User) TableName() string {
|
||||||
|
return "auth.users" // PostgreSQL/MSSQL: "auth"."users"
|
||||||
|
// SQLite: "auth_users"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query execution
|
||||||
|
db.NewSelect().Model(&User{}).Scan(ctx)
|
||||||
|
// PostgreSQL/MSSQL: SELECT * FROM auth.users
|
||||||
|
// SQLite: SELECT * FROM auth_users
|
||||||
|
```
|
||||||
|
|
||||||
|
**How it Works**:
|
||||||
|
- Bun, GORM, and Native adapters detect the driver type
|
||||||
|
- `parseTableName()` automatically translates schema.table → schema_table for SQLite
|
||||||
|
- Translation happens transparently in all database operations (SELECT, INSERT, UPDATE, DELETE)
|
||||||
|
- Preload and relation queries are also handled automatically
|
||||||
|
|
||||||
|
**Benefits**:
|
||||||
|
- Write database-agnostic code
|
||||||
|
- Use the same models across PostgreSQL, MSSQL, and SQLite
|
||||||
|
- No conditional logic needed in your application
|
||||||
|
- Schema separation maintained through naming convention in SQLite
|
||||||
|
|
||||||
## Best Practices
|
## Best Practices
|
||||||
|
|
||||||
1. **Use Named Connections**: Be explicit about which database you're accessing
|
1. **Use Named Connections**: Be explicit about which database you're accessing
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ type Connection interface {
|
|||||||
Bun() (*bun.DB, error)
|
Bun() (*bun.DB, error)
|
||||||
GORM() (*gorm.DB, error)
|
GORM() (*gorm.DB, error)
|
||||||
Native() (*sql.DB, error)
|
Native() (*sql.DB, error)
|
||||||
|
DB() (*sql.DB, error)
|
||||||
|
|
||||||
// Common Database interface (for SQL databases)
|
// Common Database interface (for SQL databases)
|
||||||
Database() (common.Database, error)
|
Database() (common.Database, error)
|
||||||
@@ -224,6 +225,11 @@ func (c *sqlConnection) Native() (*sql.DB, error) {
|
|||||||
return c.nativeDB, nil
|
return c.nativeDB, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DB returns the underlying *sql.DB connection
|
||||||
|
func (c *sqlConnection) DB() (*sql.DB, error) {
|
||||||
|
return c.Native()
|
||||||
|
}
|
||||||
|
|
||||||
// Bun returns a Bun ORM instance wrapping the native connection
|
// Bun returns a Bun ORM instance wrapping the native connection
|
||||||
func (c *sqlConnection) Bun() (*bun.DB, error) {
|
func (c *sqlConnection) Bun() (*bun.DB, error) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
@@ -353,6 +359,42 @@ func (c *sqlConnection) Stats() *ConnectionStats {
|
|||||||
return stats
|
return stats
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *sqlConnection) reconnectForAdapter() error {
|
||||||
|
timeout := c.config.ConnectTimeout
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = 10 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
return c.Reconnect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlConnection) reopenNativeForAdapter() (*sql.DB, error) {
|
||||||
|
if err := c.reconnectForAdapter(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Native()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlConnection) reopenBunForAdapter() (*bun.DB, error) {
|
||||||
|
if err := c.reconnectForAdapter(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Bun()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sqlConnection) reopenGORMForAdapter() (*gorm.DB, error) {
|
||||||
|
if err := c.reconnectForAdapter(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.GORM()
|
||||||
|
}
|
||||||
|
|
||||||
// getBunAdapter returns or creates the Bun adapter
|
// getBunAdapter returns or creates the Bun adapter
|
||||||
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
@@ -385,7 +427,9 @@ func (c *sqlConnection) getBunAdapter() (common.Database, error) {
|
|||||||
c.bunDB = bun.NewDB(native, dialect)
|
c.bunDB = bun.NewDB(native, dialect)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.bunAdapter = database.NewBunAdapter(c.bunDB)
|
c.bunAdapter = database.NewBunAdapter(c.bunDB).
|
||||||
|
WithDBFactory(c.reopenBunForAdapter).
|
||||||
|
SetMetricsEnabled(c.config.EnableMetrics)
|
||||||
return c.bunAdapter, nil
|
return c.bunAdapter, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -426,7 +470,9 @@ func (c *sqlConnection) getGORMAdapter() (common.Database, error) {
|
|||||||
c.gormDB = db
|
c.gormDB = db
|
||||||
}
|
}
|
||||||
|
|
||||||
c.gormAdapter = database.NewGormAdapter(c.gormDB)
|
c.gormAdapter = database.NewGormAdapter(c.gormDB).
|
||||||
|
WithDBFactory(c.reopenGORMForAdapter).
|
||||||
|
SetMetricsEnabled(c.config.EnableMetrics)
|
||||||
return c.gormAdapter, nil
|
return c.gormAdapter, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -467,11 +513,17 @@ func (c *sqlConnection) getNativeAdapter() (common.Database, error) {
|
|||||||
// Create a native adapter based on database type
|
// Create a native adapter based on database type
|
||||||
switch c.dbType {
|
switch c.dbType {
|
||||||
case DatabaseTypePostgreSQL:
|
case DatabaseTypePostgreSQL:
|
||||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
|
||||||
|
WithDBFactory(c.reopenNativeForAdapter).
|
||||||
|
SetMetricsEnabled(c.config.EnableMetrics)
|
||||||
case DatabaseTypeSQLite:
|
case DatabaseTypeSQLite:
|
||||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
|
||||||
|
WithDBFactory(c.reopenNativeForAdapter).
|
||||||
|
SetMetricsEnabled(c.config.EnableMetrics)
|
||||||
case DatabaseTypeMSSQL:
|
case DatabaseTypeMSSQL:
|
||||||
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType))
|
c.nativeAdapter = database.NewPgSQLAdapter(c.nativeDB, string(c.dbType)).
|
||||||
|
WithDBFactory(c.reopenNativeForAdapter).
|
||||||
|
SetMetricsEnabled(c.config.EnableMetrics)
|
||||||
default:
|
default:
|
||||||
return nil, ErrUnsupportedDatabase
|
return nil, ErrUnsupportedDatabase
|
||||||
}
|
}
|
||||||
@@ -645,6 +697,11 @@ func (c *mongoConnection) Native() (*sql.DB, error) {
|
|||||||
return nil, ErrNotSQLDatabase
|
return nil, ErrNotSQLDatabase
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DB returns an error for MongoDB connections
|
||||||
|
func (c *mongoConnection) DB() (*sql.DB, error) {
|
||||||
|
return nil, ErrNotSQLDatabase
|
||||||
|
}
|
||||||
|
|
||||||
// Database returns an error for MongoDB connections
|
// Database returns an error for MongoDB connections
|
||||||
func (c *mongoConnection) Database() (common.Database, error) {
|
func (c *mongoConnection) Database() (common.Database, error) {
|
||||||
return nil, ErrNotSQLDatabase
|
return nil, ErrNotSQLDatabase
|
||||||
|
|||||||
@@ -4,8 +4,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/dbmanager/providers"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewConnectionFromDB(t *testing.T) {
|
func TestNewConnectionFromDB(t *testing.T) {
|
||||||
@@ -208,3 +213,157 @@ func TestNewConnectionFromDB_PostgreSQL(t *testing.T) {
|
|||||||
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
|
t.Errorf("Expected type DatabaseTypePostgreSQL, got '%s'", conn.Type())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDatabaseNativeAdapterReconnectFactory(t *testing.T) {
|
||||||
|
conn := newSQLConnection("test-native", DatabaseTypeSQLite, ConnectionConfig{
|
||||||
|
Name: "test-native",
|
||||||
|
Type: DatabaseTypeSQLite,
|
||||||
|
FilePath: ":memory:",
|
||||||
|
DefaultORM: string(ORMTypeNative),
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}, providers.NewSQLiteProvider())
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := conn.Connect(ctx); err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
db, err := conn.Database()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get database adapter: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter, ok := db.(*database.PgSQLAdapter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected PgSQLAdapter, got %T", db)
|
||||||
|
}
|
||||||
|
|
||||||
|
underlyingBefore, ok := adapter.GetUnderlyingDB().(*sql.DB)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected underlying *sql.DB, got %T", adapter.GetUnderlyingDB())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := underlyingBefore.Close(); err != nil {
|
||||||
|
t.Fatalf("Failed to close underlying database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
|
||||||
|
t.Fatalf("Expected native adapter to reconnect, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
underlyingAfter, ok := adapter.GetUnderlyingDB().(*sql.DB)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected reconnected *sql.DB, got %T", adapter.GetUnderlyingDB())
|
||||||
|
}
|
||||||
|
|
||||||
|
if underlyingAfter == underlyingBefore {
|
||||||
|
t.Fatal("Expected adapter to swap to a fresh *sql.DB after reconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseBunAdapterReconnectFactory(t *testing.T) {
|
||||||
|
conn := newSQLConnection("test-bun", DatabaseTypeSQLite, ConnectionConfig{
|
||||||
|
Name: "test-bun",
|
||||||
|
Type: DatabaseTypeSQLite,
|
||||||
|
FilePath: ":memory:",
|
||||||
|
DefaultORM: string(ORMTypeBun),
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}, providers.NewSQLiteProvider())
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := conn.Connect(ctx); err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
db, err := conn.Database()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get database adapter: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter, ok := db.(*database.BunAdapter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected BunAdapter, got %T", db)
|
||||||
|
}
|
||||||
|
|
||||||
|
underlyingBefore, ok := adapter.GetUnderlyingDB().(interface{ Close() error })
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected underlying Bun DB with Close method, got %T", adapter.GetUnderlyingDB())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := underlyingBefore.Close(); err != nil {
|
||||||
|
t.Fatalf("Failed to close underlying Bun database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec(ctx, "SELECT 1"); err != nil {
|
||||||
|
t.Fatalf("Expected Bun adapter to reconnect, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
underlyingAfter := adapter.GetUnderlyingDB()
|
||||||
|
if underlyingAfter == underlyingBefore {
|
||||||
|
t.Fatal("Expected adapter to swap to a fresh Bun DB after reconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDatabaseGormAdapterReconnectFactory(t *testing.T) {
|
||||||
|
conn := newSQLConnection("test-gorm", DatabaseTypeSQLite, ConnectionConfig{
|
||||||
|
Name: "test-gorm",
|
||||||
|
Type: DatabaseTypeSQLite,
|
||||||
|
FilePath: ":memory:",
|
||||||
|
DefaultORM: string(ORMTypeGORM),
|
||||||
|
ConnectTimeout: 2 * time.Second,
|
||||||
|
}, providers.NewSQLiteProvider())
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := conn.Connect(ctx); err != nil {
|
||||||
|
t.Fatalf("Failed to connect: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
db, err := conn.Database()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get database adapter: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
adapter, ok := db.(*database.GormAdapter)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected GormAdapter, got %T", db)
|
||||||
|
}
|
||||||
|
|
||||||
|
gormBefore, ok := adapter.GetUnderlyingDB().(*gorm.DB)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected underlying *gorm.DB, got %T", adapter.GetUnderlyingDB())
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlBefore, err := gormBefore.DB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get underlying *sql.DB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sqlBefore.Close(); err != nil {
|
||||||
|
t.Fatalf("Failed to close underlying database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count, err := db.NewSelect().Table("sqlite_master").Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected GORM query builder to reconnect, got error: %v", err)
|
||||||
|
}
|
||||||
|
if count < 0 {
|
||||||
|
t.Fatalf("Expected non-negative count, got %d", count)
|
||||||
|
}
|
||||||
|
|
||||||
|
gormAfter, ok := adapter.GetUnderlyingDB().(*gorm.DB)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected reconnected *gorm.DB, got %T", adapter.GetUnderlyingDB())
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlAfter, err := gormAfter.DB()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get reconnected *sql.DB: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sqlAfter == sqlBefore {
|
||||||
|
t.Fatal("Expected GORM adapter to use a fresh *sql.DB after reconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ package dbmanager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -366,8 +368,11 @@ func (m *connectionManager) performHealthCheck() {
|
|||||||
"connection", item.name,
|
"connection", item.name,
|
||||||
"error", err)
|
"error", err)
|
||||||
|
|
||||||
// Attempt reconnection if enabled
|
// Only reconnect when the client handle itself is closed/disconnected.
|
||||||
if m.config.EnableAutoReconnect {
|
// For transient database restarts or network blips, *sql.DB can recover
|
||||||
|
// on its own; forcing Close()+Connect() here invalidates any cached ORM
|
||||||
|
// wrappers and callers that still hold the old handle.
|
||||||
|
if m.config.EnableAutoReconnect && shouldReconnectAfterHealthCheck(err) {
|
||||||
logger.Info("Attempting reconnection: connection=%s", item.name)
|
logger.Info("Attempting reconnection: connection=%s", item.name)
|
||||||
if err := item.conn.Reconnect(ctx); err != nil {
|
if err := item.conn.Reconnect(ctx); err != nil {
|
||||||
logger.Error("Reconnection failed",
|
logger.Error("Reconnection failed",
|
||||||
@@ -376,7 +381,21 @@ func (m *connectionManager) performHealthCheck() {
|
|||||||
} else {
|
} else {
|
||||||
logger.Info("Reconnection successful: connection=%s", item.name)
|
logger.Info("Reconnection successful: connection=%s", item.name)
|
||||||
}
|
}
|
||||||
|
} else if m.config.EnableAutoReconnect {
|
||||||
|
logger.Info("Skipping reconnect for transient health check failure: connection=%s", item.name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldReconnectAfterHealthCheck(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, ErrConnectionClosed) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Contains(err.Error(), "sql: database is closed")
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,12 +3,38 @@ package dbmanager
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type healthCheckStubConnection struct {
|
||||||
|
healthErr error
|
||||||
|
reconnectCalls int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *healthCheckStubConnection) Name() string { return "stub" }
|
||||||
|
func (c *healthCheckStubConnection) Type() DatabaseType { return DatabaseTypePostgreSQL }
|
||||||
|
func (c *healthCheckStubConnection) Bun() (*bun.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||||
|
func (c *healthCheckStubConnection) GORM() (*gorm.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||||
|
func (c *healthCheckStubConnection) Native() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||||
|
func (c *healthCheckStubConnection) DB() (*sql.DB, error) { return nil, fmt.Errorf("not implemented") }
|
||||||
|
func (c *healthCheckStubConnection) Database() (common.Database, error) { return nil, fmt.Errorf("not implemented") }
|
||||||
|
func (c *healthCheckStubConnection) MongoDB() (*mongo.Client, error) { return nil, fmt.Errorf("not implemented") }
|
||||||
|
func (c *healthCheckStubConnection) Connect(ctx context.Context) error { return nil }
|
||||||
|
func (c *healthCheckStubConnection) Close() error { return nil }
|
||||||
|
func (c *healthCheckStubConnection) HealthCheck(ctx context.Context) error { return c.healthErr }
|
||||||
|
func (c *healthCheckStubConnection) Reconnect(ctx context.Context) error { c.reconnectCalls++; return nil }
|
||||||
|
func (c *healthCheckStubConnection) Stats() *ConnectionStats { return &ConnectionStats{} }
|
||||||
|
|
||||||
func TestBackgroundHealthChecker(t *testing.T) {
|
func TestBackgroundHealthChecker(t *testing.T) {
|
||||||
// Create a SQLite in-memory database
|
// Create a SQLite in-memory database
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
@@ -224,3 +250,41 @@ func TestManagerStatsAfterClose(t *testing.T) {
|
|||||||
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
|
t.Errorf("Expected 0 total connections after close, got %d", stats.TotalConnections)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPerformHealthCheckSkipsReconnectForTransientFailures(t *testing.T) {
|
||||||
|
conn := &healthCheckStubConnection{
|
||||||
|
healthErr: fmt.Errorf("connection 'primary' health check: dial tcp 127.0.0.1:5432: connect: connection refused"),
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := &connectionManager{
|
||||||
|
connections: map[string]Connection{"primary": conn},
|
||||||
|
config: ManagerConfig{
|
||||||
|
EnableAutoReconnect: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.performHealthCheck()
|
||||||
|
|
||||||
|
if conn.reconnectCalls != 0 {
|
||||||
|
t.Fatalf("expected no reconnect attempts for transient health failure, got %d", conn.reconnectCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPerformHealthCheckReconnectsClosedConnections(t *testing.T) {
|
||||||
|
conn := &healthCheckStubConnection{
|
||||||
|
healthErr: NewConnectionError("primary", "health check", fmt.Errorf("sql: database is closed")),
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := &connectionManager{
|
||||||
|
connections: map[string]Connection{"primary": conn},
|
||||||
|
config: ManagerConfig{
|
||||||
|
EnableAutoReconnect: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.performHealthCheck()
|
||||||
|
|
||||||
|
if conn.reconnectCalls != 1 {
|
||||||
|
t.Fatalf("expected reconnect attempt for closed database handle, got %d", conn.reconnectCalls)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package providers_test
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
|
"github.com/bitechdev/ResolveSpec/pkg/dbmanager"
|
||||||
@@ -29,14 +30,14 @@ func ExamplePostgresListener_basic() {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
if err := provider.Connect(ctx, cfg); err != nil {
|
if err := provider.Connect(ctx, cfg); err != nil {
|
||||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
log.Fatalf("Failed to connect: %v", err)
|
||||||
}
|
}
|
||||||
defer provider.Close()
|
defer provider.Close()
|
||||||
|
|
||||||
// Get listener
|
// Get listener
|
||||||
listener, err := provider.GetListener(ctx)
|
listener, err := provider.GetListener(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
log.Fatalf("Failed to get listener: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subscribe to a channel with a handler
|
// Subscribe to a channel with a handler
|
||||||
@@ -44,13 +45,13 @@ func ExamplePostgresListener_basic() {
|
|||||||
fmt.Printf("Received notification on %s: %s\n", channel, payload)
|
fmt.Printf("Received notification on %s: %s\n", channel, payload)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Failed to listen: %v", err))
|
log.Fatalf("Failed to listen: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send a notification
|
// Send a notification
|
||||||
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
|
err = listener.Notify(ctx, "user_events", `{"event":"user_created","user_id":123}`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Failed to notify: %v", err))
|
log.Fatalf("Failed to notify: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for notification to be processed
|
// Wait for notification to be processed
|
||||||
@@ -58,7 +59,7 @@ func ExamplePostgresListener_basic() {
|
|||||||
|
|
||||||
// Unsubscribe from the channel
|
// Unsubscribe from the channel
|
||||||
if err := listener.Unlisten("user_events"); err != nil {
|
if err := listener.Unlisten("user_events"); err != nil {
|
||||||
panic(fmt.Sprintf("Failed to unlisten: %v", err))
|
log.Fatalf("Failed to unlisten: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,13 +81,13 @@ func ExamplePostgresListener_multipleChannels() {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
if err := provider.Connect(ctx, cfg); err != nil {
|
if err := provider.Connect(ctx, cfg); err != nil {
|
||||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
log.Fatalf("Failed to connect: %v", err)
|
||||||
}
|
}
|
||||||
defer provider.Close()
|
defer provider.Close()
|
||||||
|
|
||||||
listener, err := provider.GetListener(ctx)
|
listener, err := provider.GetListener(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
log.Fatalf("Failed to get listener: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Listen to multiple channels
|
// Listen to multiple channels
|
||||||
@@ -97,7 +98,7 @@ func ExamplePostgresListener_multipleChannels() {
|
|||||||
fmt.Printf("[%s] %s\n", ch, payload)
|
fmt.Printf("[%s] %s\n", ch, payload)
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Failed to listen on %s: %v", channel, err))
|
log.Fatalf("Failed to listen on %s: %v", channel, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,14 +141,14 @@ func ExamplePostgresListener_withDBManager() {
|
|||||||
|
|
||||||
provider := providers.NewPostgresProvider()
|
provider := providers.NewPostgresProvider()
|
||||||
if err := provider.Connect(ctx, cfg); err != nil {
|
if err := provider.Connect(ctx, cfg); err != nil {
|
||||||
panic(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
defer provider.Close()
|
defer provider.Close()
|
||||||
|
|
||||||
// Get listener
|
// Get listener
|
||||||
listener, err := provider.GetListener(ctx)
|
listener, err := provider.GetListener(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Subscribe to application events
|
// Subscribe to application events
|
||||||
@@ -186,13 +187,13 @@ func ExamplePostgresListener_errorHandling() {
|
|||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
if err := provider.Connect(ctx, cfg); err != nil {
|
if err := provider.Connect(ctx, cfg); err != nil {
|
||||||
panic(fmt.Sprintf("Failed to connect: %v", err))
|
log.Fatalf("Failed to connect: %v", err)
|
||||||
}
|
}
|
||||||
defer provider.Close()
|
defer provider.Close()
|
||||||
|
|
||||||
listener, err := provider.GetListener(ctx)
|
listener, err := provider.GetListener(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("Failed to get listener: %v", err))
|
log.Fatalf("Failed to get listener: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// The listener automatically reconnects if the connection is lost
|
// The listener automatically reconnects if the connection is lost
|
||||||
|
|||||||
@@ -4,11 +4,17 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"go.mongodb.org/mongo-driver/mongo"
|
"go.mongodb.org/mongo-driver/mongo"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// isDBClosed reports whether err indicates the *sql.DB has been closed.
|
||||||
|
func isDBClosed(err error) bool {
|
||||||
|
return err != nil && strings.Contains(err.Error(), "sql: database is closed")
|
||||||
|
}
|
||||||
|
|
||||||
// Common errors
|
// Common errors
|
||||||
var (
|
var (
|
||||||
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
|
// ErrNotSQLDatabase is returned when attempting SQL operations on a non-SQL database
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver
|
_ "github.com/glebarez/sqlite" // Pure Go SQLite driver
|
||||||
@@ -15,6 +16,8 @@ import (
|
|||||||
// SQLiteProvider implements Provider for SQLite databases
|
// SQLiteProvider implements Provider for SQLite databases
|
||||||
type SQLiteProvider struct {
|
type SQLiteProvider struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*sql.DB, error)
|
||||||
config ConnectionConfig
|
config ConnectionConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,7 +132,13 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
|
|||||||
|
|
||||||
// Execute a simple query to verify the database is accessible
|
// Execute a simple query to verify the database is accessible
|
||||||
var result int
|
var result int
|
||||||
err := p.db.QueryRowContext(healthCtx, "SELECT 1").Scan(&result)
|
run := func() error { return p.getDB().QueryRowContext(healthCtx, "SELECT 1").Scan(&result) }
|
||||||
|
err := run()
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := p.reconnectDB(); reconnErr == nil {
|
||||||
|
err = run()
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("health check failed: %w", err)
|
return fmt.Errorf("health check failed: %w", err)
|
||||||
}
|
}
|
||||||
@@ -141,6 +150,32 @@ func (p *SQLiteProvider) HealthCheck(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithDBFactory configures a factory used to reopen the database connection if it is closed.
|
||||||
|
func (p *SQLiteProvider) WithDBFactory(factory func() (*sql.DB, error)) *SQLiteProvider {
|
||||||
|
p.dbFactory = factory
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SQLiteProvider) getDB() *sql.DB {
|
||||||
|
p.dbMu.RLock()
|
||||||
|
defer p.dbMu.RUnlock()
|
||||||
|
return p.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SQLiteProvider) reconnectDB() error {
|
||||||
|
if p.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := p.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.dbMu.Lock()
|
||||||
|
p.db = newDB
|
||||||
|
p.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetNative returns the native *sql.DB connection
|
// GetNative returns the native *sql.DB connection
|
||||||
func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
|
func (p *SQLiteProvider) GetNative() (*sql.DB, error) {
|
||||||
if p.db == nil {
|
if p.db == nil {
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func (s *SentryProvider) CaptureError(ctx context.Context, err error, severity S
|
|||||||
}
|
}
|
||||||
|
|
||||||
if extra != nil {
|
if extra != nil {
|
||||||
event.Extra = extra
|
event.Contexts["extra"] = sentry.Context(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
hub.CaptureEvent(event)
|
hub.CaptureEvent(event)
|
||||||
@@ -88,7 +88,7 @@ func (s *SentryProvider) CaptureMessage(ctx context.Context, message string, sev
|
|||||||
event.Message = message
|
event.Message = message
|
||||||
|
|
||||||
if extra != nil {
|
if extra != nil {
|
||||||
event.Extra = extra
|
event.Contexts["extra"] = sentry.Context(extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
hub.CaptureEvent(event)
|
hub.CaptureEvent(event)
|
||||||
@@ -115,12 +115,15 @@ func (s *SentryProvider) CapturePanic(ctx context.Context, recovered interface{}
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if extra != nil {
|
extraCtx := sentry.Context{}
|
||||||
event.Extra = extra
|
for k, v := range extra {
|
||||||
|
extraCtx[k] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
if stackTrace != nil {
|
if stackTrace != nil {
|
||||||
event.Extra["stack_trace"] = string(stackTrace)
|
extraCtx["stack_trace"] = string(stackTrace)
|
||||||
|
}
|
||||||
|
if len(extraCtx) > 0 {
|
||||||
|
event.Contexts["extra"] = extraCtx
|
||||||
}
|
}
|
||||||
|
|
||||||
hub.CaptureEvent(event)
|
hub.CaptureEvent(event)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package funcspec
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -168,9 +169,16 @@ func (h *Handler) SqlQueryList(sqlquery string, options SqlQueryOptions) HTTPFun
|
|||||||
// Replace meta variables in SQL
|
// Replace meta variables in SQL
|
||||||
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
|
sqlquery = h.replaceMetaVariables(sqlquery, r, userCtx, metainfo, variables)
|
||||||
|
|
||||||
// Remove unused input variables
|
// Replace variables from provided values, then blank any remaining unused ones
|
||||||
if options.BlankParams {
|
|
||||||
for _, kw := range inputvars {
|
for _, kw := range inputvars {
|
||||||
|
varName := kw[1 : len(kw)-1] // strip [ and ]
|
||||||
|
if val, ok := variables[varName]; ok {
|
||||||
|
if strVal := fmt.Sprintf("%v", val); strVal != "" {
|
||||||
|
sqlquery = strings.ReplaceAll(sqlquery, kw, safeSubstituteVar(sqlquery, kw, strVal))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if options.BlankParams {
|
||||||
replacement := getReplacementForBlankParam(sqlquery, kw)
|
replacement := getReplacementForBlankParam(sqlquery, kw)
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
||||||
logger.Debug("Replaced unused variable %s with: %s", kw, replacement)
|
logger.Debug("Replaced unused variable %s with: %s", kw, replacement)
|
||||||
@@ -520,9 +528,16 @@ func (h *Handler) SqlQuery(sqlquery string, options SqlQueryOptions) HTTPFuncTyp
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove unused input variables
|
// Replace variables from provided values, then blank any remaining unused ones
|
||||||
if options.BlankParams {
|
|
||||||
for _, kw := range inputvars {
|
for _, kw := range inputvars {
|
||||||
|
varName := kw[1 : len(kw)-1] // strip [ and ]
|
||||||
|
if val, ok := variables[varName]; ok {
|
||||||
|
if strVal := fmt.Sprintf("%v", val); strVal != "" {
|
||||||
|
sqlquery = strings.ReplaceAll(sqlquery, kw, safeSubstituteVar(sqlquery, kw, strVal))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if options.BlankParams {
|
||||||
replacement := getReplacementForBlankParam(sqlquery, kw)
|
replacement := getReplacementForBlankParam(sqlquery, kw)
|
||||||
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
sqlquery = strings.ReplaceAll(sqlquery, kw, replacement)
|
||||||
logger.Debug("Replaced unused variable %s with: %s", kw, replacement)
|
logger.Debug("Replaced unused variable %s with: %s", kw, replacement)
|
||||||
@@ -715,8 +730,10 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m
|
|||||||
propQry[parmk] = val
|
propQry[parmk] = val
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply filters if allowed
|
// Apply filters if allowed — check only the SELECT list to avoid matching function
|
||||||
if allowFilter && len(parmk) > 1 && strings.Contains(strings.ToLower(sqlquery), strings.ToLower(parmk)) {
|
// parameters in the FROM clause (e.g. [p_rid_doctype] in a set-returning function call)
|
||||||
|
// or names inside quoted string arguments.
|
||||||
|
if allowFilter && len(parmk) > 1 && strings.Contains(sqlSelectList(sqlStripStringLiterals(sqlquery)), strings.ToLower(parmk)) {
|
||||||
if len(parmv) > 1 {
|
if len(parmv) > 1 {
|
||||||
// Sanitize each value in the IN clause with appropriate quoting
|
// Sanitize each value in the IN clause with appropriate quoting
|
||||||
sanitizedValues := make([]string, len(parmv))
|
sanitizedValues := make([]string, len(parmv))
|
||||||
@@ -739,7 +756,7 @@ func (h *Handler) mergeQueryParams(r *http.Request, sqlquery string, variables m
|
|||||||
colval = strings.ReplaceAll(colval, "\\", "\\\\")
|
colval = strings.ReplaceAll(colval, "\\", "\\\\")
|
||||||
colval = strings.ReplaceAll(colval, "'", "''")
|
colval = strings.ReplaceAll(colval, "'", "''")
|
||||||
if colval != "*" {
|
if colval != "*" {
|
||||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), colval))
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("CAST(%s AS TEXT) ILIKE '%%%s%%'", ValidSQL(parmk, "colname"), colval))
|
||||||
}
|
}
|
||||||
} else if val == "" || val == "0" {
|
} else if val == "" || val == "0" {
|
||||||
// For empty/zero values, treat as literal 0 or empty string with quotes
|
// For empty/zero values, treat as literal 0 or empty string with quotes
|
||||||
@@ -806,7 +823,7 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables
|
|||||||
colname := strings.ReplaceAll(k, "x-searchfilter-", "")
|
colname := strings.ReplaceAll(k, "x-searchfilter-", "")
|
||||||
sval := strings.ReplaceAll(val, "'", "")
|
sval := strings.ReplaceAll(val, "'", "")
|
||||||
if sval != "" {
|
if sval != "" {
|
||||||
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colname, "colname"), ValidSQL(sval, "colvalue")))
|
sqlquery = sqlQryWhere(sqlquery, fmt.Sprintf("CAST(%s AS TEXT) ILIKE '%%%s%%'", ValidSQL(colname, "colname"), ValidSQL(sval, "colvalue")))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -824,6 +841,26 @@ func (h *Handler) mergeHeaderParams(r *http.Request, sqlquery string, variables
|
|||||||
return sqlquery
|
return sqlquery
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sqlStripStringLiterals removes the contents of single-quoted string literals from SQL,
|
||||||
|
// leaving the structural identifiers (column names, table names) intact.
|
||||||
|
// Used to check column presence without matching inside string arguments.
|
||||||
|
func sqlStripStringLiterals(sql string) string {
|
||||||
|
re := regexp.MustCompile(`'(?:[^']|'')*'`)
|
||||||
|
return re.ReplaceAllString(sql, "''")
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqlSelectList returns the column list portion of a SELECT query (between SELECT and FROM).
|
||||||
|
// Returns the full query lowercased if no clear SELECT…FROM boundary is found.
|
||||||
|
func sqlSelectList(sql string) string {
|
||||||
|
lower := strings.ToLower(sql)
|
||||||
|
selectPos := strings.Index(lower, "select ")
|
||||||
|
fromPos := strings.Index(lower, " from ")
|
||||||
|
if selectPos < 0 || fromPos <= selectPos {
|
||||||
|
return lower
|
||||||
|
}
|
||||||
|
return lower[selectPos+7 : fromPos]
|
||||||
|
}
|
||||||
|
|
||||||
// replaceMetaVariables replaces meta variables like [rid_user], [user], etc. in the SQL query
|
// replaceMetaVariables replaces meta variables like [rid_user], [user], etc. in the SQL query
|
||||||
func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string {
|
func (h *Handler) replaceMetaVariables(sqlquery string, r *http.Request, userCtx *security.UserContext, metainfo map[string]interface{}, variables map[string]interface{}) string {
|
||||||
if strings.Contains(sqlquery, "[p_meta_default]") {
|
if strings.Contains(sqlquery, "[p_meta_default]") {
|
||||||
@@ -969,6 +1006,37 @@ func IsNumeric(s string) bool {
|
|||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isInsideDollarQuote reports whether the first occurrence of placeholder in sqlquery
|
||||||
|
// is immediately surrounded by dollar-sign characters (i.e. inside a $...$-quoted string).
|
||||||
|
// Dollar-quoted strings pass content through literally — no backslash processing — so
|
||||||
|
// values placed there must NOT have their backslashes escaped.
|
||||||
|
func isInsideDollarQuote(sqlquery, placeholder string) bool {
|
||||||
|
idx := strings.Index(sqlquery, placeholder)
|
||||||
|
if idx < 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
endIdx := idx + len(placeholder)
|
||||||
|
charBefore := byte(0)
|
||||||
|
charAfter := byte(0)
|
||||||
|
if idx > 0 {
|
||||||
|
charBefore = sqlquery[idx-1]
|
||||||
|
}
|
||||||
|
if endIdx < len(sqlquery) {
|
||||||
|
charAfter = sqlquery[endIdx]
|
||||||
|
}
|
||||||
|
return charBefore == '$' || charAfter == '$'
|
||||||
|
}
|
||||||
|
|
||||||
|
// safeSubstituteVar returns value sanitised for the quoting context that surrounds
|
||||||
|
// placeholder in sqlquery: raw (no backslash escaping) for dollar-quoted contexts,
|
||||||
|
// ValidSQL("colvalue") escaping for everything else.
|
||||||
|
func safeSubstituteVar(sqlquery, placeholder, value string) string {
|
||||||
|
if isInsideDollarQuote(sqlquery, placeholder) {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return ValidSQL(value, "colvalue")
|
||||||
|
}
|
||||||
|
|
||||||
// getReplacementForBlankParam determines the replacement value for an unused parameter
|
// getReplacementForBlankParam determines the replacement value for an unused parameter
|
||||||
// based on whether it appears within quotes in the SQL query.
|
// based on whether it appears within quotes in the SQL query.
|
||||||
// It checks for PostgreSQL quotes: single quotes (”) and dollar quotes ($...$)
|
// It checks for PostgreSQL quotes: single quotes (”) and dollar quotes ($...$)
|
||||||
@@ -991,8 +1059,8 @@ func getReplacementForBlankParam(sqlquery, param string) string {
|
|||||||
charAfter = sqlquery[endIdx]
|
charAfter = sqlquery[endIdx]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if parameter is surrounded by quotes (single quote or dollar sign for PostgreSQL dollar-quoted strings)
|
// Check if parameter is surrounded by quotes (single quote, dollar sign for PostgreSQL dollar-quoted strings, or double quote for JSON string values)
|
||||||
if (charBefore == '\'' || charBefore == '$') && (charAfter == '\'' || charAfter == '$') {
|
if (charBefore == '\'' || charBefore == '$' || charBefore == '"') && (charAfter == '\'' || charAfter == '$' || charAfter == '"') {
|
||||||
// Parameter is in quotes, return empty string
|
// Parameter is in quotes, return empty string
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -1035,6 +1103,10 @@ func sendError(w http.ResponseWriter, status int, code, message string, err erro
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errObj.Detail = err.Error()
|
errObj.Detail = err.Error()
|
||||||
|
var sqlErr *common.SQLError
|
||||||
|
if errors.As(err, &sqlErr) {
|
||||||
|
errObj.SQL = sqlErr.SQL
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data, _ := json.Marshal(map[string]interface{}{
|
data, _ := json.Marshal(map[string]interface{}{
|
||||||
|
|||||||
@@ -821,7 +821,7 @@ func TestReplaceMetaVariables(t *testing.T) {
|
|||||||
name: "Replace [user]",
|
name: "Replace [user]",
|
||||||
sqlQuery: "SELECT * FROM audit WHERE username = [user]",
|
sqlQuery: "SELECT * FROM audit WHERE username = [user]",
|
||||||
expectedCheck: func(result string) bool {
|
expectedCheck: func(result string) bool {
|
||||||
return strings.Contains(result, "'testuser'")
|
return strings.Contains(result, "$USR$testuser$USR$")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -851,6 +851,285 @@ func TestReplaceMetaVariables(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSqlStripStringLiterals tests that single-quoted string literals are removed
|
||||||
|
func TestSqlStripStringLiterals(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "No string literals",
|
||||||
|
input: "SELECT rid, rid_parent FROM users",
|
||||||
|
expected: "SELECT rid, rid_parent FROM users",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Simple string literal",
|
||||||
|
input: "SELECT * FROM users WHERE mode = 'admin'",
|
||||||
|
expected: "SELECT * FROM users WHERE mode = ''",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "JSON argument containing column names",
|
||||||
|
input: `SELECT rid, rid_parent FROM crm_get_menu(1,'mode', '{"rid_parent":"[rid_parent]","CF:STARTDATE":"[cf_startdate]"}')`,
|
||||||
|
expected: `SELECT rid, rid_parent FROM crm_get_menu(1,'', '')`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Escaped single quotes inside literal",
|
||||||
|
input: "SELECT * FROM t WHERE name = 'O''Brien'",
|
||||||
|
expected: "SELECT * FROM t WHERE name = ''",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := sqlStripStringLiterals(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("sqlStripStringLiterals() =\n %q\nwant\n %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAllowFilterDoesNotMatchInsideJsonArgument verifies that AllowFilter will add WHERE
|
||||||
|
// clauses for real output columns (rid, rid_parent) but not for names that only appear
|
||||||
|
// inside a JSON string argument (cf_startdate, cf_rid_branch).
|
||||||
|
func TestAllowFilterDoesNotMatchInsideJsonArgument(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
sqlQuery := `select rid, rid_parent, description
|
||||||
|
from crm_get_menu([rid_user],'[p_mode]', 0, '', '{"rid_parent":"[rid_parent]", "CF:STARTDATE": "[cf_startdate]", "CF:RID_BRANCH": "[cf_rid_branch]"}')`
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryParams map[string]string
|
||||||
|
checkResult func(t *testing.T, result string)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "rid_parent=0 is a real column — filter applied",
|
||||||
|
queryParams: map[string]string{"rid_parent": "0"},
|
||||||
|
checkResult: func(t *testing.T, result string) {
|
||||||
|
if !strings.Contains(strings.ToLower(result), "where") {
|
||||||
|
t.Error("Expected WHERE clause to be added for rid_parent")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "rid_parent = 0 OR") && !strings.Contains(result, "rid_parent IS NULL") {
|
||||||
|
t.Errorf("Expected null-safe filter for rid_parent=0, got:\n%s", result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cf_startdate only appears in JSON string — no filter applied",
|
||||||
|
queryParams: map[string]string{"cf_startdate": "2024-01-01"},
|
||||||
|
checkResult: func(t *testing.T, result string) {
|
||||||
|
if strings.Contains(strings.ToLower(result), "where") {
|
||||||
|
t.Errorf("Expected no WHERE clause for cf_startdate (only in JSON arg), got:\n%s", result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cf_rid_branch only appears in JSON string — no filter applied",
|
||||||
|
queryParams: map[string]string{"cf_rid_branch": "5"},
|
||||||
|
checkResult: func(t *testing.T, result string) {
|
||||||
|
if strings.Contains(strings.ToLower(result), "where") {
|
||||||
|
t.Errorf("Expected no WHERE clause for cf_rid_branch (only in JSON arg), got:\n%s", result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "description is a real column — filter applied",
|
||||||
|
queryParams: map[string]string{"description": "test"},
|
||||||
|
checkResult: func(t *testing.T, result string) {
|
||||||
|
if !strings.Contains(strings.ToLower(result), "where") {
|
||||||
|
t.Error("Expected WHERE clause for description")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
|
||||||
|
variables := make(map[string]interface{})
|
||||||
|
propQry := make(map[string]string)
|
||||||
|
|
||||||
|
result := handler.mergeQueryParams(req, sqlQuery, variables, true, propQry)
|
||||||
|
tt.checkResult(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAllowFilterDoesNotMatchFunctionParams verifies that query params that appear only
|
||||||
|
// as function call arguments in the FROM clause (e.g. [p_rid_doctype]) are not treated
|
||||||
|
// as column filters, since they are not in the SELECT list.
|
||||||
|
func TestAllowFilterDoesNotMatchFunctionParams(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
sqlQuery := `select rid, rid_parent, description, row_cnt, filterstring, tableprefix, rid_table, tooltip, additionalfilter, haschildren
|
||||||
|
from crm_get_doc_menu($JQ$[p_tableprefix]$JQ$,[p_rid_parent],[p_rid_doctype],[p_removedup],[p_showall]) r`
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryParams map[string]string
|
||||||
|
checkResult func(t *testing.T, result string)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "p_rid_doctype is a function param, not a column — no filter applied",
|
||||||
|
queryParams: map[string]string{"p_rid_doctype": "0"},
|
||||||
|
checkResult: func(t *testing.T, result string) {
|
||||||
|
if strings.Contains(strings.ToLower(result), "where") {
|
||||||
|
t.Errorf("Expected no WHERE clause for p_rid_doctype (function arg, not SELECT column), got:\n%s", result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "p_showall is a function param, not a column — no filter applied",
|
||||||
|
queryParams: map[string]string{"p_showall": "1"},
|
||||||
|
checkResult: func(t *testing.T, result string) {
|
||||||
|
if strings.Contains(strings.ToLower(result), "where") {
|
||||||
|
t.Errorf("Expected no WHERE clause for p_showall (function arg, not SELECT column), got:\n%s", result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rid is a SELECT column — filter applied",
|
||||||
|
queryParams: map[string]string{"rid": "42"},
|
||||||
|
checkResult: func(t *testing.T, result string) {
|
||||||
|
if !strings.Contains(strings.ToLower(result), "where") {
|
||||||
|
t.Error("Expected WHERE clause for rid (real SELECT column)")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
|
||||||
|
variables := make(map[string]interface{})
|
||||||
|
propQry := make(map[string]string)
|
||||||
|
result := handler.mergeQueryParams(req, sqlQuery, variables, true, propQry)
|
||||||
|
tt.checkResult(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGetReplacementForBlankParamDoubleQuote verifies that placeholders surrounded by
|
||||||
|
// double quotes (as in JSON string values) are blanked to "" not NULL.
|
||||||
|
func TestGetReplacementForBlankParamDoubleQuote(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
sqlQuery string
|
||||||
|
param string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Parameter in double quotes (JSON value)",
|
||||||
|
sqlQuery: `SELECT * FROM f(1, '{"key":"[myparam]"}')`,
|
||||||
|
param: "[myparam]",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parameter not in any quotes",
|
||||||
|
sqlQuery: `SELECT * FROM f([myparam])`,
|
||||||
|
param: "[myparam]",
|
||||||
|
expected: "NULL",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Parameter in single quotes",
|
||||||
|
sqlQuery: `SELECT * FROM f('[myparam]')`,
|
||||||
|
param: "[myparam]",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := getReplacementForBlankParam(tt.sqlQuery, tt.param)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("getReplacementForBlankParam() = %q, want %q\nquery: %s", result, tt.expected, tt.sqlQuery)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestVariableReplacementFromQueryParams verifies that query params matching [placeholder]
|
||||||
|
// tokens are substituted even when they don't have the p- prefix.
|
||||||
|
func TestVariableReplacementFromQueryParams(t *testing.T) {
|
||||||
|
handler := NewHandler(&MockDatabase{})
|
||||||
|
|
||||||
|
sqlQuery := `select rid, rid_parent from crm_get_menu([rid_user],'[p_mode]', 0, '', '{"rid_parent":"[rid_parent]","CF:STARTDATE":"[cf_startdate]"}')`
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
queryParams map[string]string
|
||||||
|
checkResult func(t *testing.T, result string)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "rid_parent replaced from query param",
|
||||||
|
queryParams: map[string]string{"rid_parent": "42"},
|
||||||
|
checkResult: func(t *testing.T, result string) {
|
||||||
|
if strings.Contains(result, "[rid_parent]") {
|
||||||
|
t.Errorf("Expected [rid_parent] to be replaced, still present in:\n%s", result)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "42") {
|
||||||
|
t.Errorf("Expected value 42 in query, got:\n%s", result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cf_startdate replaced from query param",
|
||||||
|
queryParams: map[string]string{"cf_startdate": "2024-01-01"},
|
||||||
|
checkResult: func(t *testing.T, result string) {
|
||||||
|
if strings.Contains(result, "[cf_startdate]") {
|
||||||
|
t.Errorf("Expected [cf_startdate] to be replaced, still present in:\n%s", result)
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "2024-01-01") {
|
||||||
|
t.Errorf("Expected date value in query, got:\n%s", result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing param blanked to empty string inside JSON (double-quoted)",
|
||||||
|
queryParams: map[string]string{},
|
||||||
|
checkResult: func(t *testing.T, result string) {
|
||||||
|
// [cf_startdate] is surrounded by " in the JSON — should blank to ""
|
||||||
|
if strings.Contains(result, "[cf_startdate]") {
|
||||||
|
t.Errorf("Expected [cf_startdate] to be blanked, still present in:\n%s", result)
|
||||||
|
}
|
||||||
|
if strings.Contains(result, "NULL") && strings.Contains(result, "cf_startdate") {
|
||||||
|
t.Errorf("Expected empty string (not NULL) for double-quoted placeholder, got:\n%s", result)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
inputvars := make([]string, 0)
|
||||||
|
q := handler.extractInputVariables(sqlQuery, &inputvars)
|
||||||
|
|
||||||
|
req := createTestRequest("GET", "/test", tt.queryParams, nil, nil)
|
||||||
|
variables := make(map[string]interface{})
|
||||||
|
propQry := make(map[string]string)
|
||||||
|
|
||||||
|
q = handler.mergeQueryParams(req, q, variables, false, propQry)
|
||||||
|
|
||||||
|
// Simulate the variable replacement + blank-param loop (mirrors function_api.go)
|
||||||
|
for _, kw := range inputvars {
|
||||||
|
varName := kw[1 : len(kw)-1]
|
||||||
|
if val, ok := variables[varName]; ok {
|
||||||
|
if strVal := strings.TrimSpace(val.(string)); strVal != "" {
|
||||||
|
q = strings.ReplaceAll(q, kw, ValidSQL(strVal, "colvalue"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
replacement := getReplacementForBlankParam(q, kw)
|
||||||
|
q = strings.ReplaceAll(q, kw, replacement)
|
||||||
|
}
|
||||||
|
|
||||||
|
tt.checkResult(t, q)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestGetReplacementForBlankParam tests the blank parameter replacement logic
|
// TestGetReplacementForBlankParam tests the blank parameter replacement logic
|
||||||
func TestGetReplacementForBlankParam(t *testing.T) {
|
func TestGetReplacementForBlankParam(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|||||||
@@ -259,7 +259,7 @@ func (h *Handler) ApplyFilters(sqlQuery string, params *RequestParameters) strin
|
|||||||
for colName, value := range params.SearchFilters {
|
for colName, value := range params.SearchFilters {
|
||||||
sval := strings.ReplaceAll(value, "'", "")
|
sval := strings.ReplaceAll(value, "'", "")
|
||||||
if sval != "" {
|
if sval != "" {
|
||||||
condition := fmt.Sprintf("%s ILIKE '%%%s%%'", ValidSQL(colName, "colname"), ValidSQL(sval, "colvalue"))
|
condition := fmt.Sprintf("CAST(%s AS TEXT) ILIKE '%%%s%%'", ValidSQL(colName, "colname"), ValidSQL(sval, "colvalue"))
|
||||||
sqlQuery = sqlQryWhere(sqlQuery, condition)
|
sqlQuery = sqlQryWhere(sqlQuery, condition)
|
||||||
logger.Debug("Applied search filter: %s", condition)
|
logger.Debug("Applied search filter: %s", condition)
|
||||||
}
|
}
|
||||||
@@ -307,11 +307,11 @@ func (h *Handler) buildFilterCondition(colName string, op FilterOperator) string
|
|||||||
|
|
||||||
switch operator {
|
switch operator {
|
||||||
case "contains", "contain", "like":
|
case "contains", "contain", "like":
|
||||||
return fmt.Sprintf("%s ILIKE '%%%s%%'", safCol, ValidSQL(value, "colvalue"))
|
return fmt.Sprintf("CAST(%s AS TEXT) ILIKE '%%%s%%'", safCol, ValidSQL(value, "colvalue"))
|
||||||
case "beginswith", "startswith":
|
case "beginswith", "startswith":
|
||||||
return fmt.Sprintf("%s ILIKE '%s%%'", safCol, ValidSQL(value, "colvalue"))
|
return fmt.Sprintf("CAST(%s AS TEXT) ILIKE '%s%%'", safCol, ValidSQL(value, "colvalue"))
|
||||||
case "endswith":
|
case "endswith":
|
||||||
return fmt.Sprintf("%s ILIKE '%%%s'", safCol, ValidSQL(value, "colvalue"))
|
return fmt.Sprintf("CAST(%s AS TEXT) ILIKE '%%%s'", safCol, ValidSQL(value, "colvalue"))
|
||||||
case "equals", "eq", "=":
|
case "equals", "eq", "=":
|
||||||
if IsNumeric(value) {
|
if IsNumeric(value) {
|
||||||
return fmt.Sprintf("%s = %s", safCol, ValidSQL(value, "colvalue"))
|
return fmt.Sprintf("%s = %s", safCol, ValidSQL(value, "colvalue"))
|
||||||
|
|||||||
@@ -274,7 +274,7 @@ func TestBuildFilterCondition(t *testing.T) {
|
|||||||
Value: "test",
|
Value: "test",
|
||||||
Logic: "AND",
|
Logic: "AND",
|
||||||
},
|
},
|
||||||
expected: "description ILIKE '%test%'",
|
expected: "CAST(description AS TEXT) ILIKE '%test%'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Starts with operator",
|
name: "Starts with operator",
|
||||||
@@ -284,7 +284,7 @@ func TestBuildFilterCondition(t *testing.T) {
|
|||||||
Value: "john",
|
Value: "john",
|
||||||
Logic: "AND",
|
Logic: "AND",
|
||||||
},
|
},
|
||||||
expected: "name ILIKE 'john%'",
|
expected: "CAST(name AS TEXT) ILIKE 'john%'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Ends with operator",
|
name: "Ends with operator",
|
||||||
@@ -294,7 +294,7 @@ func TestBuildFilterCondition(t *testing.T) {
|
|||||||
Value: "@example.com",
|
Value: "@example.com",
|
||||||
Logic: "AND",
|
Logic: "AND",
|
||||||
},
|
},
|
||||||
expected: "email ILIKE '%@example.com'",
|
expected: "CAST(email AS TEXT) ILIKE '%@example.com'",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Between operator",
|
name: "Between operator",
|
||||||
|
|||||||
@@ -2,14 +2,38 @@ package funcspec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RegisterSecurityHooks registers security hooks for funcspec handlers
|
// RegisterSecurityHooks registers security hooks for funcspec handlers
|
||||||
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
|
// Note: funcspec operates on SQL queries directly, so row-level security is not directly applicable
|
||||||
// We provide audit logging for data access tracking
|
// We provide auth enforcement and audit logging for data access tracking
|
||||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// Hook 0: BeforeQueryList - Auth check before list query execution
|
||||||
|
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
||||||
|
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = "authentication required"
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return fmt.Errorf("authentication required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 0: BeforeQuery - Auth check before single query execution
|
||||||
|
handler.Hooks().Register(BeforeQuery, func(hookCtx *HookContext) error {
|
||||||
|
if hookCtx.UserContext == nil || hookCtx.UserContext.UserID == 0 {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = "authentication required"
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return fmt.Errorf("authentication required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
// Hook 1: BeforeQueryList - Audit logging before query list execution
|
// Hook 1: BeforeQueryList - Audit logging before query list execution
|
||||||
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
handler.Hooks().Register(BeforeQueryList, func(hookCtx *HookContext) error {
|
||||||
secCtx := newFuncSpecSecurityContext(hookCtx)
|
secCtx := newFuncSpecSecurityContext(hookCtx)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package metrics
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -19,7 +20,7 @@ type Provider interface {
|
|||||||
DecRequestsInFlight()
|
DecRequestsInFlight()
|
||||||
|
|
||||||
// RecordDBQuery records metrics for a database query
|
// RecordDBQuery records metrics for a database query
|
||||||
RecordDBQuery(operation, table string, duration time.Duration, err error)
|
RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error)
|
||||||
|
|
||||||
// RecordCacheHit records a cache hit
|
// RecordCacheHit records a cache hit
|
||||||
RecordCacheHit(provider string)
|
RecordCacheHit(provider string)
|
||||||
@@ -46,21 +47,28 @@ type Provider interface {
|
|||||||
Handler() http.Handler
|
Handler() http.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// globalProvider is the global metrics provider
|
// globalProvider is the global metrics provider, protected by globalProviderMu.
|
||||||
var globalProvider Provider
|
var (
|
||||||
|
globalProviderMu sync.RWMutex
|
||||||
|
globalProvider Provider
|
||||||
|
)
|
||||||
|
|
||||||
// SetProvider sets the global metrics provider
|
// SetProvider sets the global metrics provider.
|
||||||
func SetProvider(p Provider) {
|
func SetProvider(p Provider) {
|
||||||
|
globalProviderMu.Lock()
|
||||||
globalProvider = p
|
globalProvider = p
|
||||||
|
globalProviderMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProvider returns the current metrics provider
|
// GetProvider returns the current metrics provider.
|
||||||
func GetProvider() Provider {
|
func GetProvider() Provider {
|
||||||
if globalProvider == nil {
|
globalProviderMu.RLock()
|
||||||
// Return no-op provider if none is set
|
p := globalProvider
|
||||||
|
globalProviderMu.RUnlock()
|
||||||
|
if p == nil {
|
||||||
return &NoOpProvider{}
|
return &NoOpProvider{}
|
||||||
}
|
}
|
||||||
return globalProvider
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
// NoOpProvider is a no-op implementation of Provider
|
// NoOpProvider is a no-op implementation of Provider
|
||||||
@@ -69,7 +77,7 @@ type NoOpProvider struct{}
|
|||||||
func (n *NoOpProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {}
|
func (n *NoOpProvider) RecordHTTPRequest(method, path, status string, duration time.Duration) {}
|
||||||
func (n *NoOpProvider) IncRequestsInFlight() {}
|
func (n *NoOpProvider) IncRequestsInFlight() {}
|
||||||
func (n *NoOpProvider) DecRequestsInFlight() {}
|
func (n *NoOpProvider) DecRequestsInFlight() {}
|
||||||
func (n *NoOpProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
func (n *NoOpProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
|
||||||
}
|
}
|
||||||
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
func (n *NoOpProvider) RecordCacheHit(provider string) {}
|
||||||
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
func (n *NoOpProvider) RecordCacheMiss(provider string) {}
|
||||||
|
|||||||
@@ -83,14 +83,14 @@ func NewPrometheusProvider(cfg *Config) *PrometheusProvider {
|
|||||||
Help: "Database query duration in seconds",
|
Help: "Database query duration in seconds",
|
||||||
Buckets: cfg.DBQueryBuckets,
|
Buckets: cfg.DBQueryBuckets,
|
||||||
},
|
},
|
||||||
[]string{"operation", "table"},
|
[]string{"operation", "schema", "entity", "table"},
|
||||||
),
|
),
|
||||||
dbQueryTotal: promauto.NewCounterVec(
|
dbQueryTotal: promauto.NewCounterVec(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
Name: metricName("db_queries_total"),
|
Name: metricName("db_queries_total"),
|
||||||
Help: "Total number of database queries",
|
Help: "Total number of database queries",
|
||||||
},
|
},
|
||||||
[]string{"operation", "table", "status"},
|
[]string{"operation", "schema", "entity", "table", "status"},
|
||||||
),
|
),
|
||||||
cacheHits: promauto.NewCounterVec(
|
cacheHits: promauto.NewCounterVec(
|
||||||
prometheus.CounterOpts{
|
prometheus.CounterOpts{
|
||||||
@@ -204,13 +204,13 @@ func (p *PrometheusProvider) DecRequestsInFlight() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RecordDBQuery implements Provider interface
|
// RecordDBQuery implements Provider interface
|
||||||
func (p *PrometheusProvider) RecordDBQuery(operation, table string, duration time.Duration, err error) {
|
func (p *PrometheusProvider) RecordDBQuery(operation, schema, entity, table string, duration time.Duration, err error) {
|
||||||
status := "success"
|
status := "success"
|
||||||
if err != nil {
|
if err != nil {
|
||||||
status = "error"
|
status = "error"
|
||||||
}
|
}
|
||||||
p.dbQueryDuration.WithLabelValues(operation, table).Observe(duration.Seconds())
|
p.dbQueryDuration.WithLabelValues(operation, schema, entity, table).Observe(duration.Seconds())
|
||||||
p.dbQueryTotal.WithLabelValues(operation, table, status).Inc()
|
p.dbQueryTotal.WithLabelValues(operation, schema, entity, table, status).Inc()
|
||||||
}
|
}
|
||||||
|
|
||||||
// RecordCacheHit implements Provider interface
|
// RecordCacheHit implements Provider interface
|
||||||
|
|||||||
@@ -8,6 +8,10 @@ import (
|
|||||||
|
|
||||||
// ModelRules defines the permissions and security settings for a model
|
// ModelRules defines the permissions and security settings for a model
|
||||||
type ModelRules struct {
|
type ModelRules struct {
|
||||||
|
CanPublicRead bool // Whether the model can be read (GET operations)
|
||||||
|
CanPublicUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||||
|
CanPublicCreate bool // Whether the model can be created (POST operations)
|
||||||
|
CanPublicDelete bool // Whether the model can be deleted (DELETE operations)
|
||||||
CanRead bool // Whether the model can be read (GET operations)
|
CanRead bool // Whether the model can be read (GET operations)
|
||||||
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
CanUpdate bool // Whether the model can be updated (PUT/PATCH operations)
|
||||||
CanCreate bool // Whether the model can be created (POST operations)
|
CanCreate bool // Whether the model can be created (POST operations)
|
||||||
@@ -22,6 +26,10 @@ func DefaultModelRules() ModelRules {
|
|||||||
CanUpdate: true,
|
CanUpdate: true,
|
||||||
CanCreate: true,
|
CanCreate: true,
|
||||||
CanDelete: true,
|
CanDelete: true,
|
||||||
|
CanPublicRead: false,
|
||||||
|
CanPublicUpdate: false,
|
||||||
|
CanPublicCreate: false,
|
||||||
|
CanPublicDelete: false,
|
||||||
SecurityDisabled: false,
|
SecurityDisabled: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+18
-3
@@ -9,7 +9,7 @@ MQTTSpec is an MQTT-based database query framework that enables real-time databa
|
|||||||
- **Full CRUD Operations**: Create, Read, Update, Delete with hooks
|
- **Full CRUD Operations**: Create, Read, Update, Delete with hooks
|
||||||
- **Real-time Subscriptions**: Subscribe to entity changes with filtering
|
- **Real-time Subscriptions**: Subscribe to entity changes with filtering
|
||||||
- **Database Agnostic**: GORM and Bun ORM support
|
- **Database Agnostic**: GORM and Bun ORM support
|
||||||
- **Lifecycle Hooks**: 12 hooks for authentication, authorization, validation, and auditing
|
- **Lifecycle Hooks**: 13 hooks for authentication, authorization, validation, and auditing
|
||||||
- **Multi-tenancy Support**: Built-in tenant isolation via hooks
|
- **Multi-tenancy Support**: Built-in tenant isolation via hooks
|
||||||
- **Thread-safe**: Proper concurrency handling throughout
|
- **Thread-safe**: Proper concurrency handling throughout
|
||||||
|
|
||||||
@@ -326,10 +326,11 @@ When any client creates/updates/deletes a user matching the subscription filters
|
|||||||
|
|
||||||
## Lifecycle Hooks
|
## Lifecycle Hooks
|
||||||
|
|
||||||
MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
|
MQTTSpec provides 13 lifecycle hooks for implementing cross-cutting concerns:
|
||||||
|
|
||||||
### Hook Types
|
### Hook Types
|
||||||
|
|
||||||
|
- `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
|
||||||
- `BeforeConnect` / `AfterConnect` - Connection lifecycle
|
- `BeforeConnect` / `AfterConnect` - Connection lifecycle
|
||||||
- `BeforeDisconnect` / `AfterDisconnect` - Disconnection lifecycle
|
- `BeforeDisconnect` / `AfterDisconnect` - Disconnection lifecycle
|
||||||
- `BeforeRead` / `AfterRead` - Read operations
|
- `BeforeRead` / `AfterRead` - Read operations
|
||||||
@@ -339,6 +340,20 @@ MQTTSpec provides 12 lifecycle hooks for implementing cross-cutting concerns:
|
|||||||
- `BeforeSubscribe` / `AfterSubscribe` - Subscription creation
|
- `BeforeSubscribe` / `AfterSubscribe` - Subscription creation
|
||||||
- `BeforeUnsubscribe` / `AfterUnsubscribe` - Subscription removal
|
- `BeforeUnsubscribe` / `AfterUnsubscribe` - Subscription removal
|
||||||
|
|
||||||
|
### Security Hooks (Recommended)
|
||||||
|
|
||||||
|
Use `RegisterSecurityHooks` for integrated auth with model-rule support:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
|
||||||
|
provider := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||||
|
securityList := security.NewSecurityList(provider)
|
||||||
|
mqttspec.RegisterSecurityHooks(handler, securityList)
|
||||||
|
// Registers BeforeHandle (model auth), BeforeRead (load rules),
|
||||||
|
// AfterRead (column security + audit), BeforeUpdate, BeforeDelete
|
||||||
|
```
|
||||||
|
|
||||||
### Authentication Example (JWT)
|
### Authentication Example (JWT)
|
||||||
|
|
||||||
```go
|
```go
|
||||||
@@ -657,7 +672,7 @@ handler, err := mqttspec.NewHandlerWithGORM(db,
|
|||||||
| **Network Efficiency** | Better for unreliable networks | Better for low-latency |
|
| **Network Efficiency** | Better for unreliable networks | Better for low-latency |
|
||||||
| **Best For** | IoT, mobile apps, distributed systems | Web applications, real-time dashboards |
|
| **Best For** | IoT, mobile apps, distributed systems | Web applications, real-time dashboards |
|
||||||
| **Message Protocol** | Same JSON structure | Same JSON structure |
|
| **Message Protocol** | Same JSON structure | Same JSON structure |
|
||||||
| **Hooks** | Same 12 hooks | Same 12 hooks |
|
| **Hooks** | Same 13 hooks | Same 13 hooks |
|
||||||
| **CRUD Operations** | Identical | Identical |
|
| **CRUD Operations** | Identical | Identical |
|
||||||
| **Subscriptions** | Identical (via MQTT topics) | Identical (via app-level) |
|
| **Subscriptions** | Identical (via MQTT topics) | Identical (via app-level) |
|
||||||
|
|
||||||
|
|||||||
@@ -284,6 +284,15 @@ func (h *Handler) handleRequest(client *Client, msg *Message) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||||
|
hookCtx.Operation = string(msg.Operation)
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||||
|
if hookCtx.Abort {
|
||||||
|
h.sendError(client.ID, msg.ID, "unauthorized", hookCtx.AbortMessage)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Route to operation handler
|
// Route to operation handler
|
||||||
switch msg.Operation {
|
switch msg.Operation {
|
||||||
case OperationRead:
|
case OperationRead:
|
||||||
@@ -693,8 +702,13 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
|
|||||||
if hookCtx.Options != nil {
|
if hookCtx.Options != nil {
|
||||||
// Apply filters
|
// Apply filters
|
||||||
for _, filter := range hookCtx.Options.Filters {
|
for _, filter := range hookCtx.Options.Filters {
|
||||||
|
op := strings.ToLower(filter.Operator)
|
||||||
|
if op == "like" || op == "ilike" {
|
||||||
|
query = query.Where(fmt.Sprintf("CAST(%s AS TEXT) %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||||
|
} else {
|
||||||
query = query.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
query = query.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Apply sorting
|
// Apply sorting
|
||||||
for _, sort := range hookCtx.Options.Sort {
|
for _, sort := range hookCtx.Options.Sort {
|
||||||
@@ -734,9 +748,14 @@ func (h *Handler) readMultiple(hookCtx *HookContext) (data interface{}, metadata
|
|||||||
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
countQuery := h.db.NewSelect().Model(hookCtx.ModelPtr).Table(hookCtx.TableName)
|
||||||
if hookCtx.Options != nil {
|
if hookCtx.Options != nil {
|
||||||
for _, filter := range hookCtx.Options.Filters {
|
for _, filter := range hookCtx.Options.Filters {
|
||||||
|
op := strings.ToLower(filter.Operator)
|
||||||
|
if op == "like" || op == "ilike" {
|
||||||
|
countQuery = countQuery.Where(fmt.Sprintf("CAST(%s AS TEXT) %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||||
|
} else {
|
||||||
countQuery = countQuery.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
countQuery = countQuery.Where(fmt.Sprintf("%s %s ?", filter.Column, h.getOperatorSQL(filter.Operator)), filter.Value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
count, _ := countQuery.Count(hookCtx.Context)
|
count, _ := countQuery.Count(hookCtx.Context)
|
||||||
metadata["total"] = count
|
metadata["total"] = count
|
||||||
metadata["count"] = reflection.Len(hookCtx.ModelPtr)
|
metadata["count"] = reflection.Len(hookCtx.ModelPtr)
|
||||||
|
|||||||
@@ -20,8 +20,11 @@ type (
|
|||||||
HookRegistry = websocketspec.HookRegistry
|
HookRegistry = websocketspec.HookRegistry
|
||||||
)
|
)
|
||||||
|
|
||||||
// Hook type constants - all 12 lifecycle hooks
|
// Hook type constants - all lifecycle hooks
|
||||||
const (
|
const (
|
||||||
|
// BeforeHandle fires after model resolution, before operation dispatch
|
||||||
|
BeforeHandle = websocketspec.BeforeHandle
|
||||||
|
|
||||||
// CRUD operation hooks
|
// CRUD operation hooks
|
||||||
BeforeRead = websocketspec.BeforeRead
|
BeforeRead = websocketspec.BeforeRead
|
||||||
AfterRead = websocketspec.AfterRead
|
AfterRead = websocketspec.AfterRead
|
||||||
|
|||||||
@@ -0,0 +1,108 @@
|
|||||||
|
package mqttspec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterSecurityHooks registers all security-related hooks with the MQTT handler
|
||||||
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||||
|
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||||
|
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = err.Error()
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 1: BeforeRead - Load security rules
|
||||||
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.LoadSecurityRules(secCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 2: AfterRead - Apply column-level security (masking)
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.ApplyColumnSecurity(secCtx, securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 3 (Optional): Audit logging
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.LogDataAccess(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 4: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelUpdateAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 5: BeforeDelete - enforce CanDelete rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelDeleteAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.Info("Security hooks registered for mqttspec handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// securityContext adapts mqttspec.HookContext to security.SecurityContext interface
|
||||||
|
type securityContext struct {
|
||||||
|
ctx *HookContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||||
|
return &securityContext{ctx: ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetContext() context.Context {
|
||||||
|
return s.ctx.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetUserID() (int, bool) {
|
||||||
|
return security.GetUserID(s.ctx.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetSchema() string {
|
||||||
|
return s.ctx.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetEntity() string {
|
||||||
|
return s.ctx.Entity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetModel() interface{} {
|
||||||
|
return s.ctx.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQuery retrieves a stored query from hook metadata
|
||||||
|
func (s *securityContext) GetQuery() interface{} {
|
||||||
|
if s.ctx.Metadata == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.ctx.Metadata["query"]
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetQuery stores the query in hook metadata
|
||||||
|
func (s *securityContext) SetQuery(query interface{}) {
|
||||||
|
if s.ctx.Metadata == nil {
|
||||||
|
s.ctx.Metadata = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
s.ctx.Metadata["query"] = query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetResult() interface{} {
|
||||||
|
return s.ctx.Result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) SetResult(result interface{}) {
|
||||||
|
s.ctx.Result = result
|
||||||
|
}
|
||||||
@@ -51,6 +51,31 @@ func ExtractTableNameOnly(fullName string) string {
|
|||||||
return fullName[startIndex:]
|
return fullName[startIndex:]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsEmptyValue reports whether v is nil, an empty string, or a zero number.
|
||||||
|
func IsEmptyValue(v any) bool {
|
||||||
|
if v == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
rv := reflect.ValueOf(v)
|
||||||
|
if rv.Kind() == reflect.Ptr {
|
||||||
|
if rv.IsNil() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
rv = rv.Elem()
|
||||||
|
}
|
||||||
|
switch rv.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
return rv.String() == ""
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
return rv.Int() == 0
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
|
||||||
|
return rv.Uint() == 0
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return rv.Float() == 0
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// GetPointerElement returns the element type if the provided reflect.Type is a pointer.
|
// GetPointerElement returns the element type if the provided reflect.Type is a pointer.
|
||||||
// If the type is a slice of pointers, it returns the element type of the pointer within the slice.
|
// If the type is a slice of pointers, it returns the element type of the pointer within the slice.
|
||||||
// If neither condition is met, it returns the original type.
|
// If neither condition is met, it returns the original type.
|
||||||
@@ -76,9 +101,14 @@ func GetJSONNameForField(modelType reflect.Type, fieldName string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle pointer types
|
// Unwrap pointer and slice indirections to reach the struct type
|
||||||
if modelType.Kind() == reflect.Ptr {
|
for {
|
||||||
|
switch modelType.Kind() {
|
||||||
|
case reflect.Ptr, reflect.Slice:
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelType.Kind() != reflect.Struct {
|
if modelType.Kind() != reflect.Struct {
|
||||||
|
|||||||
+230
-12
@@ -196,6 +196,92 @@ func collectColumnsFromType(typ reflect.Type, columns *[]string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetColumnName extracts the database column name from a struct field.
|
||||||
|
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name.
|
||||||
|
// This is the exported version for use by other packages.
|
||||||
|
func GetColumnName(field reflect.StructField) string {
|
||||||
|
return getColumnNameFromField(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildJSONToDBColumnMap returns a map from JSON key names to database column names
|
||||||
|
// for the given model type. Only writable, non-relation fields are included.
|
||||||
|
// This is used to translate incoming request data (keyed by JSON names) into
|
||||||
|
// properly named database columns before insert/update operations.
|
||||||
|
func BuildJSONToDBColumnMap(modelType reflect.Type) map[string]string {
|
||||||
|
result := make(map[string]string)
|
||||||
|
buildJSONToDBMap(modelType, result, false)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildJSONToDBMap(modelType reflect.Type, result map[string]string, scanOnly bool) {
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
if !field.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
gormTag := field.Tag.Get("gorm")
|
||||||
|
|
||||||
|
// Handle embedded structs
|
||||||
|
if field.Anonymous {
|
||||||
|
ft := field.Type
|
||||||
|
if ft.Kind() == reflect.Ptr {
|
||||||
|
ft = ft.Elem()
|
||||||
|
}
|
||||||
|
isScanOnly := scanOnly
|
||||||
|
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||||
|
isScanOnly = true
|
||||||
|
}
|
||||||
|
if ft.Kind() == reflect.Struct {
|
||||||
|
buildJSONToDBMap(ft, result, isScanOnly)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if scanOnly {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip explicitly excluded fields
|
||||||
|
if bunTag == "-" || gormTag == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip scan-only fields
|
||||||
|
if bunTag != "" && isBunFieldScanOnly(bunTag) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip bun relation fields
|
||||||
|
if bunTag != "" && (strings.Contains(bunTag, "rel:") || strings.Contains(bunTag, "join:") || strings.Contains(bunTag, "m2m:")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip gorm relation fields
|
||||||
|
if gormTag != "" && (strings.Contains(gormTag, "foreignKey:") || strings.Contains(gormTag, "references:") || strings.Contains(gormTag, "many2many:")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get JSON key (how the field appears in incoming request data)
|
||||||
|
jsonKey := ""
|
||||||
|
if jsonTag := field.Tag.Get("json"); jsonTag != "" && jsonTag != "-" {
|
||||||
|
parts := strings.Split(jsonTag, ",")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
jsonKey = parts[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if jsonKey == "" {
|
||||||
|
jsonKey = strings.ToLower(field.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the actual DB column name (bun > gorm > json > field name)
|
||||||
|
dbColName := getColumnNameFromField(field)
|
||||||
|
|
||||||
|
result[jsonKey] = dbColName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getColumnNameFromField extracts the column name from a struct field
|
// getColumnNameFromField extracts the column name from a struct field
|
||||||
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
|
// Priority: bun tag -> gorm tag -> json tag -> lowercase field name
|
||||||
func getColumnNameFromField(field reflect.StructField) string {
|
func getColumnNameFromField(field reflect.StructField) string {
|
||||||
@@ -455,9 +541,14 @@ func collectSQLColumnsFromType(typ reflect.Type, columns *[]string, scanOnlyEmbe
|
|||||||
func IsColumnWritable(model any, columnName string) bool {
|
func IsColumnWritable(model any, columnName string) bool {
|
||||||
modelType := reflect.TypeOf(model)
|
modelType := reflect.TypeOf(model)
|
||||||
|
|
||||||
// Unwrap pointers to get to the base struct type
|
// Unwrap pointers and slices to get to the base struct type
|
||||||
for modelType != nil && modelType.Kind() == reflect.Pointer {
|
for modelType != nil {
|
||||||
|
switch modelType.Kind() {
|
||||||
|
case reflect.Ptr, reflect.Slice:
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate that we have a struct type
|
// Validate that we have a struct type
|
||||||
@@ -792,8 +883,14 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
|
|||||||
return RelationUnknown
|
return RelationUnknown
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelType.Kind() == reflect.Ptr {
|
// Unwrap pointer → slice → pointer chains to reach the underlying struct
|
||||||
|
for {
|
||||||
|
switch modelType.Kind() {
|
||||||
|
case reflect.Ptr, reflect.Slice:
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
@@ -876,6 +973,108 @@ func GetRelationType(model interface{}, fieldName string) RelationType {
|
|||||||
return RelationUnknown
|
return RelationUnknown
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetForeignKeyColumn returns the DB column names of the foreign key(s) that
|
||||||
|
// relate parentKey to modelType. Composite keys (e.g. bun "join:a=b,join:c=d"
|
||||||
|
// or GORM "foreignKey:ColA,ColB") yield multiple entries. Returns nil when no
|
||||||
|
// tag is found (caller should fall back to convention).
|
||||||
|
//
|
||||||
|
// Two lookup strategies are tried in order:
|
||||||
|
//
|
||||||
|
// 1. Relation-field match: find a field whose name/json equals parentKey, then
|
||||||
|
// read its bun join: or GORM foreignKey: tag and return the local columns.
|
||||||
|
// e.g. parentKey="department", field `Department bun:"join:dept_id=id"` → ["dept_id"]
|
||||||
|
//
|
||||||
|
// 2. Join left-side scan: scan every bun join tag in the struct for pairs whose
|
||||||
|
// left side equals parentKey and return the right-side (child FK) columns.
|
||||||
|
// e.g. parentKey="rid_mastertaskitem", field `Children bun:"join:rid_mastertaskitem=rid_parentmastertaskitem"` → ["rid_parentmastertaskitem"]
|
||||||
|
// Strategy 1 is skipped if the matched field is a declared relation (rel:) or
|
||||||
|
// has a GORM tag but carries no explicit FK — callers should use convention.
|
||||||
|
func GetForeignKeyColumn(modelType reflect.Type, parentKey string) []string {
|
||||||
|
for modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
if modelType.Kind() != reflect.Struct {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strategy 1: match parentKey against a field's name/json tag.
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
|
||||||
|
name := field.Name
|
||||||
|
jsonName := strings.Split(field.Tag.Get("json"), ",")[0]
|
||||||
|
if !strings.EqualFold(name, parentKey) && !strings.EqualFold(jsonName, parentKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
|
||||||
|
// Bun: join:local_col=foreign_col (one join: part per pair)
|
||||||
|
var bunCols []string
|
||||||
|
for _, part := range strings.Split(bunTag, ",") {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, "join:") {
|
||||||
|
pair := strings.TrimPrefix(part, "join:")
|
||||||
|
if idx := strings.Index(pair, "="); idx > 0 {
|
||||||
|
bunCols = append(bunCols, pair[:idx])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(bunCols) > 0 {
|
||||||
|
return bunCols
|
||||||
|
}
|
||||||
|
|
||||||
|
// GORM: foreignKey:FieldA,FieldB
|
||||||
|
for _, part := range strings.Split(field.Tag.Get("gorm"), ";") {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, "foreignKey:") {
|
||||||
|
var cols []string
|
||||||
|
for _, fkFieldName := range strings.Split(strings.TrimPrefix(part, "foreignKey:"), ",") {
|
||||||
|
fkFieldName = strings.TrimSpace(fkFieldName)
|
||||||
|
if fkField, ok := modelType.FieldByName(fkFieldName); ok {
|
||||||
|
cols = append(cols, getColumnNameFromField(fkField))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(cols) > 0 {
|
||||||
|
return cols
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The field matched by name/json but has no explicit FK tag. If it is a
|
||||||
|
// declared relation field (rel:) or carries a GORM tag, the caller should
|
||||||
|
// use naming convention — don't fall through to strategy 2. Otherwise the
|
||||||
|
// matched field is a plain scalar column; proceed to the join left-side scan.
|
||||||
|
if strings.Contains(bunTag, "rel:") || field.Tag.Get("gorm") != "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strategy 2: scan every field's bun join tag for pairs whose left side (the
|
||||||
|
// parent's column) matches parentKey; the right side is the child FK column.
|
||||||
|
// This handles cases where parentKey is a raw column name rather than a
|
||||||
|
// relation field name (e.g. self-referential or has-many relationships).
|
||||||
|
seen := map[string]bool{}
|
||||||
|
var cols []string
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
for _, part := range strings.Split(modelType.Field(i).Tag.Get("bun"), ",") {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if strings.HasPrefix(part, "join:") {
|
||||||
|
pair := strings.TrimPrefix(part, "join:")
|
||||||
|
if idx := strings.Index(pair, "="); idx > 0 {
|
||||||
|
left, right := pair[:idx], pair[idx+1:]
|
||||||
|
if strings.EqualFold(left, parentKey) && !seen[right] {
|
||||||
|
seen[right] = true
|
||||||
|
cols = append(cols, right)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cols // nil if empty
|
||||||
|
}
|
||||||
|
|
||||||
// GetRelationModel gets the model type for a relation field
|
// GetRelationModel gets the model type for a relation field
|
||||||
// It searches for the field by name in the following order (case-insensitive):
|
// It searches for the field by name in the following order (case-insensitive):
|
||||||
// 1. Actual field name
|
// 1. Actual field name
|
||||||
@@ -1158,6 +1357,16 @@ func setFieldValue(field reflect.Value, value interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle map[string]interface{} → nested struct (e.g. relation fields like AFN, DEF)
|
||||||
|
if m, ok := value.(map[string]interface{}); ok {
|
||||||
|
if field.CanAddr() {
|
||||||
|
if err := MapToStruct(m, field.Addr().Interface()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Fallback: Try to find a "Val" field (for SqlNull types) and set it directly
|
// Fallback: Try to find a "Val" field (for SqlNull types) and set it directly
|
||||||
valField := field.FieldByName("Val")
|
valField := field.FieldByName("Val")
|
||||||
if valField.IsValid() && valField.CanSet() {
|
if valField.IsValid() && valField.CanSet() {
|
||||||
@@ -1376,9 +1585,14 @@ func convertToFloat64(value interface{}) (float64, bool) {
|
|||||||
func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool {
|
func GetValidJSONFieldNames(modelType reflect.Type) map[string]bool {
|
||||||
validFields := make(map[string]bool)
|
validFields := make(map[string]bool)
|
||||||
|
|
||||||
// Unwrap pointers to get to the base struct type
|
// Unwrap pointers and slices to get to the base struct type
|
||||||
for modelType != nil && modelType.Kind() == reflect.Pointer {
|
for modelType != nil {
|
||||||
|
switch modelType.Kind() {
|
||||||
|
case reflect.Ptr, reflect.Slice:
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
@@ -1439,8 +1653,13 @@ func getRelationModelSingleLevel(model interface{}, fieldName string) interface{
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelType.Kind() == reflect.Ptr {
|
for {
|
||||||
|
switch modelType.Kind() {
|
||||||
|
case reflect.Ptr, reflect.Slice:
|
||||||
modelType = modelType.Elem()
|
modelType = modelType.Elem()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if modelType == nil || modelType.Kind() != reflect.Struct {
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
@@ -1503,17 +1722,16 @@ func getRelationModelSingleLevel(model interface{}, fieldName string) interface{
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Slice {
|
for {
|
||||||
|
switch targetType.Kind() {
|
||||||
|
case reflect.Ptr, reflect.Slice:
|
||||||
targetType = targetType.Elem()
|
targetType = targetType.Elem()
|
||||||
if targetType == nil {
|
if targetType == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
if targetType.Kind() == reflect.Ptr {
|
break
|
||||||
targetType = targetType.Elem()
|
|
||||||
if targetType == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if targetType.Kind() != reflect.Struct {
|
if targetType.Kind() != reflect.Struct {
|
||||||
|
|||||||
@@ -0,0 +1,168 @@
|
|||||||
|
package reflection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --- local test models ---
|
||||||
|
|
||||||
|
type fkDept struct{}
|
||||||
|
|
||||||
|
// bunEmployee uses bun join: tag to declare the FK column explicitly.
|
||||||
|
type bunEmployee struct {
|
||||||
|
DeptID string `bun:"dept_id" json:"dept_id"`
|
||||||
|
Department *fkDept `bun:"rel:belongs-to,join:dept_id=id" json:"department"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// bunCompositeEmployee has a composite bun join: (two join: parts).
|
||||||
|
type bunCompositeEmployee struct {
|
||||||
|
DeptID string `bun:"dept_id" json:"dept_id"`
|
||||||
|
TenantID string `bun:"tenant_id" json:"tenant_id"`
|
||||||
|
Department *fkDept `bun:"rel:belongs-to,join:dept_id=id,join:tenant_id=id" json:"department"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// gormEmployee uses gorm foreignKey: tag (mirrors testmodels.Employee).
|
||||||
|
type gormEmployee struct {
|
||||||
|
DepartmentID string `json:"department_id"`
|
||||||
|
ManagerID string `json:"manager_id"`
|
||||||
|
Department *fkDept `gorm:"foreignKey:DepartmentID;references:ID" json:"department"`
|
||||||
|
Manager *fkDept `gorm:"foreignKey:ManagerID;references:ID" json:"manager"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// gormCompositeEmployee has a composite GORM foreignKey.
|
||||||
|
type gormCompositeEmployee struct {
|
||||||
|
DeptID string `json:"dept_id"`
|
||||||
|
TenantID string `json:"tenant_id"`
|
||||||
|
Department *fkDept `gorm:"foreignKey:DeptID,TenantID" json:"department"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// selfRefItem mimics a self-referential model (like mastertaskitem) where the
|
||||||
|
// parent PK column appears as the left side of a has-many join tag.
|
||||||
|
type selfRefItem struct {
|
||||||
|
RidItem int32 `json:"rid_item" bun:"rid_item,type:integer,pk"`
|
||||||
|
RidParentItem int32 `json:"rid_parentitem" bun:"rid_parentitem,type:integer"`
|
||||||
|
// has-one (single parent pointer)
|
||||||
|
Parent *selfRefItem `json:"Parent,omitempty" bun:"rel:has-one,join:rid_item=rid_parentitem"`
|
||||||
|
// has-many (child collection) — same join, duplicate right-side must be deduped
|
||||||
|
Children []*selfRefItem `json:"Children,omitempty" bun:"rel:has-many,join:rid_item=rid_parentitem"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// conventionEmployee has no explicit FK tag — relies on naming convention.
|
||||||
|
type conventionEmployee struct {
|
||||||
|
DepartmentID string `json:"department_id"`
|
||||||
|
Department *fkDept `json:"department"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// noTagEmployee has a relation field with no FK tag and no convention match.
|
||||||
|
type noTagEmployee struct {
|
||||||
|
Unrelated *fkDept `json:"unrelated"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetForeignKeyColumn(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
modelType reflect.Type
|
||||||
|
parentKey string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
// Bun join: tag
|
||||||
|
{
|
||||||
|
name: "bun join tag returns local column",
|
||||||
|
modelType: reflect.TypeOf(bunEmployee{}),
|
||||||
|
parentKey: "department",
|
||||||
|
want: []string{"dept_id"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bun join tag matched via json tag (case-insensitive)",
|
||||||
|
modelType: reflect.TypeOf(bunEmployee{}),
|
||||||
|
parentKey: "Department",
|
||||||
|
want: []string{"dept_id"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bun composite join returns all local columns",
|
||||||
|
modelType: reflect.TypeOf(bunCompositeEmployee{}),
|
||||||
|
parentKey: "department",
|
||||||
|
want: []string{"dept_id", "tenant_id"},
|
||||||
|
},
|
||||||
|
|
||||||
|
// GORM foreignKey: tag
|
||||||
|
{
|
||||||
|
name: "gorm foreignKey resolves to column name",
|
||||||
|
modelType: reflect.TypeOf(gormEmployee{}),
|
||||||
|
parentKey: "department",
|
||||||
|
want: []string{"department_id"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gorm foreignKey resolves second relation",
|
||||||
|
modelType: reflect.TypeOf(gormEmployee{}),
|
||||||
|
parentKey: "manager",
|
||||||
|
want: []string{"manager_id"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gorm foreignKey matched case-insensitively",
|
||||||
|
modelType: reflect.TypeOf(gormEmployee{}),
|
||||||
|
parentKey: "Department",
|
||||||
|
want: []string{"department_id"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gorm composite foreignKey returns all columns",
|
||||||
|
modelType: reflect.TypeOf(gormCompositeEmployee{}),
|
||||||
|
parentKey: "department",
|
||||||
|
want: []string{"dept_id", "tenant_id"},
|
||||||
|
},
|
||||||
|
|
||||||
|
// Join left-side scan (parentKey is a raw column name, not a relation field name)
|
||||||
|
{
|
||||||
|
name: "self-referential: parent PK column returns child FK column",
|
||||||
|
modelType: reflect.TypeOf(selfRefItem{}),
|
||||||
|
parentKey: "rid_item",
|
||||||
|
want: []string{"rid_parentitem"},
|
||||||
|
},
|
||||||
|
|
||||||
|
// Pointer and slice unwrapping
|
||||||
|
{
|
||||||
|
name: "pointer to struct is unwrapped",
|
||||||
|
modelType: reflect.TypeOf(&gormEmployee{}),
|
||||||
|
parentKey: "department",
|
||||||
|
want: []string{"department_id"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "slice of struct is unwrapped",
|
||||||
|
modelType: reflect.TypeOf([]gormEmployee{}),
|
||||||
|
parentKey: "department",
|
||||||
|
want: []string{"department_id"},
|
||||||
|
},
|
||||||
|
|
||||||
|
// No tag — returns nil so caller can fall back to convention
|
||||||
|
{
|
||||||
|
name: "relation with no FK tag returns nil",
|
||||||
|
modelType: reflect.TypeOf(conventionEmployee{}),
|
||||||
|
parentKey: "department",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
|
||||||
|
// Unknown parent key
|
||||||
|
{
|
||||||
|
name: "unknown parent key returns nil",
|
||||||
|
modelType: reflect.TypeOf(gormEmployee{}),
|
||||||
|
parentKey: "nonexistent",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-struct type returns nil",
|
||||||
|
modelType: reflect.TypeOf(""),
|
||||||
|
parentKey: "department",
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := GetForeignKeyColumn(tt.modelType, tt.parentKey)
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("GetForeignKeyColumn(%v, %q) = %v, want %v", tt.modelType, tt.parentKey, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -221,6 +221,124 @@ func TestMapToStruct_AllSqlTypes(t *testing.T) {
|
|||||||
t.Logf(" - SqlJSONB (Tags): %v", tagsValue)
|
t.Logf(" - SqlJSONB (Tags): %v", tagsValue)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestMapToStruct_NestedStructPointer tests that a map[string]interface{} value is
|
||||||
|
// correctly converted into a pointer-to-struct field (e.g. AFN *ModelCoreActionfunction).
|
||||||
|
func TestMapToStruct_NestedStructPointer(t *testing.T) {
|
||||||
|
type Inner struct {
|
||||||
|
ID spectypes.SqlInt32 `bun:"rid_inner,pk" json:"rid_inner"`
|
||||||
|
Name spectypes.SqlString `bun:"name" json:"name"`
|
||||||
|
}
|
||||||
|
type Outer struct {
|
||||||
|
ID spectypes.SqlInt32 `bun:"rid_outer,pk" json:"rid_outer"`
|
||||||
|
Inner *Inner `json:"inner,omitempty" bun:"rel:has-one"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dataMap := map[string]interface{}{
|
||||||
|
"rid_outer": int64(1),
|
||||||
|
"inner": map[string]interface{}{
|
||||||
|
"rid_inner": int64(42),
|
||||||
|
"name": "hello",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var result Outer
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.ID.Valid || result.ID.Val != 1 {
|
||||||
|
t.Errorf("ID = %v, want 1", result.ID)
|
||||||
|
}
|
||||||
|
if result.Inner == nil {
|
||||||
|
t.Fatal("Inner is nil, want non-nil")
|
||||||
|
}
|
||||||
|
if !result.Inner.ID.Valid || result.Inner.ID.Val != 42 {
|
||||||
|
t.Errorf("Inner.ID = %v, want 42", result.Inner.ID)
|
||||||
|
}
|
||||||
|
if !result.Inner.Name.Valid || result.Inner.Name.Val != "hello" {
|
||||||
|
t.Errorf("Inner.Name = %v, want 'hello'", result.Inner.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMapToStruct_NestedStructNilPointer tests that a nil map value leaves the pointer nil.
|
||||||
|
func TestMapToStruct_NestedStructNilPointer(t *testing.T) {
|
||||||
|
type Inner struct {
|
||||||
|
ID spectypes.SqlInt32 `bun:"rid_inner,pk" json:"rid_inner"`
|
||||||
|
}
|
||||||
|
type Outer struct {
|
||||||
|
ID spectypes.SqlInt32 `bun:"rid_outer,pk" json:"rid_outer"`
|
||||||
|
Inner *Inner `json:"inner,omitempty" bun:"rel:has-one"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dataMap := map[string]interface{}{
|
||||||
|
"rid_outer": int64(5),
|
||||||
|
"inner": nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
var result Outer
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Inner != nil {
|
||||||
|
t.Errorf("Inner = %v, want nil", result.Inner)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMapToStruct_NestedStructWithSpectypes mirrors the real-world case of
|
||||||
|
// ModelCoreActionoption.AFN being populated from map[string]interface{}.
|
||||||
|
func TestMapToStruct_NestedStructWithSpectypes(t *testing.T) {
|
||||||
|
type ActionFunction struct {
|
||||||
|
Ridactionfunction spectypes.SqlInt32 `bun:"rid_actionfunction,pk" json:"rid_actionfunction"`
|
||||||
|
Functionname spectypes.SqlString `bun:"functionname" json:"functionname"`
|
||||||
|
Fntype spectypes.SqlString `bun:"fntype" json:"fntype"`
|
||||||
|
}
|
||||||
|
type ActionOption struct {
|
||||||
|
Ridactionoption spectypes.SqlInt32 `bun:"rid_actionoption,pk" json:"rid_actionoption"`
|
||||||
|
Ridactionfunction spectypes.SqlInt32 `bun:"rid_actionfunction" json:"rid_actionfunction"`
|
||||||
|
Description spectypes.SqlString `bun:"description" json:"description"`
|
||||||
|
AFN *ActionFunction `json:"AFN,omitempty" bun:"rel:has-one"`
|
||||||
|
}
|
||||||
|
|
||||||
|
dataMap := map[string]interface{}{
|
||||||
|
"rid_actionoption": int64(10),
|
||||||
|
"rid_actionfunction": int64(99),
|
||||||
|
"description": "test option",
|
||||||
|
"AFN": map[string]interface{}{
|
||||||
|
"rid_actionfunction": int64(99),
|
||||||
|
"functionname": "MyFunction",
|
||||||
|
"fntype": "action",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var result ActionOption
|
||||||
|
err := reflection.MapToStruct(dataMap, &result)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("MapToStruct() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.Ridactionoption.Valid || result.Ridactionoption.Val != 10 {
|
||||||
|
t.Errorf("Ridactionoption = %v, want 10", result.Ridactionoption)
|
||||||
|
}
|
||||||
|
if !result.Description.Valid || result.Description.Val != "test option" {
|
||||||
|
t.Errorf("Description = %v, want 'test option'", result.Description)
|
||||||
|
}
|
||||||
|
if result.AFN == nil {
|
||||||
|
t.Fatal("AFN is nil, want non-nil")
|
||||||
|
}
|
||||||
|
if !result.AFN.Ridactionfunction.Valid || result.AFN.Ridactionfunction.Val != 99 {
|
||||||
|
t.Errorf("AFN.Ridactionfunction = %v, want 99", result.AFN.Ridactionfunction)
|
||||||
|
}
|
||||||
|
if !result.AFN.Functionname.Valid || result.AFN.Functionname.Val != "MyFunction" {
|
||||||
|
t.Errorf("AFN.Functionname = %v, want 'MyFunction'", result.AFN.Functionname)
|
||||||
|
}
|
||||||
|
if !result.AFN.Fntype.Valid || result.AFN.Fntype.Val != "action" {
|
||||||
|
t.Errorf("AFN.Fntype = %v, want 'action'", result.AFN.Fntype)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMapToStruct_SqlNull_NilValues(t *testing.T) {
|
func TestMapToStruct_SqlNull_NilValues(t *testing.T) {
|
||||||
// Test that SqlNull types handle nil values correctly
|
// Test that SqlNull types handle nil values correctly
|
||||||
type TestModel struct {
|
type TestModel struct {
|
||||||
|
|||||||
@@ -823,12 +823,12 @@ func TestToSnakeCase(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "UserID",
|
name: "UserID",
|
||||||
input: "UserID",
|
input: "UserID",
|
||||||
expected: "user_i_d",
|
expected: "user_id",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "HTTPServer",
|
name: "HTTPServer",
|
||||||
input: "HTTPServer",
|
input: "HTTPServer",
|
||||||
expected: "h_t_t_p_server",
|
expected: "http_server",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "lowercase",
|
name: "lowercase",
|
||||||
@@ -838,7 +838,7 @@ func TestToSnakeCase(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "UPPERCASE",
|
name: "UPPERCASE",
|
||||||
input: "UPPERCASE",
|
input: "UPPERCASE",
|
||||||
expected: "u_p_p_e_r_c_a_s_e",
|
expected: "uppercase",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Single",
|
name: "Single",
|
||||||
|
|||||||
@@ -0,0 +1,671 @@
|
|||||||
|
# resolvemcp
|
||||||
|
|
||||||
|
Package `resolvemcp` exposes registered database models as **Model Context Protocol (MCP) tools and resources** over HTTP/SSE transport. It mirrors the `resolvespec` package patterns — same model registration API, same filter/sort/pagination/preload options, same lifecycle hook system.
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/resolvemcp"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 1. Create a handler
|
||||||
|
handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{
|
||||||
|
BaseURL: "http://localhost:8080",
|
||||||
|
})
|
||||||
|
|
||||||
|
// 2. Register models
|
||||||
|
handler.RegisterModel("public", "users", &User{})
|
||||||
|
handler.RegisterModel("public", "orders", &Order{})
|
||||||
|
|
||||||
|
// 3. Mount routes
|
||||||
|
r := mux.NewRouter()
|
||||||
|
resolvemcp.SetupMuxRoutes(r, handler)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Config
|
||||||
|
|
||||||
|
```go
|
||||||
|
type Config struct {
|
||||||
|
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||||
|
// Sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||||
|
// If empty, it is detected from each incoming request using the Host header and
|
||||||
|
// TLS state (X-Forwarded-Proto is honoured for reverse-proxy deployments).
|
||||||
|
BaseURL string
|
||||||
|
|
||||||
|
// BasePath is the URL path prefix where MCP endpoints are mounted (e.g. "/mcp").
|
||||||
|
// Required.
|
||||||
|
BasePath string
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Handler Creation
|
||||||
|
|
||||||
|
| Function | Description |
|
||||||
|
|---|---|
|
||||||
|
| `NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler` | Backed by GORM |
|
||||||
|
| `NewHandlerWithBun(db *bun.DB, cfg Config) *Handler` | Backed by Bun |
|
||||||
|
| `NewHandlerWithDB(db common.Database, cfg Config) *Handler` | Backed by any `common.Database` |
|
||||||
|
| `NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler` | Full control over registry |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Registering Models
|
||||||
|
|
||||||
|
```go
|
||||||
|
handler.RegisterModel(schema, entity string, model interface{}) error
|
||||||
|
```
|
||||||
|
|
||||||
|
- `schema` — database schema name (e.g. `"public"`), or empty string for no schema prefix.
|
||||||
|
- `entity` — table/entity name (e.g. `"users"`).
|
||||||
|
- `model` — a pointer to a struct (e.g. `&User{}`).
|
||||||
|
|
||||||
|
Each call immediately creates four MCP **tools** and one MCP **resource** for the model.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## HTTP Transports
|
||||||
|
|
||||||
|
`Config.BasePath` is required and used for all route registration.
|
||||||
|
`Config.BaseURL` is optional — when empty it is detected from each request.
|
||||||
|
|
||||||
|
Two transports are supported: **SSE** (legacy, two-endpoint) and **Streamable HTTP** (recommended, single-endpoint).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### SSE Transport
|
||||||
|
|
||||||
|
Two endpoints: `GET {BasePath}/sse` (subscribe) + `POST {BasePath}/message` (send).
|
||||||
|
|
||||||
|
#### Gorilla Mux
|
||||||
|
|
||||||
|
```go
|
||||||
|
resolvemcp.SetupMuxRoutes(r, handler)
|
||||||
|
```
|
||||||
|
|
||||||
|
| Route | Method | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `{BasePath}/sse` | GET | SSE connection — clients subscribe here |
|
||||||
|
| `{BasePath}/message` | POST | JSON-RPC — clients send requests here |
|
||||||
|
|
||||||
|
#### bunrouter
|
||||||
|
|
||||||
|
```go
|
||||||
|
resolvemcp.SetupBunRouterRoutes(router, handler)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Gin / net/http / Echo
|
||||||
|
|
||||||
|
```go
|
||||||
|
sse := handler.SSEServer()
|
||||||
|
|
||||||
|
engine.Any("/mcp/*path", gin.WrapH(sse)) // Gin
|
||||||
|
http.Handle("/mcp/", sse) // net/http
|
||||||
|
e.Any("/mcp/*", echo.WrapHandler(sse)) // Echo
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Streamable HTTP Transport
|
||||||
|
|
||||||
|
Single endpoint at `{BasePath}`. Handles POST (client→server) and GET (server→client streaming). Preferred for new integrations.
|
||||||
|
|
||||||
|
#### Gorilla Mux
|
||||||
|
|
||||||
|
```go
|
||||||
|
resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler)
|
||||||
|
```
|
||||||
|
|
||||||
|
Mounts the handler at `{BasePath}` (all methods).
|
||||||
|
|
||||||
|
#### bunrouter
|
||||||
|
|
||||||
|
```go
|
||||||
|
resolvemcp.SetupBunRouterStreamableHTTPRoutes(router, handler)
|
||||||
|
```
|
||||||
|
|
||||||
|
Registers GET, POST, DELETE on `{BasePath}`.
|
||||||
|
|
||||||
|
#### Gin / net/http / Echo
|
||||||
|
|
||||||
|
```go
|
||||||
|
h := handler.StreamableHTTPServer()
|
||||||
|
// or: h := resolvemcp.NewStreamableHTTPHandler(handler)
|
||||||
|
|
||||||
|
engine.Any("/mcp", gin.WrapH(h)) // Gin
|
||||||
|
http.Handle("/mcp", h) // net/http
|
||||||
|
e.Any("/mcp", echo.WrapHandler(h)) // Echo
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## OAuth2 Authentication
|
||||||
|
|
||||||
|
`resolvemcp` ships a full **MCP-standard OAuth2 authorization server** (`pkg/security.OAuthServer`) that MCP clients (Claude Desktop, Cursor, etc.) can discover and use automatically.
|
||||||
|
|
||||||
|
It can operate as:
|
||||||
|
- **Its own identity provider** — shows a login form, validates via `DatabaseAuthenticator.Login()`
|
||||||
|
- **An OAuth2 federation layer** — delegates to external providers (Google, GitHub, Microsoft, etc.)
|
||||||
|
- **Both simultaneously**
|
||||||
|
|
||||||
|
### Standard endpoints served
|
||||||
|
|
||||||
|
| Path | Spec | Purpose |
|
||||||
|
|---|---|---|
|
||||||
|
| `GET /.well-known/oauth-authorization-server` | RFC 8414 | MCP client auto-discovery |
|
||||||
|
| `POST /oauth/register` | RFC 7591 | Dynamic client registration |
|
||||||
|
| `GET /oauth/authorize` | OAuth 2.1 + PKCE | Start login (form or provider redirect) |
|
||||||
|
| `POST /oauth/authorize` | — | Login form submission |
|
||||||
|
| `POST /oauth/token` | OAuth 2.1 | Auth code → Bearer token exchange |
|
||||||
|
| `POST /oauth/token` (refresh) | OAuth 2.1 | Refresh token rotation |
|
||||||
|
| `GET /oauth/provider/callback` | Internal | External provider redirect target |
|
||||||
|
|
||||||
|
MCP clients send `Authorization: Bearer <token>` on all subsequent requests.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Mode 1 — Direct login (server as identity provider)
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
|
||||||
|
db, _ := sql.Open("postgres", dsn)
|
||||||
|
auth := security.NewDatabaseAuthenticator(db)
|
||||||
|
|
||||||
|
handler := resolvemcp.NewHandlerWithGORM(gormDB, resolvemcp.Config{
|
||||||
|
BaseURL: "https://api.example.com",
|
||||||
|
BasePath: "/mcp",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Enable the OAuth2 server — auth enables the login form
|
||||||
|
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||||
|
Issuer: "https://api.example.com",
|
||||||
|
}, auth)
|
||||||
|
|
||||||
|
provider, _ := security.NewCompositeSecurityProvider(auth, colSec, rowSec)
|
||||||
|
securityList, _ := security.NewSecurityList(provider)
|
||||||
|
security.RegisterSecurityHooks(handler, securityList)
|
||||||
|
|
||||||
|
http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
|
||||||
|
```
|
||||||
|
|
||||||
|
MCP client flow:
|
||||||
|
1. Discovers server at `/.well-known/oauth-authorization-server`
|
||||||
|
2. Registers itself at `/oauth/register`
|
||||||
|
3. Redirects user to `/oauth/authorize` → login form appears
|
||||||
|
4. On submit, exchanges code at `/oauth/token` → receives `Authorization: Bearer` token
|
||||||
|
5. Uses token on all MCP tool calls
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Mode 2 — External provider (Google, GitHub, etc.)
|
||||||
|
|
||||||
|
The `RedirectURL` in the provider config must point to `/oauth/provider/callback` on this server.
|
||||||
|
|
||||||
|
```go
|
||||||
|
auth := security.NewDatabaseAuthenticator(db).WithOAuth2(security.OAuth2Config{
|
||||||
|
ClientID: os.Getenv("GOOGLE_CLIENT_ID"),
|
||||||
|
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
|
||||||
|
RedirectURL: "https://api.example.com/oauth/provider/callback",
|
||||||
|
Scopes: []string{"openid", "profile", "email"},
|
||||||
|
AuthURL: "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
TokenURL: "https://oauth2.googleapis.com/token",
|
||||||
|
UserInfoURL: "https://www.googleapis.com/oauth2/v2/userinfo",
|
||||||
|
ProviderName: "google",
|
||||||
|
})
|
||||||
|
|
||||||
|
// Pass `auth` so the OAuth server supports persistence, introspection, and revocation.
|
||||||
|
// Google handles the end-user authentication flow via redirect.
|
||||||
|
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||||
|
Issuer: "https://api.example.com",
|
||||||
|
}, auth)
|
||||||
|
handler.RegisterOAuth2Provider(auth, "google")
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Mode 3 — Both (login form + external providers)
|
||||||
|
|
||||||
|
```go
|
||||||
|
handler.EnableOAuthServer(security.OAuthServerConfig{
|
||||||
|
Issuer: "https://api.example.com",
|
||||||
|
LoginTitle: "My App Login",
|
||||||
|
}, auth) // auth enables the username/password form
|
||||||
|
|
||||||
|
handler.RegisterOAuth2Provider(googleAuth, "google")
|
||||||
|
handler.RegisterOAuth2Provider(githubAuth, "github")
|
||||||
|
```
|
||||||
|
|
||||||
|
When external providers are registered they take priority; the login form is used as fallback when no providers are configured.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Using `security.OAuthServer` standalone
|
||||||
|
|
||||||
|
The authorization server lives in `pkg/security` and can be used with any HTTP framework independently of `resolvemcp`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
oauthSrv := security.NewOAuthServer(security.OAuthServerConfig{
|
||||||
|
Issuer: "https://api.example.com",
|
||||||
|
}, auth)
|
||||||
|
oauthSrv.RegisterExternalProvider(googleAuth, "google")
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.Handle("/", oauthSrv.HTTPHandler()) // mounts all OAuth2 routes
|
||||||
|
mux.Handle("/mcp/", myMCPHandler)
|
||||||
|
http.ListenAndServe(":8080", mux)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Cookie-based flow (legacy)
|
||||||
|
|
||||||
|
For simple setups without full MCP OAuth2 compliance, use the legacy helpers that set a session cookie after external provider login:
|
||||||
|
|
||||||
|
```go
|
||||||
|
resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
|
||||||
|
ProviderName: "google",
|
||||||
|
LoginPath: "/auth/google/login",
|
||||||
|
CallbackPath: "/auth/google/callback",
|
||||||
|
AfterLoginRedirect: "/",
|
||||||
|
})
|
||||||
|
resolvemcp.SetupMuxRoutesWithAuth(r, handler, securityList)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Security
|
||||||
|
|
||||||
|
`resolvemcp` integrates with the `security` package to provide per-entity access control, row-level security, and column-level security — the same system used by `resolvespec` and `restheadspec`.
|
||||||
|
|
||||||
|
### Wiring security hooks
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
|
||||||
|
securityList := security.NewSecurityList(mySecurityProvider)
|
||||||
|
resolvemcp.RegisterSecurityHooks(handler, securityList)
|
||||||
|
```
|
||||||
|
|
||||||
|
Call `RegisterSecurityHooks` **once**, after creating the handler and before registering models. It installs these controls automatically:
|
||||||
|
|
||||||
|
| Hook | Effect |
|
||||||
|
|---|---|
|
||||||
|
| `BeforeHandle` | Enforces per-entity operation rules (see below) |
|
||||||
|
| `BeforeRead` | Loads RLS/CLS rules, then injects a user-scoped WHERE clause |
|
||||||
|
| `AfterRead` | Masks/hides columns per column-security rules; writes audit log |
|
||||||
|
| `BeforeUpdate` | Blocks update if `CanUpdate` is false |
|
||||||
|
| `BeforeDelete` | Blocks delete if `CanDelete` is false |
|
||||||
|
|
||||||
|
### Per-entity operation rules
|
||||||
|
|
||||||
|
Use `RegisterModelWithRules` instead of `RegisterModel` to set access rules at registration time:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import "github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
|
||||||
|
// Read-only entity
|
||||||
|
handler.RegisterModelWithRules("public", "audit_logs", &AuditLog{}, modelregistry.ModelRules{
|
||||||
|
CanRead: true,
|
||||||
|
CanCreate: false,
|
||||||
|
CanUpdate: false,
|
||||||
|
CanDelete: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Public read, authenticated write
|
||||||
|
handler.RegisterModelWithRules("public", "products", &Product{}, modelregistry.ModelRules{
|
||||||
|
CanPublicRead: true,
|
||||||
|
CanRead: true,
|
||||||
|
CanCreate: true,
|
||||||
|
CanUpdate: true,
|
||||||
|
CanDelete: false,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
To update rules for an already-registered model:
|
||||||
|
|
||||||
|
```go
|
||||||
|
handler.SetModelRules("public", "users", modelregistry.ModelRules{
|
||||||
|
CanRead: true,
|
||||||
|
CanCreate: true,
|
||||||
|
CanUpdate: true,
|
||||||
|
CanDelete: false,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
`RegisterModel` (no rules) registers with all-allowed defaults (`CanRead/Create/Update/Delete = true`).
|
||||||
|
|
||||||
|
### ModelRules fields
|
||||||
|
|
||||||
|
| Field | Default | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `CanPublicRead` | `false` | Allow unauthenticated reads |
|
||||||
|
| `CanPublicCreate` | `false` | Allow unauthenticated creates |
|
||||||
|
| `CanPublicUpdate` | `false` | Allow unauthenticated updates |
|
||||||
|
| `CanPublicDelete` | `false` | Allow unauthenticated deletes |
|
||||||
|
| `CanRead` | `true` | Allow authenticated reads |
|
||||||
|
| `CanCreate` | `true` | Allow authenticated creates |
|
||||||
|
| `CanUpdate` | `true` | Allow authenticated updates |
|
||||||
|
| `CanDelete` | `true` | Allow authenticated deletes |
|
||||||
|
| `SecurityDisabled` | `false` | Skip all security checks for this model |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## MCP Tools
|
||||||
|
|
||||||
|
### Tool Naming
|
||||||
|
|
||||||
|
```
|
||||||
|
{operation}_{schema}_{entity} // e.g. read_public_users
|
||||||
|
{operation}_{entity} // e.g. read_users (when schema is empty)
|
||||||
|
```
|
||||||
|
|
||||||
|
Operations: `read`, `create`, `update`, `delete`.
|
||||||
|
|
||||||
|
### Read Tool — `read_{schema}_{entity}`
|
||||||
|
|
||||||
|
Fetch one or many records.
|
||||||
|
|
||||||
|
| Argument | Type | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `id` | string | Primary key value. Omit to return multiple records. |
|
||||||
|
| `limit` | number | Max records per page (recommended: 10–100). |
|
||||||
|
| `offset` | number | Records to skip (offset-based pagination). |
|
||||||
|
| `cursor_forward` | string | PK of the **last** record on the current page (next-page cursor). |
|
||||||
|
| `cursor_backward` | string | PK of the **first** record on the current page (prev-page cursor). |
|
||||||
|
| `columns` | array | Column names to include. Omit for all columns. |
|
||||||
|
| `omit_columns` | array | Column names to exclude. |
|
||||||
|
| `filters` | array | Filter objects (see [Filtering](#filtering)). |
|
||||||
|
| `sort` | array | Sort objects (see [Sorting](#sorting)). |
|
||||||
|
| `preloads` | array | Relation preload objects (see [Preloading](#preloading)). |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": [...],
|
||||||
|
"metadata": {
|
||||||
|
"total": 100,
|
||||||
|
"filtered": 100,
|
||||||
|
"count": 10,
|
||||||
|
"limit": 10,
|
||||||
|
"offset": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Create Tool — `create_{schema}_{entity}`
|
||||||
|
|
||||||
|
Insert one or more records.
|
||||||
|
|
||||||
|
| Argument | Type | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `data` | object \| array | Single object or array of objects to insert. |
|
||||||
|
|
||||||
|
Array input runs inside a single transaction — all succeed or all fail.
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{ "success": true, "data": { ... } }
|
||||||
|
```
|
||||||
|
|
||||||
|
### Update Tool — `update_{schema}_{entity}`
|
||||||
|
|
||||||
|
Partially update an existing record. Only non-null, non-empty fields in `data` are applied; existing values are preserved for omitted fields.
|
||||||
|
|
||||||
|
| Argument | Type | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `id` | string | Primary key of the record. Can also be included inside `data`. |
|
||||||
|
| `data` | object (required) | Fields to update. |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{ "success": true, "data": { ...merged record... } }
|
||||||
|
```
|
||||||
|
|
||||||
|
### Delete Tool — `delete_{schema}_{entity}`
|
||||||
|
|
||||||
|
Delete a record by primary key. **Irreversible.**
|
||||||
|
|
||||||
|
| Argument | Type | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `id` | string (required) | Primary key of the record to delete. |
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{ "success": true, "data": { ...deleted record... } }
|
||||||
|
```
|
||||||
|
|
||||||
|
### Annotation Tool — `resolvespec_annotate`
|
||||||
|
|
||||||
|
Store or retrieve freeform annotation records for any tool, model, or entity. Registered automatically on every handler.
|
||||||
|
|
||||||
|
| Argument | Type | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `tool_name` | string (required) | Key to annotate — an MCP tool name (e.g. `read_public_users`), a model name (e.g. `public.users`), or any other identifier. |
|
||||||
|
| `annotations` | object | Annotation data to persist. Omit to retrieve existing annotations instead. |
|
||||||
|
|
||||||
|
**Set annotations** (calls `resolvespec_set_annotation(tool_name, annotations)`):
|
||||||
|
```json
|
||||||
|
{ "tool_name": "read_public_users", "annotations": { "description": "Returns active users", "owner": "platform-team" } }
|
||||||
|
```
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{ "success": true, "tool_name": "read_public_users", "action": "set" }
|
||||||
|
```
|
||||||
|
|
||||||
|
**Get annotations** (calls `resolvespec_get_annotation(tool_name)`):
|
||||||
|
```json
|
||||||
|
{ "tool_name": "read_public_users" }
|
||||||
|
```
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{ "success": true, "tool_name": "read_public_users", "action": "get", "annotations": { ... } }
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Resource — `{schema}.{entity}`
|
||||||
|
|
||||||
|
Each model is also registered as an MCP resource with URI `schema.entity` (or just `entity` when schema is empty). Reading the resource returns up to 100 records as `application/json`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Filtering
|
||||||
|
|
||||||
|
Pass an array of filter objects to the `filters` argument:
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{ "column": "status", "operator": "=", "value": "active" },
|
||||||
|
{ "column": "age", "operator": ">", "value": 18, "logic_operator": "AND" },
|
||||||
|
{ "column": "role", "operator": "in", "value": ["admin", "editor"], "logic_operator": "OR" }
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Supported Operators
|
||||||
|
|
||||||
|
| Operator | Aliases | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `=` | `eq` | Equal |
|
||||||
|
| `!=` | `neq`, `<>` | Not equal |
|
||||||
|
| `>` | `gt` | Greater than |
|
||||||
|
| `>=` | `gte` | Greater than or equal |
|
||||||
|
| `<` | `lt` | Less than |
|
||||||
|
| `<=` | `lte` | Less than or equal |
|
||||||
|
| `like` | | SQL LIKE (case-sensitive) |
|
||||||
|
| `ilike` | | SQL ILIKE (case-insensitive) |
|
||||||
|
| `in` | | Value in list |
|
||||||
|
| `is_null` | | Column IS NULL |
|
||||||
|
| `is_not_null` | | Column IS NOT NULL |
|
||||||
|
|
||||||
|
### Logic Operators
|
||||||
|
|
||||||
|
- `"logic_operator": "AND"` (default) — filter is AND-chained with the previous condition.
|
||||||
|
- `"logic_operator": "OR"` — filter is OR-grouped with the previous condition.
|
||||||
|
|
||||||
|
Consecutive OR filters are grouped into a single `(cond1 OR cond2 OR ...)` clause.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Sorting
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{ "column": "created_at", "direction": "desc" },
|
||||||
|
{ "column": "name", "direction": "asc" }
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Pagination
|
||||||
|
|
||||||
|
### Offset-Based
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "limit": 20, "offset": 40 }
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cursor-Based
|
||||||
|
|
||||||
|
Cursor pagination uses a SQL `EXISTS` subquery for stable, efficient paging. Always pair with a `sort` argument.
|
||||||
|
|
||||||
|
```json
|
||||||
|
// Next page: pass the PK of the last record on the current page
|
||||||
|
{ "cursor_forward": "42", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
|
||||||
|
|
||||||
|
// Previous page: pass the PK of the first record on the current page
|
||||||
|
{ "cursor_backward": "23", "limit": 20, "sort": [{"column": "id", "direction": "asc"}] }
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Preloading Relations
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{ "relation": "Profile" },
|
||||||
|
{ "relation": "Orders" }
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Available relations are listed in each tool's description. Only relations defined on the model struct are valid.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Hook System
|
||||||
|
|
||||||
|
Hooks let you intercept and modify CRUD operations at well-defined lifecycle points.
|
||||||
|
|
||||||
|
### Hook Types
|
||||||
|
|
||||||
|
| Constant | Fires |
|
||||||
|
|---|---|
|
||||||
|
| `BeforeHandle` | After model resolution, before operation dispatch (all CRUD) |
|
||||||
|
| `BeforeRead` / `AfterRead` | Around read queries |
|
||||||
|
| `BeforeCreate` / `AfterCreate` | Around insert |
|
||||||
|
| `BeforeUpdate` / `AfterUpdate` | Around update |
|
||||||
|
| `BeforeDelete` / `AfterDelete` | Around delete |
|
||||||
|
|
||||||
|
### Registering Hooks
|
||||||
|
|
||||||
|
```go
|
||||||
|
handler.Hooks().Register(resolvemcp.BeforeCreate, func(ctx *resolvemcp.HookContext) error {
|
||||||
|
// Inject a timestamp before insert
|
||||||
|
if data, ok := ctx.Data.(map[string]interface{}); ok {
|
||||||
|
data["created_at"] = time.Now()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register the same hook for multiple events
|
||||||
|
handler.Hooks().RegisterMultiple(
|
||||||
|
[]resolvemcp.HookType{resolvemcp.BeforeCreate, resolvemcp.BeforeUpdate},
|
||||||
|
auditHook,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### HookContext Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `Context` | `context.Context` | Request context |
|
||||||
|
| `Handler` | `*Handler` | The resolvemcp handler |
|
||||||
|
| `Schema` | `string` | Database schema name |
|
||||||
|
| `Entity` | `string` | Entity/table name |
|
||||||
|
| `Model` | `interface{}` | Registered model instance |
|
||||||
|
| `Options` | `common.RequestOptions` | Parsed request options (read operations) |
|
||||||
|
| `Operation` | `string` | `"read"`, `"create"`, `"update"`, or `"delete"` |
|
||||||
|
| `ID` | `string` | Primary key from request (read/update/delete) |
|
||||||
|
| `Data` | `interface{}` | Input data (create/update — modifiable) |
|
||||||
|
| `Result` | `interface{}` | Output data (set by After hooks) |
|
||||||
|
| `Error` | `error` | Operation error, if any |
|
||||||
|
| `Query` | `common.SelectQuery` | Live query object (available in `BeforeRead`) |
|
||||||
|
| `Tx` | `common.Database` | Database/transaction handle |
|
||||||
|
| `Abort` | `bool` | Set to `true` to abort the operation |
|
||||||
|
| `AbortMessage` | `string` | Error message returned when aborting |
|
||||||
|
| `AbortCode` | `int` | Optional status code for the abort |
|
||||||
|
|
||||||
|
### Aborting an Operation
|
||||||
|
|
||||||
|
```go
|
||||||
|
handler.Hooks().Register(resolvemcp.BeforeDelete, func(ctx *resolvemcp.HookContext) error {
|
||||||
|
ctx.Abort = true
|
||||||
|
ctx.AbortMessage = "deletion is disabled"
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Managing Hooks
|
||||||
|
|
||||||
|
```go
|
||||||
|
registry := handler.Hooks()
|
||||||
|
registry.HasHooks(resolvemcp.BeforeCreate) // bool
|
||||||
|
registry.Clear(resolvemcp.BeforeCreate) // remove hooks for one type
|
||||||
|
registry.ClearAll() // remove all hooks
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Context Helpers
|
||||||
|
|
||||||
|
Request metadata is threaded through `context.Context` during handler execution. Hooks and custom tools can read it:
|
||||||
|
|
||||||
|
```go
|
||||||
|
schema := resolvemcp.GetSchema(ctx)
|
||||||
|
entity := resolvemcp.GetEntity(ctx)
|
||||||
|
tableName := resolvemcp.GetTableName(ctx)
|
||||||
|
model := resolvemcp.GetModel(ctx)
|
||||||
|
modelPtr := resolvemcp.GetModelPtr(ctx)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also set values manually (e.g. in middleware):
|
||||||
|
|
||||||
|
```go
|
||||||
|
ctx = resolvemcp.WithSchema(ctx, "tenant_a")
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Adding Custom MCP Tools
|
||||||
|
|
||||||
|
Access the underlying `*server.MCPServer` to register additional tools:
|
||||||
|
|
||||||
|
```go
|
||||||
|
mcpServer := handler.MCPServer()
|
||||||
|
mcpServer.AddTool(myTool, myHandler)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Table Name Resolution
|
||||||
|
|
||||||
|
The handler resolves table names in priority order:
|
||||||
|
|
||||||
|
1. `TableNameProvider` interface — `TableName() string` (can return `"schema.table"`)
|
||||||
|
2. `SchemaProvider` interface — `SchemaName() string` (combined with entity name)
|
||||||
|
3. Fallback: `schema.entity` (or `schema_entity` for SQLite)
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
const annotationToolName = "resolvespec_annotate"
|
||||||
|
|
||||||
|
// registerAnnotationTool adds the resolvespec_annotate tool to the MCP server.
|
||||||
|
// The tool lets models/entities store and retrieve freeform annotation records
|
||||||
|
// using the resolvespec_set_annotation / resolvespec_get_annotation database procedures.
|
||||||
|
func registerAnnotationTool(h *Handler) {
|
||||||
|
tool := mcp.NewTool(annotationToolName,
|
||||||
|
mcp.WithDescription(
|
||||||
|
"Store or retrieve annotations for any MCP tool, model, or entity.\n\n"+
|
||||||
|
"To set annotations: provide both 'tool_name' and 'annotations'. "+
|
||||||
|
"Calls resolvespec_set_annotation(tool_name, annotations) to persist the data.\n\n"+
|
||||||
|
"To get annotations: provide only 'tool_name'. "+
|
||||||
|
"Calls resolvespec_get_annotation(tool_name) and returns the stored annotations.\n\n"+
|
||||||
|
"'tool_name' may be any identifier: an MCP tool name (e.g. 'read_public_users'), "+
|
||||||
|
"a model/entity name (e.g. 'public.users'), or any other key.",
|
||||||
|
),
|
||||||
|
mcp.WithString("tool_name",
|
||||||
|
mcp.Description("Name of the tool, model, or entity to annotate (e.g. 'read_public_users', 'public.users')."),
|
||||||
|
mcp.Required(),
|
||||||
|
),
|
||||||
|
mcp.WithObject("annotations",
|
||||||
|
mcp.Description("Annotation data to store. Omit to retrieve existing annotations instead of setting them."),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
args := req.GetArguments()
|
||||||
|
|
||||||
|
toolName, ok := args["tool_name"].(string)
|
||||||
|
if !ok || toolName == "" {
|
||||||
|
return mcp.NewToolResultError("missing required argument: tool_name"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
annotations, hasAnnotations := args["annotations"]
|
||||||
|
|
||||||
|
if hasAnnotations && annotations != nil {
|
||||||
|
return executeSetAnnotation(ctx, h, toolName, annotations)
|
||||||
|
}
|
||||||
|
return executeGetAnnotation(ctx, h, toolName)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func executeSetAnnotation(ctx context.Context, h *Handler, toolName string, annotations interface{}) (*mcp.CallToolResult, error) {
|
||||||
|
jsonBytes, err := json.Marshal(annotations)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("failed to marshal annotations: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = h.db.Exec(ctx, "SELECT resolvespec_set_annotation($1, $2)", toolName, string(jsonBytes))
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("failed to set annotation: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalResult(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"tool_name": toolName,
|
||||||
|
"action": "set",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func executeGetAnnotation(ctx context.Context, h *Handler, toolName string) (*mcp.CallToolResult, error) {
|
||||||
|
var rows []map[string]interface{}
|
||||||
|
err := h.db.Query(ctx, &rows, "SELECT resolvespec_get_annotation($1)", toolName)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("failed to get annotation: %v", err)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var annotations interface{}
|
||||||
|
if len(rows) > 0 {
|
||||||
|
// The procedure returns a single value; extract the first column of the first row.
|
||||||
|
for _, v := range rows[0] {
|
||||||
|
annotations = v
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the value is a []byte or string containing JSON, decode it so it round-trips cleanly.
|
||||||
|
switch v := annotations.(type) {
|
||||||
|
case []byte:
|
||||||
|
var decoded interface{}
|
||||||
|
if json.Unmarshal(v, &decoded) == nil {
|
||||||
|
annotations = decoded
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
var decoded interface{}
|
||||||
|
if json.Unmarshal([]byte(v), &decoded) == nil {
|
||||||
|
annotations = decoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalResult(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"tool_name": toolName,
|
||||||
|
"action": "get",
|
||||||
|
"annotations": annotations,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const (
|
||||||
|
contextKeySchema contextKey = "schema"
|
||||||
|
contextKeyEntity contextKey = "entity"
|
||||||
|
contextKeyTableName contextKey = "tableName"
|
||||||
|
contextKeyModel contextKey = "model"
|
||||||
|
contextKeyModelPtr contextKey = "modelPtr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WithSchema(ctx context.Context, schema string) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeySchema, schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSchema(ctx context.Context) string {
|
||||||
|
if v := ctx.Value(contextKeySchema); v != nil {
|
||||||
|
return v.(string)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithEntity(ctx context.Context, entity string) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyEntity, entity)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetEntity(ctx context.Context) string {
|
||||||
|
if v := ctx.Value(contextKeyEntity); v != nil {
|
||||||
|
return v.(string)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithTableName(ctx context.Context, tableName string) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyTableName, tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetTableName(ctx context.Context) string {
|
||||||
|
if v := ctx.Value(contextKeyTableName); v != nil {
|
||||||
|
return v.(string)
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithModel(ctx context.Context, model interface{}) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyModel, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetModel(ctx context.Context) interface{} {
|
||||||
|
return ctx.Value(contextKeyModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithModelPtr(ctx context.Context, modelPtr interface{}) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyModelPtr, modelPtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetModelPtr(ctx context.Context) interface{} {
|
||||||
|
return ctx.Value(contextKeyModelPtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func withRequestData(ctx context.Context, schema, entity, tableName string, model, modelPtr interface{}) context.Context {
|
||||||
|
ctx = WithSchema(ctx, schema)
|
||||||
|
ctx = WithEntity(ctx, entity)
|
||||||
|
ctx = WithTableName(ctx, tableName)
|
||||||
|
ctx = WithModel(ctx, model)
|
||||||
|
ctx = WithModelPtr(ctx, modelPtr)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
// Cursor-based pagination adapted from pkg/resolvespec/cursor.go.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cursorDirection int
|
||||||
|
|
||||||
|
const (
|
||||||
|
cursorForward cursorDirection = 1
|
||||||
|
cursorBackward cursorDirection = -1
|
||||||
|
)
|
||||||
|
|
||||||
|
// getCursorFilter generates a SQL EXISTS subquery for cursor-based pagination.
|
||||||
|
// expandJoins is an optional map[alias]string of JOIN clauses for join-column sort support.
|
||||||
|
func getCursorFilter(
|
||||||
|
tableName string,
|
||||||
|
pkName string,
|
||||||
|
modelColumns []string,
|
||||||
|
options common.RequestOptions,
|
||||||
|
expandJoins map[string]string,
|
||||||
|
) (string, error) {
|
||||||
|
fullTableName := tableName
|
||||||
|
if strings.Contains(tableName, ".") {
|
||||||
|
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
cursorID, direction := getActiveCursor(options)
|
||||||
|
if cursorID == "" {
|
||||||
|
return "", fmt.Errorf("no cursor provided for table %s", tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
sortItems := options.Sort
|
||||||
|
if len(sortItems) == 0 {
|
||||||
|
return "", fmt.Errorf("no sort columns defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
var whereClauses []string
|
||||||
|
joinSQL := ""
|
||||||
|
reverse := direction < 0
|
||||||
|
|
||||||
|
for _, s := range sortItems {
|
||||||
|
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||||
|
if col == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(col, ".")
|
||||||
|
field := strings.TrimSpace(parts[len(parts)-1])
|
||||||
|
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||||
|
|
||||||
|
desc := strings.EqualFold(s.Direction, "desc")
|
||||||
|
if reverse {
|
||||||
|
desc = !desc
|
||||||
|
}
|
||||||
|
|
||||||
|
cursorCol, targetCol, isJoin, err := resolveCursorColumn(field, prefix, tableName, modelColumns)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Skipping invalid sort column %q: %v", col, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if isJoin {
|
||||||
|
if expandJoins != nil {
|
||||||
|
if joinClause, ok := expandJoins[prefix]; ok {
|
||||||
|
jSQL, cRef := rewriteCursorJoin(joinClause, tableName, prefix)
|
||||||
|
joinSQL = jSQL
|
||||||
|
cursorCol = cRef + "." + field
|
||||||
|
targetCol = prefix + "." + field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cursorCol == "" {
|
||||||
|
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
op := "<"
|
||||||
|
if desc {
|
||||||
|
op = ">"
|
||||||
|
}
|
||||||
|
whereClauses = append(whereClauses, fmt.Sprintf("%s %s %s", cursorCol, op, targetCol))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(whereClauses) == 0 {
|
||||||
|
return "", fmt.Errorf("no valid sort columns after filtering")
|
||||||
|
}
|
||||||
|
|
||||||
|
orSQL := buildCursorPriorityChain(whereClauses)
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM %s cursor_select
|
||||||
|
%s
|
||||||
|
WHERE cursor_select.%s = %s
|
||||||
|
AND (%s)
|
||||||
|
)`,
|
||||||
|
fullTableName,
|
||||||
|
joinSQL,
|
||||||
|
pkName,
|
||||||
|
cursorID,
|
||||||
|
orSQL,
|
||||||
|
)
|
||||||
|
|
||||||
|
return query, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getActiveCursor(options common.RequestOptions) (id string, direction cursorDirection) {
|
||||||
|
if options.CursorForward != "" {
|
||||||
|
return options.CursorForward, cursorForward
|
||||||
|
}
|
||||||
|
if options.CursorBackward != "" {
|
||||||
|
return options.CursorBackward, cursorBackward
|
||||||
|
}
|
||||||
|
return "", 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveCursorColumn(field, prefix, tableName string, modelColumns []string) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||||
|
if strings.Contains(field, "->") {
|
||||||
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelColumns != nil {
|
||||||
|
for _, col := range modelColumns {
|
||||||
|
if strings.EqualFold(col, field) {
|
||||||
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if prefix != "" && prefix != tableName {
|
||||||
|
return "", "", true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rewriteCursorJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||||
|
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||||
|
cursorAlias = "cursor_select_" + alias
|
||||||
|
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||||
|
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||||
|
return joinSQL, cursorAlias
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCursorPriorityChain(clauses []string) string {
|
||||||
|
var or []string
|
||||||
|
for i := 0; i < len(clauses); i++ {
|
||||||
|
and := strings.Join(clauses[:i+1], "\n AND ")
|
||||||
|
or = append(or, "("+and+")")
|
||||||
|
}
|
||||||
|
return strings.Join(or, "\n OR ")
|
||||||
|
}
|
||||||
@@ -0,0 +1,761 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/mark3labs/mcp-go/server"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handler exposes registered database models as MCP tools and resources.
|
||||||
|
type Handler struct {
|
||||||
|
db common.Database
|
||||||
|
registry common.ModelRegistry
|
||||||
|
hooks *HookRegistry
|
||||||
|
mcpServer *server.MCPServer
|
||||||
|
config Config
|
||||||
|
name string
|
||||||
|
version string
|
||||||
|
oauth2Regs []oauth2Registration
|
||||||
|
oauthSrv *security.OAuthServer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandler creates a Handler with the given database, model registry, and config.
|
||||||
|
func NewHandler(db common.Database, registry common.ModelRegistry, cfg Config) *Handler {
|
||||||
|
h := &Handler{
|
||||||
|
db: db,
|
||||||
|
registry: registry,
|
||||||
|
hooks: NewHookRegistry(),
|
||||||
|
mcpServer: server.NewMCPServer("resolvemcp", "1.0.0"),
|
||||||
|
config: cfg,
|
||||||
|
name: "resolvemcp",
|
||||||
|
version: "1.0.0",
|
||||||
|
}
|
||||||
|
registerAnnotationTool(h)
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hooks returns the hook registry.
|
||||||
|
func (h *Handler) Hooks() *HookRegistry {
|
||||||
|
return h.hooks
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDatabase returns the underlying database.
|
||||||
|
func (h *Handler) GetDatabase() common.Database {
|
||||||
|
return h.db
|
||||||
|
}
|
||||||
|
|
||||||
|
// MCPServer returns the underlying MCP server, e.g. to add custom tools.
|
||||||
|
func (h *Handler) MCPServer() *server.MCPServer {
|
||||||
|
return h.mcpServer
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSEServer returns an http.Handler that serves MCP over SSE.
|
||||||
|
// Config.BasePath must be set. Config.BaseURL is used when set; if empty it is
|
||||||
|
// detected automatically from each incoming request.
|
||||||
|
func (h *Handler) SSEServer() http.Handler {
|
||||||
|
if h.config.BaseURL != "" {
|
||||||
|
return h.newSSEServer(h.config.BaseURL, h.config.BasePath)
|
||||||
|
}
|
||||||
|
return &dynamicSSEHandler{h: h}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamableHTTPServer returns an http.Handler that serves MCP over the streamable HTTP transport.
|
||||||
|
// Unlike SSE (which requires two endpoints), streamable HTTP uses a single endpoint for all
|
||||||
|
// client-server communication (POST for requests, GET for server-initiated messages).
|
||||||
|
// Mount the returned handler at the desired path; the path itself becomes the MCP endpoint.
|
||||||
|
func (h *Handler) StreamableHTTPServer() http.Handler {
|
||||||
|
return server.NewStreamableHTTPServer(h.mcpServer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newSSEServer creates a concrete *server.SSEServer for known baseURL and basePath values.
|
||||||
|
func (h *Handler) newSSEServer(baseURL, basePath string) *server.SSEServer {
|
||||||
|
return server.NewSSEServer(
|
||||||
|
h.mcpServer,
|
||||||
|
server.WithBaseURL(baseURL),
|
||||||
|
server.WithStaticBasePath(basePath),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dynamicSSEHandler detects BaseURL from each request and delegates to a cached
|
||||||
|
// *server.SSEServer per detected baseURL. Used when Config.BaseURL is empty.
|
||||||
|
type dynamicSSEHandler struct {
|
||||||
|
h *Handler
|
||||||
|
mu sync.Mutex
|
||||||
|
pool map[string]*server.SSEServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dynamicSSEHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
baseURL := requestBaseURL(r)
|
||||||
|
|
||||||
|
d.mu.Lock()
|
||||||
|
if d.pool == nil {
|
||||||
|
d.pool = make(map[string]*server.SSEServer)
|
||||||
|
}
|
||||||
|
s, ok := d.pool[baseURL]
|
||||||
|
if !ok {
|
||||||
|
s = d.h.newSSEServer(baseURL, d.h.config.BasePath)
|
||||||
|
d.pool[baseURL] = s
|
||||||
|
}
|
||||||
|
d.mu.Unlock()
|
||||||
|
|
||||||
|
s.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestBaseURL builds the base URL from an incoming request.
|
||||||
|
// It honours the X-Forwarded-Proto header for deployments behind a proxy.
|
||||||
|
func requestBaseURL(r *http.Request) string {
|
||||||
|
scheme := "http"
|
||||||
|
if r.TLS != nil {
|
||||||
|
scheme = "https"
|
||||||
|
}
|
||||||
|
if proto := r.Header.Get("X-Forwarded-Proto"); proto != "" {
|
||||||
|
scheme = proto
|
||||||
|
}
|
||||||
|
return scheme + "://" + r.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterModel registers a model and immediately exposes it as MCP tools and a resource.
|
||||||
|
func (h *Handler) RegisterModel(schema, entity string, model interface{}) error {
|
||||||
|
fullName := buildModelName(schema, entity)
|
||||||
|
if err := h.registry.RegisterModel(fullName, model); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
registerModelTools(h, schema, entity, model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterModelWithRules registers a model and sets per-entity operation rules
|
||||||
|
// (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*, SecurityDisabled).
|
||||||
|
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
|
||||||
|
func (h *Handler) RegisterModelWithRules(schema, entity string, model interface{}, rules modelregistry.ModelRules) error {
|
||||||
|
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
|
||||||
|
}
|
||||||
|
fullName := buildModelName(schema, entity)
|
||||||
|
if err := reg.RegisterModelWithRules(fullName, model, rules); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
registerModelTools(h, schema, entity, model)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetModelRules updates the operation rules for an already-registered model.
|
||||||
|
// Requires RegisterSecurityHooks to have been called for the rules to be enforced.
|
||||||
|
func (h *Handler) SetModelRules(schema, entity string, rules modelregistry.ModelRules) error {
|
||||||
|
reg, ok := h.registry.(*modelregistry.DefaultModelRegistry)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("resolvemcp: registry does not support model rules (use NewHandlerWithGORM/Bun/DB)")
|
||||||
|
}
|
||||||
|
return reg.SetModelRules(buildModelName(schema, entity), rules)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildModelName builds the registry key for a model (same format as resolvespec).
|
||||||
|
func buildModelName(schema, entity string) string {
|
||||||
|
if schema == "" {
|
||||||
|
return entity
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s.%s", schema, entity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTableName returns the fully qualified table name for a model.
|
||||||
|
func (h *Handler) getTableName(schema, entity string, model interface{}) string {
|
||||||
|
schemaName, tableName := h.getSchemaAndTable(schema, entity, model)
|
||||||
|
if schemaName != "" {
|
||||||
|
if h.db.DriverName() == "sqlite" {
|
||||||
|
return fmt.Sprintf("%s_%s", schemaName, tableName)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s.%s", schemaName, tableName)
|
||||||
|
}
|
||||||
|
return tableName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) getSchemaAndTable(defaultSchema, entity string, model interface{}) (schema, table string) {
|
||||||
|
if tableProvider, ok := model.(common.TableNameProvider); ok {
|
||||||
|
tableName := tableProvider.TableName()
|
||||||
|
if idx := strings.LastIndex(tableName, "."); idx != -1 {
|
||||||
|
return tableName[:idx], tableName[idx+1:]
|
||||||
|
}
|
||||||
|
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||||
|
return schemaProvider.SchemaName(), tableName
|
||||||
|
}
|
||||||
|
return defaultSchema, tableName
|
||||||
|
}
|
||||||
|
if schemaProvider, ok := model.(common.SchemaProvider); ok {
|
||||||
|
return schemaProvider.SchemaName(), entity
|
||||||
|
}
|
||||||
|
return defaultSchema, entity
|
||||||
|
}
|
||||||
|
|
||||||
|
// recoverPanic catches a panic from the current goroutine and returns it as an error.
|
||||||
|
// Usage: defer recoverPanic(&returnedErr)
|
||||||
|
func recoverPanic(err *error) {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
msg := fmt.Sprintf("%v", r)
|
||||||
|
logger.Error("[resolvemcp] panic recovered: %s", msg)
|
||||||
|
*err = fmt.Errorf("internal error: %s", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeRead reads records from the database and returns raw data + metadata.
|
||||||
|
func (h *Handler) executeRead(ctx context.Context, schema, entity, id string, options common.RequestOptions) (_ interface{}, _ *common.Metadata, retErr error) {
|
||||||
|
defer recoverPanic(&retErr)
|
||||||
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("model not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
unwrapped, err := common.ValidateAndUnwrapModel(model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("invalid model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
model = unwrapped.Model
|
||||||
|
modelType := unwrapped.ModelType
|
||||||
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
ctx = withRequestData(ctx, schema, entity, tableName, model, unwrapped.ModelPtr)
|
||||||
|
|
||||||
|
validator := common.NewColumnValidator(model)
|
||||||
|
options = validator.FilterRequestOptions(options)
|
||||||
|
|
||||||
|
// BeforeHandle hook
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Operation: "read",
|
||||||
|
Options: options,
|
||||||
|
ID: id,
|
||||||
|
Tx: h.db,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sliceType := reflect.SliceOf(reflect.PointerTo(modelType))
|
||||||
|
modelPtr := reflect.New(sliceType).Interface()
|
||||||
|
|
||||||
|
query := h.db.NewSelect().Model(modelPtr)
|
||||||
|
|
||||||
|
tempInstance := reflect.New(modelType).Interface()
|
||||||
|
if provider, ok := tempInstance.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||||
|
query = query.Table(tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Column selection
|
||||||
|
if len(options.Columns) == 0 && len(options.ComputedColumns) > 0 {
|
||||||
|
options.Columns = reflection.GetSQLModelColumns(model)
|
||||||
|
}
|
||||||
|
for _, col := range options.Columns {
|
||||||
|
query = query.Column(reflection.ExtractSourceColumn(col))
|
||||||
|
}
|
||||||
|
for _, cu := range options.ComputedColumns {
|
||||||
|
query = query.ColumnExpr(fmt.Sprintf("(%s) AS %s", cu.Expression, cu.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filters
|
||||||
|
query = h.applyFilters(query, options.Filters)
|
||||||
|
|
||||||
|
// Custom operators
|
||||||
|
for _, customOp := range options.CustomOperators {
|
||||||
|
query = query.Where(customOp.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sorting
|
||||||
|
for _, sort := range options.Sort {
|
||||||
|
direction := "ASC"
|
||||||
|
if strings.EqualFold(sort.Direction, "desc") {
|
||||||
|
direction = "DESC"
|
||||||
|
}
|
||||||
|
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cursor pagination
|
||||||
|
if options.CursorForward != "" || options.CursorBackward != "" {
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
modelColumns := reflection.GetModelColumns(model)
|
||||||
|
|
||||||
|
if len(options.Sort) == 0 {
|
||||||
|
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// expandJoins is empty for resolvemcp — no custom SQL join support yet
|
||||||
|
cursorFilter, err := getCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("cursor error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cursorFilter != "" {
|
||||||
|
sanitized := common.SanitizeWhereClause(cursorFilter, reflection.ExtractTableNameOnly(tableName), &options)
|
||||||
|
sanitized = common.EnsureOuterParentheses(sanitized)
|
||||||
|
if sanitized != "" {
|
||||||
|
query = query.Where(sanitized)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count — must happen before preloads are applied; Bun panics when counting with relations.
|
||||||
|
total, err := query.Count(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error counting records: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pagination
|
||||||
|
if options.Limit != nil && *options.Limit > 0 {
|
||||||
|
query = query.Limit(*options.Limit)
|
||||||
|
}
|
||||||
|
if options.Offset != nil && *options.Offset > 0 {
|
||||||
|
query = query.Offset(*options.Offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Preloads — applied after count to avoid Bun panic when counting with relations.
|
||||||
|
if len(options.Preload) > 0 {
|
||||||
|
var preloadErr error
|
||||||
|
query, preloadErr = h.applyPreloads(model, query, options.Preload)
|
||||||
|
if preloadErr != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to apply preloads: %w", preloadErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeforeRead hook
|
||||||
|
hookCtx.Query = query
|
||||||
|
if err := h.hooks.Execute(BeforeRead, hookCtx); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var data interface{}
|
||||||
|
if id != "" {
|
||||||
|
singleResult := reflect.New(modelType).Interface()
|
||||||
|
pkName := reflection.GetPrimaryKeyName(singleResult)
|
||||||
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
if err := query.Scan(ctx, singleResult); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return nil, nil, fmt.Errorf("record not found")
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("query error: %w", err)
|
||||||
|
}
|
||||||
|
data = singleResult
|
||||||
|
} else {
|
||||||
|
if err := query.Scan(ctx, modelPtr); err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("query error: %w", err)
|
||||||
|
}
|
||||||
|
data = reflect.ValueOf(modelPtr).Elem().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := 0
|
||||||
|
offset := 0
|
||||||
|
if options.Limit != nil {
|
||||||
|
limit = *options.Limit
|
||||||
|
}
|
||||||
|
if options.Offset != nil {
|
||||||
|
offset = *options.Offset
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count is the number of records in this page, not the total.
|
||||||
|
var pageCount int64
|
||||||
|
if id != "" {
|
||||||
|
pageCount = 1
|
||||||
|
} else {
|
||||||
|
pageCount = int64(reflect.ValueOf(data).Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := &common.Metadata{
|
||||||
|
Total: int64(total),
|
||||||
|
Filtered: int64(total),
|
||||||
|
Count: pageCount,
|
||||||
|
Limit: limit,
|
||||||
|
Offset: offset,
|
||||||
|
}
|
||||||
|
|
||||||
|
// AfterRead hook
|
||||||
|
hookCtx.Result = data
|
||||||
|
if err := h.hooks.Execute(AfterRead, hookCtx); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, metadata, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeCreate inserts one or more records.
|
||||||
|
func (h *Handler) executeCreate(ctx context.Context, schema, entity string, data interface{}) (_ interface{}, retErr error) {
|
||||||
|
defer recoverPanic(&retErr)
|
||||||
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("model not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := common.ValidateAndUnwrapModel(model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
model = result.Model
|
||||||
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||||
|
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Operation: "create",
|
||||||
|
Data: data,
|
||||||
|
Tx: h.db,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeCreate, hookCtx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use potentially modified data
|
||||||
|
data = hookCtx.Data
|
||||||
|
|
||||||
|
switch v := data.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
query := h.db.NewInsert().Table(tableName)
|
||||||
|
for key, value := range v {
|
||||||
|
query = query.Value(key, value)
|
||||||
|
}
|
||||||
|
if _, err := query.Exec(ctx); err != nil {
|
||||||
|
return nil, fmt.Errorf("create error: %w", err)
|
||||||
|
}
|
||||||
|
hookCtx.Result = v
|
||||||
|
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||||
|
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
|
||||||
|
case []interface{}:
|
||||||
|
results := make([]interface{}, 0, len(v))
|
||||||
|
err := h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
for _, item := range v {
|
||||||
|
itemMap, ok := item.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("each item must be an object")
|
||||||
|
}
|
||||||
|
q := tx.NewInsert().Table(tableName)
|
||||||
|
for key, value := range itemMap {
|
||||||
|
q = q.Value(key, value)
|
||||||
|
}
|
||||||
|
if _, err := q.Exec(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
results = append(results, item)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("batch create error: %w", err)
|
||||||
|
}
|
||||||
|
hookCtx.Result = results
|
||||||
|
if err := h.hooks.Execute(AfterCreate, hookCtx); err != nil {
|
||||||
|
return nil, fmt.Errorf("AfterCreate hook failed: %w", err)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("data must be an object or array of objects")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeUpdate updates a record by ID.
|
||||||
|
func (h *Handler) executeUpdate(ctx context.Context, schema, entity, id string, data interface{}) (_ interface{}, retErr error) {
|
||||||
|
defer recoverPanic(&retErr)
|
||||||
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("model not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := common.ValidateAndUnwrapModel(model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
model = result.Model
|
||||||
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||||
|
|
||||||
|
updates, ok := data.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("data must be an object")
|
||||||
|
}
|
||||||
|
|
||||||
|
if id == "" {
|
||||||
|
if idVal, exists := updates["id"]; exists {
|
||||||
|
id = fmt.Sprintf("%v", idVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if id == "" {
|
||||||
|
return nil, fmt.Errorf("update requires an ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
|
var updateResult interface{}
|
||||||
|
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
// Read existing record
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
existingRecord := reflect.New(modelType).Interface()
|
||||||
|
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return fmt.Errorf("no records found to update")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("error fetching existing record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to map
|
||||||
|
existingMap := make(map[string]interface{})
|
||||||
|
jsonData, err := json.Marshal(existingRecord)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error marshaling existing record: %w", err)
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jsonData, &existingMap); err != nil {
|
||||||
|
return fmt.Errorf("error unmarshaling existing record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Operation: "update",
|
||||||
|
ID: id,
|
||||||
|
Data: updates,
|
||||||
|
Tx: tx,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeUpdate, hookCtx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if modifiedData, ok := hookCtx.Data.(map[string]interface{}); ok {
|
||||||
|
updates = modifiedData
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merge non-nil, non-empty values
|
||||||
|
for key, newValue := range updates {
|
||||||
|
if newValue == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strVal, ok := newValue.(string); ok && strVal == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
existingMap[key] = newValue
|
||||||
|
}
|
||||||
|
|
||||||
|
q := tx.NewUpdate().Table(tableName).SetMap(existingMap).
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
res, err := q.Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error updating record: %w", err)
|
||||||
|
}
|
||||||
|
if res.RowsAffected() == 0 {
|
||||||
|
return fmt.Errorf("no records found to update")
|
||||||
|
}
|
||||||
|
|
||||||
|
updateResult = existingMap
|
||||||
|
hookCtx.Result = updateResult
|
||||||
|
return h.hooks.Execute(AfterUpdate, hookCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return updateResult, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// executeDelete deletes a record by ID.
|
||||||
|
func (h *Handler) executeDelete(ctx context.Context, schema, entity, id string) (_ interface{}, retErr error) {
|
||||||
|
defer recoverPanic(&retErr)
|
||||||
|
if id == "" {
|
||||||
|
return nil, fmt.Errorf("delete requires an ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := h.registry.GetModelByEntity(schema, entity)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("model not found: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := common.ValidateAndUnwrapModel(model)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
model = result.Model
|
||||||
|
tableName := h.getTableName(schema, entity, model)
|
||||||
|
ctx = withRequestData(ctx, schema, entity, tableName, model, result.ModelPtr)
|
||||||
|
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Operation: "delete",
|
||||||
|
ID: id,
|
||||||
|
Tx: h.db,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, hookCtx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
var recordToDelete interface{}
|
||||||
|
|
||||||
|
err = h.db.RunInTransaction(ctx, func(tx common.Database) error {
|
||||||
|
record := reflect.New(modelType).Interface()
|
||||||
|
selectQuery := tx.NewSelect().Model(record).
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return fmt.Errorf("record not found")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("error fetching record: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := tx.NewDelete().Table(tableName).
|
||||||
|
Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id).
|
||||||
|
Exec(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("delete error: %w", err)
|
||||||
|
}
|
||||||
|
if res.RowsAffected() == 0 {
|
||||||
|
return fmt.Errorf("record not found or already deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
recordToDelete = record
|
||||||
|
hookCtx.Tx = tx
|
||||||
|
hookCtx.Result = record
|
||||||
|
return h.hooks.Execute(AfterDelete, hookCtx)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("[resolvemcp] Deleted record %s from %s.%s", id, schema, entity)
|
||||||
|
return recordToDelete, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyFilters applies all filters with OR grouping logic.
|
||||||
|
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||||
|
if len(filters) == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for i < len(filters) {
|
||||||
|
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||||
|
|
||||||
|
if startORGroup {
|
||||||
|
orGroup := []common.FilterOption{filters[i]}
|
||||||
|
j := i + 1
|
||||||
|
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||||
|
orGroup = append(orGroup, filters[j])
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
query = h.applyFilterGroup(query, orGroup)
|
||||||
|
i = j
|
||||||
|
} else {
|
||||||
|
condition, args := h.buildFilterCondition(filters[i])
|
||||||
|
if condition != "" {
|
||||||
|
query = query.Where(condition, args...)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||||
|
var conditions []string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
for _, filter := range filters {
|
||||||
|
condition, filterArgs := h.buildFilterCondition(filter)
|
||||||
|
if condition != "" {
|
||||||
|
conditions = append(conditions, condition)
|
||||||
|
args = append(args, filterArgs...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(conditions) == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
if len(conditions) == 1 {
|
||||||
|
return query.Where(conditions[0], args...)
|
||||||
|
}
|
||||||
|
return query.Where("("+strings.Join(conditions, " OR ")+")", args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) buildFilterCondition(filter common.FilterOption) (condition string, args []interface{}) {
|
||||||
|
switch filter.Operator {
|
||||||
|
case "eq", "=":
|
||||||
|
return fmt.Sprintf("%s = ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "neq", "!=", "<>":
|
||||||
|
return fmt.Sprintf("%s != ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "gt", ">":
|
||||||
|
return fmt.Sprintf("%s > ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "gte", ">=":
|
||||||
|
return fmt.Sprintf("%s >= ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "lt", "<":
|
||||||
|
return fmt.Sprintf("%s < ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "lte", "<=":
|
||||||
|
return fmt.Sprintf("%s <= ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "like":
|
||||||
|
return fmt.Sprintf("CAST(%s AS TEXT) LIKE ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "ilike":
|
||||||
|
return fmt.Sprintf("CAST(%s AS TEXT) ILIKE ?", filter.Column), []interface{}{filter.Value}
|
||||||
|
case "in":
|
||||||
|
condition, args := common.BuildInCondition(filter.Column, filter.Value)
|
||||||
|
return condition, args
|
||||||
|
case "is_null":
|
||||||
|
return fmt.Sprintf("%s IS NULL", filter.Column), nil
|
||||||
|
case "is_not_null":
|
||||||
|
return fmt.Sprintf("%s IS NOT NULL", filter.Column), nil
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) applyPreloads(model interface{}, query common.SelectQuery, preloads []common.PreloadOption) (common.SelectQuery, error) {
|
||||||
|
for i := range preloads {
|
||||||
|
preload := &preloads[i]
|
||||||
|
if preload.Relation == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
query = query.PreloadRelation(preload.Relation)
|
||||||
|
}
|
||||||
|
return query, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,113 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HookType defines the type of hook to execute
|
||||||
|
type HookType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||||
|
BeforeHandle HookType = "before_handle"
|
||||||
|
|
||||||
|
BeforeRead HookType = "before_read"
|
||||||
|
AfterRead HookType = "after_read"
|
||||||
|
|
||||||
|
BeforeCreate HookType = "before_create"
|
||||||
|
AfterCreate HookType = "after_create"
|
||||||
|
|
||||||
|
BeforeUpdate HookType = "before_update"
|
||||||
|
AfterUpdate HookType = "after_update"
|
||||||
|
|
||||||
|
BeforeDelete HookType = "before_delete"
|
||||||
|
AfterDelete HookType = "after_delete"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HookContext contains all the data available to a hook
|
||||||
|
type HookContext struct {
|
||||||
|
Context context.Context
|
||||||
|
Handler *Handler
|
||||||
|
Schema string
|
||||||
|
Entity string
|
||||||
|
Model interface{}
|
||||||
|
Options common.RequestOptions
|
||||||
|
Operation string
|
||||||
|
ID string
|
||||||
|
Data interface{}
|
||||||
|
Result interface{}
|
||||||
|
Error error
|
||||||
|
Query common.SelectQuery
|
||||||
|
Abort bool
|
||||||
|
AbortMessage string
|
||||||
|
AbortCode int
|
||||||
|
Tx common.Database
|
||||||
|
}
|
||||||
|
|
||||||
|
// HookFunc is the signature for hook functions
|
||||||
|
type HookFunc func(*HookContext) error
|
||||||
|
|
||||||
|
// HookRegistry manages all registered hooks
|
||||||
|
type HookRegistry struct {
|
||||||
|
hooks map[HookType][]HookFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHookRegistry() *HookRegistry {
|
||||||
|
return &HookRegistry{
|
||||||
|
hooks: make(map[HookType][]HookFunc),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *HookRegistry) Register(hookType HookType, hook HookFunc) {
|
||||||
|
if r.hooks == nil {
|
||||||
|
r.hooks = make(map[HookType][]HookFunc)
|
||||||
|
}
|
||||||
|
r.hooks[hookType] = append(r.hooks[hookType], hook)
|
||||||
|
logger.Info("Registered resolvemcp hook for %s (total: %d)", hookType, len(r.hooks[hookType]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *HookRegistry) RegisterMultiple(hookTypes []HookType, hook HookFunc) {
|
||||||
|
for _, hookType := range hookTypes {
|
||||||
|
r.Register(hookType, hook)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
||||||
|
hooks, exists := r.hooks[hookType]
|
||||||
|
if !exists || len(hooks) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("Executing %d resolvemcp hook(s) for %s", len(hooks), hookType)
|
||||||
|
|
||||||
|
for i, hook := range hooks {
|
||||||
|
if err := hook(ctx); err != nil {
|
||||||
|
logger.Error("resolvemcp hook %d for %s failed: %v", i+1, hookType, err)
|
||||||
|
return fmt.Errorf("hook execution failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Abort {
|
||||||
|
logger.Warn("resolvemcp hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||||
|
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *HookRegistry) Clear(hookType HookType) {
|
||||||
|
delete(r.hooks, hookType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *HookRegistry) ClearAll() {
|
||||||
|
r.hooks = make(map[HookType][]HookFunc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *HookRegistry) HasHooks(hookType HookType) bool {
|
||||||
|
hooks, exists := r.hooks[hookType]
|
||||||
|
return exists && len(hooks) > 0
|
||||||
|
}
|
||||||
@@ -0,0 +1,264 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// OAuth2 registration on the Handler
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// oauth2Registration stores a configured auth provider and its route config.
|
||||||
|
type oauth2Registration struct {
|
||||||
|
auth *security.DatabaseAuthenticator
|
||||||
|
cfg OAuth2RouteConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterOAuth2 attaches an OAuth2 provider to the Handler.
|
||||||
|
// The login and callback HTTP routes are served by HTTPHandler / StreamableHTTPMux.
|
||||||
|
// Call this once per provider before serving requests.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// auth := security.NewGoogleAuthenticator(clientID, secret, redirectURL, db)
|
||||||
|
// handler.RegisterOAuth2(auth, resolvemcp.OAuth2RouteConfig{
|
||||||
|
// ProviderName: "google",
|
||||||
|
// LoginPath: "/auth/google/login",
|
||||||
|
// CallbackPath: "/auth/google/callback",
|
||||||
|
// AfterLoginRedirect: "/",
|
||||||
|
// })
|
||||||
|
func (h *Handler) RegisterOAuth2(auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
|
||||||
|
h.oauth2Regs = append(h.oauth2Regs, oauth2Registration{auth: auth, cfg: cfg})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPHandler returns a single http.Handler that serves:
|
||||||
|
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
|
||||||
|
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
|
||||||
|
// - The MCP SSE transport wrapped with required authentication middleware
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// auth := security.NewGoogleAuthenticator(...)
|
||||||
|
// handler.RegisterOAuth2(auth, cfg)
|
||||||
|
// handler.EnableOAuthServer(security.OAuthServerConfig{Issuer: "https://api.example.com"})
|
||||||
|
// security.RegisterSecurityHooks(handler, securityList)
|
||||||
|
// http.ListenAndServe(":8080", handler.HTTPHandler(securityList))
|
||||||
|
func (h *Handler) HTTPHandler(securityList *security.SecurityList) http.Handler {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
if h.oauthSrv != nil {
|
||||||
|
h.mountOAuthServerRoutes(mux)
|
||||||
|
}
|
||||||
|
h.mountOAuth2Routes(mux)
|
||||||
|
|
||||||
|
mcpHandler := h.AuthedSSEServer(securityList)
|
||||||
|
basePath := h.config.BasePath
|
||||||
|
if basePath == "" {
|
||||||
|
basePath = "/mcp"
|
||||||
|
}
|
||||||
|
mux.Handle(basePath+"/sse", mcpHandler)
|
||||||
|
mux.Handle(basePath+"/message", mcpHandler)
|
||||||
|
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
|
||||||
|
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamableHTTPMux returns a single http.Handler that serves:
|
||||||
|
// - MCP OAuth2 authorization server endpoints (when EnableOAuthServer has been called)
|
||||||
|
// - OAuth2 login + callback routes for every registered provider (legacy cookie flow)
|
||||||
|
// - The MCP streamable HTTP transport wrapped with required authentication middleware
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// http.ListenAndServe(":8080", handler.StreamableHTTPMux(securityList))
|
||||||
|
func (h *Handler) StreamableHTTPMux(securityList *security.SecurityList) http.Handler {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
if h.oauthSrv != nil {
|
||||||
|
h.mountOAuthServerRoutes(mux)
|
||||||
|
}
|
||||||
|
h.mountOAuth2Routes(mux)
|
||||||
|
|
||||||
|
mcpHandler := h.AuthedStreamableHTTPServer(securityList)
|
||||||
|
basePath := h.config.BasePath
|
||||||
|
if basePath == "" {
|
||||||
|
basePath = "/mcp"
|
||||||
|
}
|
||||||
|
mux.Handle(basePath+"/", http.StripPrefix(basePath, mcpHandler))
|
||||||
|
mux.Handle(basePath, mcpHandler)
|
||||||
|
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
|
||||||
|
// mountOAuth2Routes registers all stored OAuth2 login+callback routes onto mux.
|
||||||
|
func (h *Handler) mountOAuth2Routes(mux *http.ServeMux) {
|
||||||
|
for _, reg := range h.oauth2Regs {
|
||||||
|
var cookieOpts []security.SessionCookieOptions
|
||||||
|
if reg.cfg.CookieOptions != nil {
|
||||||
|
cookieOpts = append(cookieOpts, *reg.cfg.CookieOptions)
|
||||||
|
}
|
||||||
|
mux.Handle(reg.cfg.LoginPath, OAuth2LoginHandler(reg.auth, reg.cfg.ProviderName))
|
||||||
|
mux.Handle(reg.cfg.CallbackPath, OAuth2CallbackHandler(reg.auth, reg.cfg.ProviderName, reg.cfg.AfterLoginRedirect, cookieOpts...))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Auth-wrapped transports
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// AuthedSSEServer wraps SSEServer with required authentication middleware from pkg/security.
|
||||||
|
// The middleware reads the session cookie / Authorization header and populates the user
|
||||||
|
// context into the request context, making it available to BeforeHandle security hooks.
|
||||||
|
// Unauthenticated requests receive 401 before reaching any MCP tool.
|
||||||
|
func (h *Handler) AuthedSSEServer(securityList *security.SecurityList) http.Handler {
|
||||||
|
return security.NewAuthMiddleware(securityList)(h.SSEServer())
|
||||||
|
}
|
||||||
|
|
||||||
|
// OptionalAuthSSEServer wraps SSEServer with optional authentication middleware.
|
||||||
|
// Unauthenticated requests continue as guest rather than returning 401.
|
||||||
|
// Use together with RegisterSecurityHooks and per-model CanPublicRead/Write rules
|
||||||
|
// to allow mixed public/private access.
|
||||||
|
func (h *Handler) OptionalAuthSSEServer(securityList *security.SecurityList) http.Handler {
|
||||||
|
return security.NewOptionalAuthMiddleware(securityList)(h.SSEServer())
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthedStreamableHTTPServer wraps StreamableHTTPServer with required authentication middleware.
|
||||||
|
func (h *Handler) AuthedStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
|
||||||
|
return security.NewAuthMiddleware(securityList)(h.StreamableHTTPServer())
|
||||||
|
}
|
||||||
|
|
||||||
|
// OptionalAuthStreamableHTTPServer wraps StreamableHTTPServer with optional authentication middleware.
|
||||||
|
func (h *Handler) OptionalAuthStreamableHTTPServer(securityList *security.SecurityList) http.Handler {
|
||||||
|
return security.NewOptionalAuthMiddleware(securityList)(h.StreamableHTTPServer())
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// OAuth2 route config and standalone handlers
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// OAuth2RouteConfig configures the OAuth2 HTTP endpoints for a single provider.
|
||||||
|
type OAuth2RouteConfig struct {
|
||||||
|
// ProviderName is the OAuth2 provider name as registered with WithOAuth2()
|
||||||
|
// (e.g. "google", "github", "microsoft").
|
||||||
|
ProviderName string
|
||||||
|
|
||||||
|
// LoginPath is the HTTP path that redirects the browser to the OAuth2 provider
|
||||||
|
// (e.g. "/auth/google/login").
|
||||||
|
LoginPath string
|
||||||
|
|
||||||
|
// CallbackPath is the HTTP path that the OAuth2 provider redirects back to
|
||||||
|
// (e.g. "/auth/google/callback"). Must match the RedirectURL in OAuth2Config.
|
||||||
|
CallbackPath string
|
||||||
|
|
||||||
|
// AfterLoginRedirect is the URL to redirect the browser to after a successful
|
||||||
|
// login. When empty the LoginResponse JSON is written directly to the response.
|
||||||
|
AfterLoginRedirect string
|
||||||
|
|
||||||
|
// CookieOptions customises the session cookie written on successful login.
|
||||||
|
// Defaults to HttpOnly, Secure, SameSite=Lax when nil.
|
||||||
|
CookieOptions *security.SessionCookieOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2LoginHandler returns an http.HandlerFunc that redirects the browser to
|
||||||
|
// the OAuth2 provider's authorization URL.
|
||||||
|
//
|
||||||
|
// Register it on any router:
|
||||||
|
//
|
||||||
|
// mux.Handle("/auth/google/login", resolvemcp.OAuth2LoginHandler(auth, "google"))
|
||||||
|
func OAuth2LoginHandler(auth *security.DatabaseAuthenticator, providerName string) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, err := auth.OAuth2GenerateState()
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "failed to generate state", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authURL, err := auth.OAuth2GetAuthURL(providerName, state)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, authURL, http.StatusTemporaryRedirect)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2CallbackHandler returns an http.HandlerFunc that handles the OAuth2 provider
|
||||||
|
// callback: exchanges the authorization code for a session token, writes the session
|
||||||
|
// cookie, then either redirects to afterLoginRedirect or writes the LoginResponse as JSON.
|
||||||
|
//
|
||||||
|
// Register it on any router:
|
||||||
|
//
|
||||||
|
// mux.Handle("/auth/google/callback", resolvemcp.OAuth2CallbackHandler(auth, "google", "/dashboard"))
|
||||||
|
func OAuth2CallbackHandler(auth *security.DatabaseAuthenticator, providerName, afterLoginRedirect string, cookieOpts ...security.SessionCookieOptions) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
state := r.URL.Query().Get("state")
|
||||||
|
if code == "" {
|
||||||
|
http.Error(w, "missing code parameter", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
loginResp, err := auth.OAuth2HandleCallback(r.Context(), providerName, code, state)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
security.SetSessionCookie(w, loginResp, cookieOpts...)
|
||||||
|
|
||||||
|
if afterLoginRedirect != "" {
|
||||||
|
http.Redirect(w, r, afterLoginRedirect, http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(loginResp) //nolint:errcheck
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Gorilla Mux convenience helpers
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// SetupMuxOAuth2Routes registers the OAuth2 login and callback routes on a Gorilla Mux router.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// resolvemcp.SetupMuxOAuth2Routes(r, auth, resolvemcp.OAuth2RouteConfig{
|
||||||
|
// ProviderName: "google", LoginPath: "/auth/google/login",
|
||||||
|
// CallbackPath: "/auth/google/callback", AfterLoginRedirect: "/",
|
||||||
|
// })
|
||||||
|
func SetupMuxOAuth2Routes(muxRouter *mux.Router, auth *security.DatabaseAuthenticator, cfg OAuth2RouteConfig) {
|
||||||
|
var cookieOpts []security.SessionCookieOptions
|
||||||
|
if cfg.CookieOptions != nil {
|
||||||
|
cookieOpts = append(cookieOpts, *cfg.CookieOptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
muxRouter.Handle(cfg.LoginPath,
|
||||||
|
OAuth2LoginHandler(auth, cfg.ProviderName),
|
||||||
|
).Methods(http.MethodGet)
|
||||||
|
|
||||||
|
muxRouter.Handle(cfg.CallbackPath,
|
||||||
|
OAuth2CallbackHandler(auth, cfg.ProviderName, cfg.AfterLoginRedirect, cookieOpts...),
|
||||||
|
).Methods(http.MethodGet)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupMuxRoutesWithAuth mounts the MCP SSE endpoints on a Gorilla Mux router
|
||||||
|
// with required authentication middleware applied.
|
||||||
|
func SetupMuxRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
|
||||||
|
basePath := handler.config.BasePath
|
||||||
|
h := handler.AuthedSSEServer(securityList)
|
||||||
|
|
||||||
|
muxRouter.Handle(basePath+"/sse", h).Methods(http.MethodGet, http.MethodOptions)
|
||||||
|
muxRouter.Handle(basePath+"/message", h).Methods(http.MethodPost, http.MethodOptions)
|
||||||
|
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupMuxStreamableHTTPRoutesWithAuth mounts the MCP streamable HTTP endpoint on a
|
||||||
|
// Gorilla Mux router with required authentication middleware applied.
|
||||||
|
func SetupMuxStreamableHTTPRoutesWithAuth(muxRouter *mux.Router, handler *Handler, securityList *security.SecurityList) {
|
||||||
|
basePath := handler.config.BasePath
|
||||||
|
h := handler.AuthedStreamableHTTPServer(securityList)
|
||||||
|
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||||
|
}
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EnableOAuthServer activates the MCP-standard OAuth2 authorization server on this Handler.
|
||||||
|
//
|
||||||
|
// Pass a DatabaseAuthenticator to enable direct username/password login — the server acts as
|
||||||
|
// its own identity provider and renders a login form at /oauth/authorize. Pass nil to use
|
||||||
|
// only external providers registered via RegisterOAuth2Provider.
|
||||||
|
//
|
||||||
|
// After calling this, HTTPHandler and StreamableHTTPMux serve the full set of RFC-compliant
|
||||||
|
// endpoints required by MCP clients alongside the MCP transport:
|
||||||
|
//
|
||||||
|
// GET /.well-known/oauth-authorization-server RFC 8414 — auto-discovery
|
||||||
|
// POST /oauth/register RFC 7591 — dynamic client registration
|
||||||
|
// GET /oauth/authorize OAuth 2.1 + PKCE — start login
|
||||||
|
// POST /oauth/authorize Login form submission (password flow)
|
||||||
|
// POST /oauth/token Bearer token exchange + refresh
|
||||||
|
// GET /oauth/provider/callback External provider redirect target
|
||||||
|
func (h *Handler) EnableOAuthServer(cfg security.OAuthServerConfig, auth *security.DatabaseAuthenticator) {
|
||||||
|
h.oauthSrv = security.NewOAuthServer(cfg, auth)
|
||||||
|
// Wire any external providers already registered via RegisterOAuth2
|
||||||
|
for _, reg := range h.oauth2Regs {
|
||||||
|
h.oauthSrv.RegisterExternalProvider(reg.auth, reg.cfg.ProviderName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterOAuth2Provider adds an external OAuth2 provider to the MCP OAuth2 authorization server.
|
||||||
|
// EnableOAuthServer must be called before this. The auth must have been configured with
|
||||||
|
// WithOAuth2(providerName, ...) for the given provider name.
|
||||||
|
func (h *Handler) RegisterOAuth2Provider(auth *security.DatabaseAuthenticator, providerName string) {
|
||||||
|
if h.oauthSrv != nil {
|
||||||
|
h.oauthSrv.RegisterExternalProvider(auth, providerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mountOAuthServerRoutes mounts the security.OAuthServer's HTTP handler onto mux.
|
||||||
|
func (h *Handler) mountOAuthServerRoutes(mux *http.ServeMux) {
|
||||||
|
oauthHandler := h.oauthSrv.HTTPHandler()
|
||||||
|
// Delegate all /oauth/ and /.well-known/ paths to the OAuth server
|
||||||
|
mux.Handle("/.well-known/", oauthHandler)
|
||||||
|
mux.Handle("/oauth/", oauthHandler)
|
||||||
|
if h.oauthSrv != nil {
|
||||||
|
// Also mount the external provider callback path if it differs from /oauth/
|
||||||
|
mux.Handle(h.oauthSrv.ProviderCallbackPath(), oauthHandler)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,133 @@
|
|||||||
|
// Package resolvemcp exposes registered database models as Model Context Protocol (MCP) tools
|
||||||
|
// and resources over HTTP/SSE transport.
|
||||||
|
//
|
||||||
|
// It mirrors the resolvespec package patterns:
|
||||||
|
// - Same model registration API
|
||||||
|
// - Same filter, sort, cursor pagination, preload options
|
||||||
|
// - Same lifecycle hook system
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
//
|
||||||
|
// handler := resolvemcp.NewHandlerWithGORM(db, resolvemcp.Config{BaseURL: "http://localhost:8080"})
|
||||||
|
// handler.RegisterModel("public", "users", &User{})
|
||||||
|
//
|
||||||
|
// r := mux.NewRouter()
|
||||||
|
// resolvemcp.SetupMuxRoutes(r, handler)
|
||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/uptrace/bun"
|
||||||
|
bunrouter "github.com/uptrace/bunrouter"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common/adapters/database"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds configuration for the resolvemcp handler.
|
||||||
|
type Config struct {
|
||||||
|
// BaseURL is the public-facing base URL of the server (e.g. "http://localhost:8080").
|
||||||
|
// It is sent to MCP clients during the SSE handshake so they know where to POST messages.
|
||||||
|
BaseURL string
|
||||||
|
|
||||||
|
// BasePath is the URL path prefix where the MCP endpoints are mounted (e.g. "/mcp").
|
||||||
|
// If empty, the path is detected from each incoming request automatically.
|
||||||
|
BasePath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandlerWithGORM creates a Handler backed by a GORM database connection.
|
||||||
|
func NewHandlerWithGORM(db *gorm.DB, cfg Config) *Handler {
|
||||||
|
return NewHandler(database.NewGormAdapter(db), modelregistry.NewModelRegistry(), cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandlerWithBun creates a Handler backed by a Bun database connection.
|
||||||
|
func NewHandlerWithBun(db *bun.DB, cfg Config) *Handler {
|
||||||
|
return NewHandler(database.NewBunAdapter(db), modelregistry.NewModelRegistry(), cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandlerWithDB creates a Handler using an existing common.Database and a new registry.
|
||||||
|
func NewHandlerWithDB(db common.Database, cfg Config) *Handler {
|
||||||
|
return NewHandler(db, modelregistry.NewModelRegistry(), cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupMuxRoutes mounts the MCP HTTP/SSE endpoints on the given Gorilla Mux router
|
||||||
|
// using the base path from Config.BasePath (falls back to "/mcp" if empty).
|
||||||
|
//
|
||||||
|
// Two routes are registered:
|
||||||
|
// - GET {basePath}/sse — SSE connection endpoint (client subscribes here)
|
||||||
|
// - POST {basePath}/message — JSON-RPC message endpoint (client sends requests here)
|
||||||
|
//
|
||||||
|
// To protect these routes with authentication, wrap the mux router or apply middleware
|
||||||
|
// before calling SetupMuxRoutes.
|
||||||
|
func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler) {
|
||||||
|
basePath := handler.config.BasePath
|
||||||
|
h := handler.SSEServer()
|
||||||
|
|
||||||
|
muxRouter.Handle(basePath+"/sse", h).Methods("GET", "OPTIONS")
|
||||||
|
muxRouter.Handle(basePath+"/message", h).Methods("POST", "OPTIONS")
|
||||||
|
|
||||||
|
// Convenience: also expose the full SSE server at basePath for clients that
|
||||||
|
// use ServeHTTP directly (e.g. net/http default mux).
|
||||||
|
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupBunRouterRoutes mounts the MCP HTTP/SSE endpoints on a bunrouter router
|
||||||
|
// using the base path from Config.BasePath.
|
||||||
|
//
|
||||||
|
// Two routes are registered:
|
||||||
|
// - GET {basePath}/sse — SSE connection endpoint
|
||||||
|
// - POST {basePath}/message — JSON-RPC message endpoint
|
||||||
|
func SetupBunRouterRoutes(router *bunrouter.Router, handler *Handler) {
|
||||||
|
basePath := handler.config.BasePath
|
||||||
|
h := handler.SSEServer()
|
||||||
|
|
||||||
|
router.GET(basePath+"/sse", bunrouter.HTTPHandler(h))
|
||||||
|
router.POST(basePath+"/message", bunrouter.HTTPHandler(h))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSSEServer returns an http.Handler that serves MCP over SSE.
|
||||||
|
// If Config.BasePath is set it is used directly; otherwise the base path is
|
||||||
|
// detected from each incoming request (by stripping the "/sse" or "/message" suffix).
|
||||||
|
//
|
||||||
|
// h := resolvemcp.NewSSEServer(handler)
|
||||||
|
// http.Handle("/api/mcp/", h)
|
||||||
|
func NewSSEServer(handler *Handler) http.Handler {
|
||||||
|
return handler.SSEServer()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupMuxStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on the given Gorilla Mux router.
|
||||||
|
// The streamable HTTP transport uses a single endpoint (Config.BasePath) for all communication:
|
||||||
|
// POST for client→server messages, GET for server→client streaming.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// resolvemcp.SetupMuxStreamableHTTPRoutes(r, handler) // mounts at Config.BasePath
|
||||||
|
func SetupMuxStreamableHTTPRoutes(muxRouter *mux.Router, handler *Handler) {
|
||||||
|
basePath := handler.config.BasePath
|
||||||
|
h := handler.StreamableHTTPServer()
|
||||||
|
muxRouter.PathPrefix(basePath).Handler(http.StripPrefix(basePath, h))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupBunRouterStreamableHTTPRoutes mounts the MCP streamable HTTP endpoint on a bunrouter router.
|
||||||
|
// The streamable HTTP transport uses a single endpoint (Config.BasePath).
|
||||||
|
func SetupBunRouterStreamableHTTPRoutes(router *bunrouter.Router, handler *Handler) {
|
||||||
|
basePath := handler.config.BasePath
|
||||||
|
h := handler.StreamableHTTPServer()
|
||||||
|
router.GET(basePath, bunrouter.HTTPHandler(h))
|
||||||
|
router.POST(basePath, bunrouter.HTTPHandler(h))
|
||||||
|
router.DELETE(basePath, bunrouter.HTTPHandler(h))
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStreamableHTTPHandler returns an http.Handler that serves MCP over the streamable HTTP transport.
|
||||||
|
// Mount it at the desired path; that path becomes the MCP endpoint.
|
||||||
|
//
|
||||||
|
// h := resolvemcp.NewStreamableHTTPHandler(handler)
|
||||||
|
// http.Handle("/mcp", h)
|
||||||
|
// engine.Any("/mcp", gin.WrapH(h))
|
||||||
|
func NewStreamableHTTPHandler(handler *Handler) http.Handler {
|
||||||
|
return handler.StreamableHTTPServer()
|
||||||
|
}
|
||||||
@@ -0,0 +1,115 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterSecurityHooks wires the security package's access-control layer into the
|
||||||
|
// resolvemcp handler. Call it once after creating the handler, before registering models.
|
||||||
|
//
|
||||||
|
// The following controls are applied:
|
||||||
|
// - Per-entity operation rules (CanRead, CanCreate, CanUpdate, CanDelete, CanPublic*)
|
||||||
|
// stored via RegisterModelWithRules / SetModelRules.
|
||||||
|
// - Row-level security: WHERE clause injected per user from the SecurityList provider.
|
||||||
|
// - Column-level security: sensitive columns masked/hidden in read results.
|
||||||
|
// - Audit logging after each read.
|
||||||
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// BeforeHandle: enforce model-level operation rules (auth check).
|
||||||
|
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||||
|
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = err.Error()
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// BeforeRead (1st): load RLS + CLS rules from the provider into SecurityList.
|
||||||
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
|
return security.LoadSecurityRules(newSecurityContext(hookCtx), securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// BeforeRead (2nd): apply row-level security — injects a WHERE clause into the query.
|
||||||
|
// resolvemcp has no separate BeforeScan hook; the query is available in BeforeRead.
|
||||||
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
|
return security.ApplyRowSecurity(newSecurityContext(hookCtx), securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// AfterRead (1st): apply column-level security — mask/hide columns in the result.
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
return security.ApplyColumnSecurity(newSecurityContext(hookCtx), securityList)
|
||||||
|
})
|
||||||
|
|
||||||
|
// AfterRead (2nd): audit log.
|
||||||
|
handler.Hooks().Register(AfterRead, func(hookCtx *HookContext) error {
|
||||||
|
return security.LogDataAccess(newSecurityContext(hookCtx))
|
||||||
|
})
|
||||||
|
|
||||||
|
// BeforeUpdate: enforce CanUpdate rule.
|
||||||
|
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||||
|
return security.CheckModelUpdateAllowed(newSecurityContext(hookCtx))
|
||||||
|
})
|
||||||
|
|
||||||
|
// BeforeDelete: enforce CanDelete rule.
|
||||||
|
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||||
|
return security.CheckModelDeleteAllowed(newSecurityContext(hookCtx))
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.Info("Security hooks registered for resolvemcp handler")
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// securityContext — adapts resolvemcp.HookContext to security.SecurityContext
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type securityContext struct {
|
||||||
|
ctx *HookContext
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSecurityContext(ctx *HookContext) security.SecurityContext {
|
||||||
|
return &securityContext{ctx: ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetContext() context.Context {
|
||||||
|
return s.ctx.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetUserID() (int, bool) {
|
||||||
|
return security.GetUserID(s.ctx.Context)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetSchema() string {
|
||||||
|
return s.ctx.Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetEntity() string {
|
||||||
|
return s.ctx.Entity
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetModel() interface{} {
|
||||||
|
return s.ctx.Model
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetQuery() interface{} {
|
||||||
|
return s.ctx.Query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) SetQuery(query interface{}) {
|
||||||
|
if q, ok := query.(common.SelectQuery); ok {
|
||||||
|
s.ctx.Query = q
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) GetResult() interface{} {
|
||||||
|
return s.ctx.Result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *securityContext) SetResult(result interface{}) {
|
||||||
|
s.ctx.Result = result
|
||||||
|
}
|
||||||
@@ -0,0 +1,692 @@
|
|||||||
|
package resolvemcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/reflection"
|
||||||
|
)
|
||||||
|
|
||||||
|
// toolName builds the MCP tool name for a given operation and model.
|
||||||
|
func toolName(operation, schema, entity string) string {
|
||||||
|
if schema == "" {
|
||||||
|
return fmt.Sprintf("%s_%s", operation, entity)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s_%s_%s", operation, schema, entity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerModelTools registers the four CRUD tools and resource for a model.
|
||||||
|
func registerModelTools(h *Handler, schema, entity string, model interface{}) {
|
||||||
|
info := buildModelInfo(schema, entity, model)
|
||||||
|
registerReadTool(h, schema, entity, info)
|
||||||
|
registerCreateTool(h, schema, entity, info)
|
||||||
|
registerUpdateTool(h, schema, entity, info)
|
||||||
|
registerDeleteTool(h, schema, entity, info)
|
||||||
|
registerModelResource(h, schema, entity, info)
|
||||||
|
|
||||||
|
logger.Info("[resolvemcp] Registered MCP tools for %s", info.fullName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Model introspection
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// modelInfo holds pre-computed metadata for a model used in tool descriptions.
|
||||||
|
type modelInfo struct {
|
||||||
|
fullName string // e.g. "public.users"
|
||||||
|
pkName string // e.g. "id"
|
||||||
|
columns []columnInfo
|
||||||
|
relationNames []string
|
||||||
|
schemaDoc string // formatted multi-line schema listing
|
||||||
|
}
|
||||||
|
|
||||||
|
type columnInfo struct {
|
||||||
|
jsonName string
|
||||||
|
sqlName string
|
||||||
|
goType string
|
||||||
|
sqlType string
|
||||||
|
isPrimary bool
|
||||||
|
isUnique bool
|
||||||
|
isFK bool
|
||||||
|
nullable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildModelInfo extracts column metadata and pre-builds the schema documentation string.
|
||||||
|
func buildModelInfo(schema, entity string, model interface{}) modelInfo {
|
||||||
|
info := modelInfo{
|
||||||
|
fullName: buildModelName(schema, entity),
|
||||||
|
pkName: reflection.GetPrimaryKeyName(model),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap to base struct type
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
for modelType != nil && (modelType.Kind() == reflect.Ptr || modelType.Kind() == reflect.Slice) {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
details := reflection.GetModelColumnDetail(reflect.New(modelType).Elem())
|
||||||
|
|
||||||
|
for _, d := range details {
|
||||||
|
// Derive the JSON name from the struct field
|
||||||
|
jsonName := fieldJSONName(modelType, d.Name)
|
||||||
|
if jsonName == "" || jsonName == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip relation fields (slice or user-defined struct that isn't time.Time).
|
||||||
|
fieldType, found := modelType.FieldByName(d.Name)
|
||||||
|
if found {
|
||||||
|
ft := fieldType.Type
|
||||||
|
if ft.Kind() == reflect.Ptr {
|
||||||
|
ft = ft.Elem()
|
||||||
|
}
|
||||||
|
isUserStruct := ft.Kind() == reflect.Struct && ft.Name() != "Time" && ft.PkgPath() != ""
|
||||||
|
if ft.Kind() == reflect.Slice || isUserStruct {
|
||||||
|
info.relationNames = append(info.relationNames, jsonName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sqlName := d.SQLName
|
||||||
|
if sqlName == "" {
|
||||||
|
sqlName = jsonName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive Go type name, unwrapping pointer if needed.
|
||||||
|
goType := d.DataType
|
||||||
|
if goType == "" && found {
|
||||||
|
ft := fieldType.Type
|
||||||
|
for ft.Kind() == reflect.Ptr {
|
||||||
|
ft = ft.Elem()
|
||||||
|
}
|
||||||
|
goType = ft.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPrimary: use both the GORM-tag detection and a name comparison against
|
||||||
|
// the known primary key (handles camelCase "primaryKey" tags correctly).
|
||||||
|
isPrimary := d.SQLKey == "primary_key" ||
|
||||||
|
(info.pkName != "" && (sqlName == info.pkName || jsonName == info.pkName))
|
||||||
|
|
||||||
|
ci := columnInfo{
|
||||||
|
jsonName: jsonName,
|
||||||
|
sqlName: sqlName,
|
||||||
|
goType: goType,
|
||||||
|
sqlType: d.SQLDataType,
|
||||||
|
isPrimary: isPrimary,
|
||||||
|
isUnique: d.SQLKey == "unique" || d.SQLKey == "uniqueindex",
|
||||||
|
isFK: d.SQLKey == "foreign_key",
|
||||||
|
nullable: d.Nullable,
|
||||||
|
}
|
||||||
|
info.columns = append(info.columns, ci)
|
||||||
|
}
|
||||||
|
|
||||||
|
info.schemaDoc = buildSchemaDoc(info)
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
|
||||||
|
// fieldJSONName returns the JSON tag name for a struct field, falling back to the field name.
|
||||||
|
func fieldJSONName(modelType reflect.Type, fieldName string) string {
|
||||||
|
field, ok := modelType.FieldByName(fieldName)
|
||||||
|
if !ok {
|
||||||
|
return fieldName
|
||||||
|
}
|
||||||
|
tag := field.Tag.Get("json")
|
||||||
|
if tag == "" {
|
||||||
|
return fieldName
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(tag, ",", 2)
|
||||||
|
if parts[0] == "" {
|
||||||
|
return fieldName
|
||||||
|
}
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSchemaDoc builds a human-readable column listing for inclusion in tool descriptions.
|
||||||
|
func buildSchemaDoc(info modelInfo) string {
|
||||||
|
if len(info.columns) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
sb.WriteString("Columns:\n")
|
||||||
|
for _, c := range info.columns {
|
||||||
|
line := fmt.Sprintf(" • %s", c.jsonName)
|
||||||
|
|
||||||
|
typeDesc := c.goType
|
||||||
|
if c.sqlType != "" {
|
||||||
|
typeDesc = c.sqlType
|
||||||
|
}
|
||||||
|
if typeDesc != "" {
|
||||||
|
line += fmt.Sprintf(" (%s)", typeDesc)
|
||||||
|
}
|
||||||
|
|
||||||
|
var flags []string
|
||||||
|
if c.isPrimary {
|
||||||
|
flags = append(flags, "primary key")
|
||||||
|
}
|
||||||
|
if c.isUnique {
|
||||||
|
flags = append(flags, "unique")
|
||||||
|
}
|
||||||
|
if c.isFK {
|
||||||
|
flags = append(flags, "foreign key")
|
||||||
|
}
|
||||||
|
if !c.nullable && !c.isPrimary {
|
||||||
|
flags = append(flags, "not null")
|
||||||
|
} else if c.nullable {
|
||||||
|
flags = append(flags, "nullable")
|
||||||
|
}
|
||||||
|
if len(flags) > 0 {
|
||||||
|
line += " — " + strings.Join(flags, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(line + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(info.relationNames) > 0 {
|
||||||
|
sb.WriteString("Relations (preloadable): " + strings.Join(info.relationNames, ", ") + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// columnNameList returns a comma-separated list of JSON column names (for descriptions).
|
||||||
|
func columnNameList(cols []columnInfo) string {
|
||||||
|
names := make([]string, len(cols))
|
||||||
|
for i, c := range cols {
|
||||||
|
names[i] = c.jsonName
|
||||||
|
}
|
||||||
|
return strings.Join(names, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// writableColumnNames returns JSON names for all non-primary-key columns.
|
||||||
|
func writableColumnNames(cols []columnInfo) []string {
|
||||||
|
var names []string
|
||||||
|
for _, c := range cols {
|
||||||
|
if !c.isPrimary {
|
||||||
|
names = append(names, c.jsonName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Read tool
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func registerReadTool(h *Handler, schema, entity string, info modelInfo) {
|
||||||
|
name := toolName("read", schema, entity)
|
||||||
|
|
||||||
|
var descParts []string
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Read records from the '%s' database table.", info.fullName))
|
||||||
|
if info.pkName != "" {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Primary key: '%s'. Pass it via 'id' to fetch a single record.", info.pkName))
|
||||||
|
}
|
||||||
|
if info.schemaDoc != "" {
|
||||||
|
descParts = append(descParts, info.schemaDoc)
|
||||||
|
}
|
||||||
|
descParts = append(descParts,
|
||||||
|
"Pagination: use 'limit'/'offset' for offset-based paging, or 'cursor_forward'/'cursor_backward' (pass the primary key value of the last/first record on the current page) for cursor-based paging.",
|
||||||
|
"Filtering: each filter object requires 'column' (JSON field name) and 'operator'. Supported operators: = != > < >= <= like ilike in is_null is_not_null. Combine with 'logic_operator': AND (default) or OR.",
|
||||||
|
"Sorting: each sort object requires 'column' and 'direction' (asc or desc).",
|
||||||
|
)
|
||||||
|
if len(info.relationNames) > 0 {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Preloadable relations: %s. Pass relation name in 'preloads'.", strings.Join(info.relationNames, ", ")))
|
||||||
|
}
|
||||||
|
|
||||||
|
description := strings.Join(descParts, "\n\n")
|
||||||
|
|
||||||
|
filterDesc := `Array of filter objects. Example: [{"column":"status","operator":"=","value":"active"},{"column":"age","operator":">","value":18,"logic_operator":"AND"}]`
|
||||||
|
if len(info.columns) > 0 {
|
||||||
|
filterDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
|
||||||
|
}
|
||||||
|
|
||||||
|
sortDesc := `Array of sort objects. Example: [{"column":"created_at","direction":"desc"}]`
|
||||||
|
if len(info.columns) > 0 {
|
||||||
|
sortDesc += fmt.Sprintf(" Available columns: %s.", columnNameList(info.columns))
|
||||||
|
}
|
||||||
|
|
||||||
|
tool := mcp.NewTool(name,
|
||||||
|
mcp.WithDescription(description),
|
||||||
|
mcp.WithString("id",
|
||||||
|
mcp.Description(fmt.Sprintf("Primary key (%s) of a single record to fetch. Omit to return multiple records.", info.pkName)),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("limit",
|
||||||
|
mcp.Description("Maximum number of records to return per page. Recommended: 10–100."),
|
||||||
|
),
|
||||||
|
mcp.WithNumber("offset",
|
||||||
|
mcp.Description("Number of records to skip (for offset-based pagination). Use with 'limit'."),
|
||||||
|
),
|
||||||
|
mcp.WithString("cursor_forward",
|
||||||
|
mcp.Description(fmt.Sprintf("Cursor for the next page: pass the '%s' value of the last record on the current page. Requires 'sort' to be set.", info.pkName)),
|
||||||
|
),
|
||||||
|
mcp.WithString("cursor_backward",
|
||||||
|
mcp.Description(fmt.Sprintf("Cursor for the previous page: pass the '%s' value of the first record on the current page. Requires 'sort' to be set.", info.pkName)),
|
||||||
|
),
|
||||||
|
mcp.WithArray("columns",
|
||||||
|
mcp.Description(fmt.Sprintf("Columns to include in the result. Omit to return all columns. Available: %s.", columnNameList(info.columns))),
|
||||||
|
),
|
||||||
|
mcp.WithArray("omit_columns",
|
||||||
|
mcp.Description(fmt.Sprintf("Columns to exclude from the result. Available: %s.", columnNameList(info.columns))),
|
||||||
|
),
|
||||||
|
mcp.WithArray("filters",
|
||||||
|
mcp.Description(filterDesc),
|
||||||
|
),
|
||||||
|
mcp.WithArray("sort",
|
||||||
|
mcp.Description(sortDesc),
|
||||||
|
),
|
||||||
|
mcp.WithArray("preloads",
|
||||||
|
mcp.Description(buildPreloadDesc(info)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
args := req.GetArguments()
|
||||||
|
id, _ := args["id"].(string)
|
||||||
|
options := parseRequestOptions(args)
|
||||||
|
|
||||||
|
data, metadata, err := h.executeRead(ctx, schema, entity, id, options)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(err.Error()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalResult(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"data": data,
|
||||||
|
"metadata": metadata,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildPreloadDesc(info modelInfo) string {
|
||||||
|
if len(info.relationNames) == 0 {
|
||||||
|
return `Array of relation preload objects. Each object: {"relation":"RelationName"}. No relations defined on this model.`
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(
|
||||||
|
`Array of relation preload objects. Each object: {"relation":"RelationName","columns":["col1","col2"]}. Available relations: %s.`,
|
||||||
|
strings.Join(info.relationNames, ", "),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Create tool
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func registerCreateTool(h *Handler, schema, entity string, info modelInfo) {
|
||||||
|
name := toolName("create", schema, entity)
|
||||||
|
|
||||||
|
writable := writableColumnNames(info.columns)
|
||||||
|
|
||||||
|
var descParts []string
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Create one or more new records in the '%s' table.", info.fullName))
|
||||||
|
if len(writable) > 0 {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Writable fields: %s.", strings.Join(writable, ", ")))
|
||||||
|
}
|
||||||
|
if info.pkName != "" {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("The primary key ('%s') is typically auto-generated — omit it unless you need to supply it explicitly.", info.pkName))
|
||||||
|
}
|
||||||
|
descParts = append(descParts,
|
||||||
|
"Pass a single JSON object to 'data' to create one record. Pass an array of objects to create multiple records in a single transaction (all succeed or all fail).",
|
||||||
|
)
|
||||||
|
if info.schemaDoc != "" {
|
||||||
|
descParts = append(descParts, info.schemaDoc)
|
||||||
|
}
|
||||||
|
|
||||||
|
description := strings.Join(descParts, "\n\n")
|
||||||
|
|
||||||
|
dataDesc := "Record fields to create."
|
||||||
|
if len(writable) > 0 {
|
||||||
|
dataDesc += fmt.Sprintf(" Writable fields: %s.", strings.Join(writable, ", "))
|
||||||
|
}
|
||||||
|
dataDesc += " Pass a single object or an array of objects."
|
||||||
|
|
||||||
|
tool := mcp.NewTool(name,
|
||||||
|
mcp.WithDescription(description),
|
||||||
|
mcp.WithObject("data",
|
||||||
|
mcp.Description(dataDesc),
|
||||||
|
mcp.Required(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
args := req.GetArguments()
|
||||||
|
data, ok := args["data"]
|
||||||
|
if !ok {
|
||||||
|
return mcp.NewToolResultError("missing required argument: data"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.executeCreate(ctx, schema, entity, data)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(err.Error()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalResult(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Update tool
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func registerUpdateTool(h *Handler, schema, entity string, info modelInfo) {
|
||||||
|
name := toolName("update", schema, entity)
|
||||||
|
|
||||||
|
writable := writableColumnNames(info.columns)
|
||||||
|
|
||||||
|
var descParts []string
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Update an existing record in the '%s' table.", info.fullName))
|
||||||
|
if info.pkName != "" {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Identify the record by its primary key ('%s') via the 'id' argument or by including '%s' inside 'data'.", info.pkName, info.pkName))
|
||||||
|
}
|
||||||
|
if len(writable) > 0 {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Updatable fields: %s.", strings.Join(writable, ", ")))
|
||||||
|
}
|
||||||
|
descParts = append(descParts,
|
||||||
|
"Only non-null, non-empty fields in 'data' are applied — existing values are preserved for fields you omit. Returns the merged record as stored.",
|
||||||
|
)
|
||||||
|
if info.schemaDoc != "" {
|
||||||
|
descParts = append(descParts, info.schemaDoc)
|
||||||
|
}
|
||||||
|
|
||||||
|
description := strings.Join(descParts, "\n\n")
|
||||||
|
|
||||||
|
idDesc := fmt.Sprintf("Primary key ('%s') of the record to update. Can also be included inside 'data'.", info.pkName)
|
||||||
|
|
||||||
|
dataDesc := "Fields to update (non-null, non-empty values are merged into the existing record)."
|
||||||
|
if len(writable) > 0 {
|
||||||
|
dataDesc += fmt.Sprintf(" Updatable fields: %s.", strings.Join(writable, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
tool := mcp.NewTool(name,
|
||||||
|
mcp.WithDescription(description),
|
||||||
|
mcp.WithString("id",
|
||||||
|
mcp.Description(idDesc),
|
||||||
|
),
|
||||||
|
mcp.WithObject("data",
|
||||||
|
mcp.Description(dataDesc),
|
||||||
|
mcp.Required(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
args := req.GetArguments()
|
||||||
|
id, _ := args["id"].(string)
|
||||||
|
|
||||||
|
data, ok := args["data"]
|
||||||
|
if !ok {
|
||||||
|
return mcp.NewToolResultError("missing required argument: data"), nil
|
||||||
|
}
|
||||||
|
dataMap, ok := data.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return mcp.NewToolResultError("data must be an object"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := h.executeUpdate(ctx, schema, entity, id, dataMap)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(err.Error()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalResult(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Delete tool
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func registerDeleteTool(h *Handler, schema, entity string, info modelInfo) {
|
||||||
|
name := toolName("delete", schema, entity)
|
||||||
|
|
||||||
|
descParts := []string{
|
||||||
|
fmt.Sprintf("Delete a record from the '%s' table by its primary key.", info.fullName),
|
||||||
|
}
|
||||||
|
if info.pkName != "" {
|
||||||
|
descParts = append(descParts, fmt.Sprintf("Pass the '%s' value of the record to delete via the 'id' argument.", info.pkName))
|
||||||
|
}
|
||||||
|
descParts = append(descParts, "Returns the deleted record. This operation is irreversible.")
|
||||||
|
|
||||||
|
description := strings.Join(descParts, " ")
|
||||||
|
|
||||||
|
tool := mcp.NewTool(name,
|
||||||
|
mcp.WithDescription(description),
|
||||||
|
mcp.WithString("id",
|
||||||
|
mcp.Description(fmt.Sprintf("Primary key ('%s') of the record to delete.", info.pkName)),
|
||||||
|
mcp.Required(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mcpServer.AddTool(tool, func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||||
|
args := req.GetArguments()
|
||||||
|
id, _ := args["id"].(string)
|
||||||
|
|
||||||
|
result, err := h.executeDelete(ctx, schema, entity, id)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(err.Error()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return marshalResult(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Resource registration
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func registerModelResource(h *Handler, schema, entity string, info modelInfo) {
|
||||||
|
resourceURI := info.fullName
|
||||||
|
|
||||||
|
var resourceDesc strings.Builder
|
||||||
|
fmt.Fprintf(&resourceDesc, "Database table: %s", info.fullName)
|
||||||
|
if info.pkName != "" {
|
||||||
|
fmt.Fprintf(&resourceDesc, " (primary key: %s)", info.pkName)
|
||||||
|
}
|
||||||
|
if info.schemaDoc != "" {
|
||||||
|
resourceDesc.WriteString("\n\n")
|
||||||
|
resourceDesc.WriteString(info.schemaDoc)
|
||||||
|
}
|
||||||
|
|
||||||
|
resource := mcp.NewResource(
|
||||||
|
resourceURI,
|
||||||
|
entity,
|
||||||
|
mcp.WithResourceDescription(resourceDesc.String()),
|
||||||
|
mcp.WithMIMEType("application/json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
h.mcpServer.AddResource(resource, func(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
|
||||||
|
limit := 100
|
||||||
|
options := common.RequestOptions{Limit: &limit}
|
||||||
|
|
||||||
|
data, metadata, err := h.executeRead(ctx, schema, entity, "", options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"data": data,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error marshaling resource: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return []mcp.ResourceContents{
|
||||||
|
mcp.TextResourceContents{
|
||||||
|
URI: req.Params.URI,
|
||||||
|
MIMEType: "application/json",
|
||||||
|
Text: string(jsonBytes),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Argument parsing helpers
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// parseRequestOptions converts raw MCP tool arguments into common.RequestOptions.
|
||||||
|
func parseRequestOptions(args map[string]interface{}) common.RequestOptions {
|
||||||
|
options := common.RequestOptions{}
|
||||||
|
|
||||||
|
if v, ok := args["limit"]; ok {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case float64:
|
||||||
|
limit := int(n)
|
||||||
|
options.Limit = &limit
|
||||||
|
case int:
|
||||||
|
options.Limit = &n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := args["offset"]; ok {
|
||||||
|
switch n := v.(type) {
|
||||||
|
case float64:
|
||||||
|
offset := int(n)
|
||||||
|
options.Offset = &offset
|
||||||
|
case int:
|
||||||
|
options.Offset = &n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if v, ok := args["cursor_forward"].(string); ok {
|
||||||
|
options.CursorForward = v
|
||||||
|
}
|
||||||
|
if v, ok := args["cursor_backward"].(string); ok {
|
||||||
|
options.CursorBackward = v
|
||||||
|
}
|
||||||
|
|
||||||
|
options.Columns = parseStringArray(args["columns"])
|
||||||
|
options.OmitColumns = parseStringArray(args["omit_columns"])
|
||||||
|
options.Filters = parseFilters(args["filters"])
|
||||||
|
options.Sort = parseSortOptions(args["sort"])
|
||||||
|
options.Preload = parsePreloadOptions(args["preloads"])
|
||||||
|
|
||||||
|
return options
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseStringArray(raw interface{}) []string {
|
||||||
|
if raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
items, ok := raw.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]string, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
if s, ok := item.(string); ok {
|
||||||
|
result = append(result, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseFilters(raw interface{}) []common.FilterOption {
|
||||||
|
if raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
items, ok := raw.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]common.FilterOption, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
b, err := json.Marshal(item)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var f common.FilterOption
|
||||||
|
if err := json.Unmarshal(b, &f); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f.Column == "" || f.Operator == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(f.LogicOperator, "or") {
|
||||||
|
f.LogicOperator = "OR"
|
||||||
|
} else {
|
||||||
|
f.LogicOperator = "AND"
|
||||||
|
}
|
||||||
|
result = append(result, f)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSortOptions(raw interface{}) []common.SortOption {
|
||||||
|
if raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
items, ok := raw.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]common.SortOption, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
b, err := json.Marshal(item)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var s common.SortOption
|
||||||
|
if err := json.Unmarshal(b, &s); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s.Column == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, s)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePreloadOptions(raw interface{}) []common.PreloadOption {
|
||||||
|
if raw == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
items, ok := raw.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]common.PreloadOption, 0, len(items))
|
||||||
|
for _, item := range items {
|
||||||
|
b, err := json.Marshal(item)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var p common.PreloadOption
|
||||||
|
if err := json.Unmarshal(b, &p); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if p.Relation == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, p)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// marshalResult marshals a value to JSON and returns it as an MCP text result.
|
||||||
|
func marshalResult(v interface{}) (*mcp.CallToolResult, error) {
|
||||||
|
b, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return mcp.NewToolResultError(fmt.Sprintf("error marshaling result: %v", err)), nil
|
||||||
|
}
|
||||||
|
return mcp.NewToolResultText(string(b)), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,572 @@
|
|||||||
|
# ResolveSpec Query Features Examples
|
||||||
|
|
||||||
|
This document provides examples of using the advanced query features in ResolveSpec, including OR logic filters, Custom Operators, and FetchRowNumber.
|
||||||
|
|
||||||
|
## OR Logic in Filters (SearchOr)
|
||||||
|
|
||||||
|
### Basic OR Filter Example
|
||||||
|
|
||||||
|
Find all users with status "active" OR "pending":
|
||||||
|
|
||||||
|
```json
|
||||||
|
POST /users
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "active"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "pending",
|
||||||
|
"logic_operator": "OR"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Combined AND/OR Filters
|
||||||
|
|
||||||
|
Find users with (status="active" OR status="pending") AND age >= 18:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "active"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "pending",
|
||||||
|
"logic_operator": "OR"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "age",
|
||||||
|
"operator": "gte",
|
||||||
|
"value": 18
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**SQL Generated:** `WHERE (status = 'active' OR status = 'pending') AND age >= 18`
|
||||||
|
|
||||||
|
**Important Notes:**
|
||||||
|
- By default, filters use AND logic
|
||||||
|
- Consecutive filters with `"logic_operator": "OR"` are automatically grouped with parentheses
|
||||||
|
- This grouping ensures OR conditions don't interfere with AND conditions
|
||||||
|
- You don't need to specify `"logic_operator": "AND"` as it's the default
|
||||||
|
|
||||||
|
### Multiple OR Groups
|
||||||
|
|
||||||
|
You can have multiple separate OR groups:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "active"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "pending",
|
||||||
|
"logic_operator": "OR"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "priority",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "high"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "priority",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "urgent",
|
||||||
|
"logic_operator": "OR"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**SQL Generated:** `WHERE (status = 'active' OR status = 'pending') AND (priority = 'high' OR priority = 'urgent')`
|
||||||
|
|
||||||
|
## Custom Operators
|
||||||
|
|
||||||
|
### Simple Custom SQL Condition
|
||||||
|
|
||||||
|
Filter by email domain using custom SQL:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"customOperators": [
|
||||||
|
{
|
||||||
|
"name": "company_emails",
|
||||||
|
"sql": "email LIKE '%@company.com'"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Multiple Custom Operators
|
||||||
|
|
||||||
|
Combine multiple custom SQL conditions:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"customOperators": [
|
||||||
|
{
|
||||||
|
"name": "recent_active",
|
||||||
|
"sql": "last_login > NOW() - INTERVAL '30 days'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "high_score",
|
||||||
|
"sql": "score > 1000"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Complex Custom Operator
|
||||||
|
|
||||||
|
Use complex SQL expressions:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"customOperators": [
|
||||||
|
{
|
||||||
|
"name": "priority_users",
|
||||||
|
"sql": "(subscription = 'premium' AND points > 500) OR (subscription = 'enterprise')"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Combining Custom Operators with Regular Filters
|
||||||
|
|
||||||
|
Mix custom operators with standard filters:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "country",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "USA"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"customOperators": [
|
||||||
|
{
|
||||||
|
"name": "active_last_month",
|
||||||
|
"sql": "last_activity > NOW() - INTERVAL '1 month'"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Row Numbers
|
||||||
|
|
||||||
|
### Two Ways to Get Row Numbers
|
||||||
|
|
||||||
|
There are two different features for row numbers:
|
||||||
|
|
||||||
|
1. **`fetch_row_number`** - Get the position of ONE specific record in a sorted/filtered set
|
||||||
|
2. **`RowNumber` field in models** - Automatically number all records in the response
|
||||||
|
|
||||||
|
### 1. FetchRowNumber - Get Position of Specific Record
|
||||||
|
|
||||||
|
Get the rank/position of a specific user in a leaderboard. **Important:** When `fetch_row_number` is specified, the response contains **ONLY that specific record**, not all records.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"sort": [
|
||||||
|
{
|
||||||
|
"column": "score",
|
||||||
|
"direction": "desc"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"fetch_row_number": "12345"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response - Contains ONLY the specified user:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": {
|
||||||
|
"id": 12345,
|
||||||
|
"name": "Alice Smith",
|
||||||
|
"score": 9850,
|
||||||
|
"level": 42
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"total": 10000,
|
||||||
|
"count": 1,
|
||||||
|
"filtered": 10000,
|
||||||
|
"row_number": 42
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Result:** User "12345" is ranked #42 out of 10,000 users. The response includes only Alice's data, not the other 9,999 users.
|
||||||
|
|
||||||
|
### Row Number with Filters
|
||||||
|
|
||||||
|
Find position within a filtered subset (e.g., "What's my rank in my country?"):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "country",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "USA"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "active"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sort": [
|
||||||
|
{
|
||||||
|
"column": "score",
|
||||||
|
"direction": "desc"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"fetch_row_number": "12345"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": {
|
||||||
|
"id": 12345,
|
||||||
|
"name": "Bob Johnson",
|
||||||
|
"country": "USA",
|
||||||
|
"score": 7200,
|
||||||
|
"status": "active"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"total": 2500,
|
||||||
|
"count": 1,
|
||||||
|
"filtered": 2500,
|
||||||
|
"row_number": 156
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Result:** Bob is ranked #156 out of 2,500 active USA users. Only Bob's record is returned.
|
||||||
|
|
||||||
|
### 2. RowNumber Field - Auto-Number All Records
|
||||||
|
|
||||||
|
If your model has a `RowNumber int64` field, restheadspec will automatically populate it for paginated results.
|
||||||
|
|
||||||
|
**Model Definition:**
|
||||||
|
```go
|
||||||
|
type Player struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Score int64 `json:"score"`
|
||||||
|
RowNumber int64 `json:"row_number"` // Will be auto-populated
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Request (with pagination):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"sort": [{"column": "score", "direction": "desc"}],
|
||||||
|
"limit": 10,
|
||||||
|
"offset": 20
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response - RowNumber automatically set:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": 456,
|
||||||
|
"name": "Player21",
|
||||||
|
"score": 8900,
|
||||||
|
"row_number": 21
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 789,
|
||||||
|
"name": "Player22",
|
||||||
|
"score": 8850,
|
||||||
|
"row_number": 22
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 123,
|
||||||
|
"name": "Player23",
|
||||||
|
"score": 8800,
|
||||||
|
"row_number": 23
|
||||||
|
}
|
||||||
|
// ... records 24-30 ...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**How It Works:**
|
||||||
|
- `row_number = offset + index + 1` (1-based)
|
||||||
|
- With offset=20, first record gets row_number=21
|
||||||
|
- With offset=20, second record gets row_number=22
|
||||||
|
- Perfect for displaying "Rank" in paginated tables
|
||||||
|
|
||||||
|
**Use Case:** Displaying leaderboards with rank numbers:
|
||||||
|
```
|
||||||
|
Rank | Player | Score
|
||||||
|
-----|-----------|-------
|
||||||
|
21 | Player21 | 8900
|
||||||
|
22 | Player22 | 8850
|
||||||
|
23 | Player23 | 8800
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note:** This feature is available in all three packages: resolvespec, restheadspec, and websocketspec.
|
||||||
|
|
||||||
|
### When to Use Each Feature
|
||||||
|
|
||||||
|
| Feature | Use Case | Returns | Performance |
|
||||||
|
|---------|----------|---------|-------------|
|
||||||
|
| `fetch_row_number` | "What's my rank?" | 1 record with position | Fast - 1 record |
|
||||||
|
| `RowNumber` field | "Show top 10 with ranks" | Many records numbered | Fast - simple math |
|
||||||
|
|
||||||
|
**Combined Example - Full Leaderboard UI:**
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
// Request 1: Get current user's rank
|
||||||
|
const userRank = await api.read({
|
||||||
|
fetch_row_number: currentUserId,
|
||||||
|
sort: [{column: "score", direction: "desc"}]
|
||||||
|
});
|
||||||
|
// Returns: {id: 123, name: "You", score: 7500, row_number: 156}
|
||||||
|
|
||||||
|
// Request 2: Get top 10 with rank numbers
|
||||||
|
const top10 = await api.read({
|
||||||
|
sort: [{column: "score", direction: "desc"}],
|
||||||
|
limit: 10,
|
||||||
|
offset: 0
|
||||||
|
});
|
||||||
|
// Returns: [{row_number: 1, ...}, {row_number: 2, ...}, ...]
|
||||||
|
|
||||||
|
// Display:
|
||||||
|
// "Your Rank: #156"
|
||||||
|
// "Top Players:"
|
||||||
|
// "#1 - Alice - 9999"
|
||||||
|
// "#2 - Bob - 9876"
|
||||||
|
// ...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Complete Example: Advanced Query
|
||||||
|
|
||||||
|
Combine all features for a complex query:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"columns": ["id", "name", "email", "score", "status"],
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "active"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "trial",
|
||||||
|
"logic_operator": "OR"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "score",
|
||||||
|
"operator": "gte",
|
||||||
|
"value": 100
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"customOperators": [
|
||||||
|
{
|
||||||
|
"name": "recent_activity",
|
||||||
|
"sql": "last_login > NOW() - INTERVAL '7 days'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "verified_email",
|
||||||
|
"sql": "email_verified = true"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sort": [
|
||||||
|
{
|
||||||
|
"column": "score",
|
||||||
|
"direction": "desc"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "created_at",
|
||||||
|
"direction": "asc"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"fetch_row_number": "12345",
|
||||||
|
"limit": 50,
|
||||||
|
"offset": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This query:
|
||||||
|
- Selects specific columns
|
||||||
|
- Filters for users with status "active" OR "trial"
|
||||||
|
- AND score >= 100
|
||||||
|
- Applies custom SQL conditions for recent activity and verified emails
|
||||||
|
- Sorts by score (descending) then creation date (ascending)
|
||||||
|
- Returns the row number of user "12345" in this filtered/sorted set
|
||||||
|
- Returns 50 records starting from the first one
|
||||||
|
|
||||||
|
## Use Cases
|
||||||
|
|
||||||
|
### 1. Leaderboards - Get Current User's Rank
|
||||||
|
|
||||||
|
Get the current user's position and data (returns only their record):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "game_id",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "game123"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sort": [
|
||||||
|
{
|
||||||
|
"column": "score",
|
||||||
|
"direction": "desc"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"fetch_row_number": "current_user_id"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Tip:** For full leaderboards, make two requests:
|
||||||
|
1. One with `fetch_row_number` to get user's rank
|
||||||
|
2. One with `limit` and `offset` to get top players list
|
||||||
|
|
||||||
|
### 2. Multi-Status Search
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "order_status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "pending"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "order_status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "processing",
|
||||||
|
"logic_operator": "OR"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "order_status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "shipped",
|
||||||
|
"logic_operator": "OR"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Advanced Date Filtering
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"customOperators": [
|
||||||
|
{
|
||||||
|
"name": "this_month",
|
||||||
|
"sql": "created_at >= DATE_TRUNC('month', CURRENT_DATE)"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "business_hours",
|
||||||
|
"sql": "EXTRACT(HOUR FROM created_at) BETWEEN 9 AND 17"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Security Considerations
|
||||||
|
|
||||||
|
**Warning:** Custom operators allow raw SQL, which can be a security risk if not properly handled:
|
||||||
|
|
||||||
|
1. **Never** directly interpolate user input into custom operator SQL
|
||||||
|
2. Always validate and sanitize custom operator SQL on the backend
|
||||||
|
3. Consider using a whitelist of allowed custom operators
|
||||||
|
4. Use prepared statements or parameterized queries when possible
|
||||||
|
5. Implement proper authorization checks before executing queries
|
||||||
|
|
||||||
|
Example of safe custom operator handling in Go:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Whitelist of allowed custom operators
|
||||||
|
allowedOperators := map[string]string{
|
||||||
|
"recent_week": "created_at > NOW() - INTERVAL '7 days'",
|
||||||
|
"active_users": "status = 'active' AND last_login > NOW() - INTERVAL '30 days'",
|
||||||
|
"premium_only": "subscription_level = 'premium'",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate custom operators from request
|
||||||
|
for _, op := range req.Options.CustomOperators {
|
||||||
|
if sql, ok := allowedOperators[op.Name]; ok {
|
||||||
|
op.SQL = sql // Use whitelisted SQL
|
||||||
|
} else {
|
||||||
|
return errors.New("custom operator not allowed: " + op.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
+154
-4
@@ -214,6 +214,146 @@ Content-Type: application/json
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### OR Logic in Filters (SearchOr)
|
||||||
|
|
||||||
|
Use the `logic_operator` field to combine filters with OR logic instead of the default AND:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "active"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "pending",
|
||||||
|
"logic_operator": "OR"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "priority",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "high",
|
||||||
|
"logic_operator": "OR"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
This will produce: `WHERE (status = 'active' OR status = 'pending' OR priority = 'high')`
|
||||||
|
|
||||||
|
**Important:** Consecutive OR filters are automatically grouped together with parentheses to ensure proper query logic.
|
||||||
|
|
||||||
|
#### Mixing AND and OR
|
||||||
|
|
||||||
|
Consecutive OR filters are grouped, then combined with AND filters:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "active"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "pending",
|
||||||
|
"logic_operator": "OR"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"column": "age",
|
||||||
|
"operator": "gte",
|
||||||
|
"value": 18
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Produces: `WHERE (status = 'active' OR status = 'pending') AND age >= 18`
|
||||||
|
|
||||||
|
This grouping ensures OR conditions don't interfere with other AND conditions in the query.
|
||||||
|
|
||||||
|
### Custom Operators
|
||||||
|
|
||||||
|
Add custom SQL conditions when needed:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"customOperators": [
|
||||||
|
{
|
||||||
|
"name": "email_domain_filter",
|
||||||
|
"sql": "LOWER(email) LIKE '%@example.com'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "recent_records",
|
||||||
|
"sql": "created_at > NOW() - INTERVAL '7 days'"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Custom operators are applied as additional WHERE conditions to your query.
|
||||||
|
|
||||||
|
### Fetch Row Number
|
||||||
|
|
||||||
|
Get the row number (position) of a specific record in the filtered and sorted result set. **When `fetch_row_number` is specified, only that specific record is returned** (not all records).
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"operation": "read",
|
||||||
|
"options": {
|
||||||
|
"filters": [
|
||||||
|
{
|
||||||
|
"column": "status",
|
||||||
|
"operator": "eq",
|
||||||
|
"value": "active"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"sort": [
|
||||||
|
{
|
||||||
|
"column": "score",
|
||||||
|
"direction": "desc"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"fetch_row_number": "12345"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response - Returns ONLY the specified record with its position:**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": {
|
||||||
|
"id": 12345,
|
||||||
|
"name": "John Doe",
|
||||||
|
"score": 850,
|
||||||
|
"status": "active"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"total": 1000,
|
||||||
|
"count": 1,
|
||||||
|
"filtered": 1000,
|
||||||
|
"row_number": 42
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Use Case:** Perfect for "Show me this user and their ranking" - you get just that one user with their position in the leaderboard.
|
||||||
|
|
||||||
|
**Note:** This is different from the `RowNumber` field feature, which automatically numbers all records in a paginated response based on offset. That feature uses simple math (`offset + index + 1`), while `fetch_row_number` uses SQL window functions to calculate the actual position in a sorted/filtered set. To use the `RowNumber` field feature, simply add a `RowNumber int64` field to your model - it will be automatically populated with the row position based on pagination.
|
||||||
|
|
||||||
## Preloading
|
## Preloading
|
||||||
|
|
||||||
Load related entities with custom configuration:
|
Load related entities with custom configuration:
|
||||||
@@ -427,7 +567,7 @@ Define virtual columns using SQL expressions:
|
|||||||
|
|
||||||
## Custom Operators
|
## Custom Operators
|
||||||
|
|
||||||
Add custom SQL conditions when needed:
|
Add custom SQL conditions when standard filters aren't sufficient:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@@ -435,17 +575,24 @@ Add custom SQL conditions when needed:
|
|||||||
"options": {
|
"options": {
|
||||||
"customOperators": [
|
"customOperators": [
|
||||||
{
|
{
|
||||||
"condition": "LOWER(email) LIKE ?",
|
"name": "email_domain_filter",
|
||||||
"values": ["%@example.com"]
|
"sql": "LOWER(email) LIKE '%@example.com'"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"condition": "created_at > NOW() - INTERVAL '7 days'"
|
"name": "recent_records",
|
||||||
|
"sql": "created_at > NOW() - INTERVAL '7 days'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "complex_condition",
|
||||||
|
"sql": "(status = 'active' AND score > 100) OR (status = 'pending' AND priority = 'high')"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Note:** Custom operators are applied as WHERE conditions. Make sure to properly escape and sanitize any user input to prevent SQL injection.
|
||||||
|
|
||||||
## Lifecycle Hooks
|
## Lifecycle Hooks
|
||||||
|
|
||||||
Register hooks for all CRUD operations:
|
Register hooks for all CRUD operations:
|
||||||
@@ -497,6 +644,7 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Available Hook Types**:
|
**Available Hook Types**:
|
||||||
|
* `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
|
||||||
* `BeforeRead`, `AfterRead`
|
* `BeforeRead`, `AfterRead`
|
||||||
* `BeforeCreate`, `AfterCreate`
|
* `BeforeCreate`, `AfterCreate`
|
||||||
* `BeforeUpdate`, `AfterUpdate`
|
* `BeforeUpdate`, `AfterUpdate`
|
||||||
@@ -507,11 +655,13 @@ handler.Hooks().Register(resolvespec.BeforeCreate, func(ctx *resolvespec.HookCon
|
|||||||
* `Handler`: Access to handler, database, and registry
|
* `Handler`: Access to handler, database, and registry
|
||||||
* `Schema`, `Entity`, `TableName`: Request info
|
* `Schema`, `Entity`, `TableName`: Request info
|
||||||
* `Model`: The registered model type
|
* `Model`: The registered model type
|
||||||
|
* `Operation`: Current operation string (`"read"`, `"create"`, `"update"`, `"delete"`)
|
||||||
* `Options`: Parsed request options (filters, sorting, etc.)
|
* `Options`: Parsed request options (filters, sorting, etc.)
|
||||||
* `ID`: Record ID (for single-record operations)
|
* `ID`: Record ID (for single-record operations)
|
||||||
* `Data`: Request data (for create/update)
|
* `Data`: Request data (for create/update)
|
||||||
* `Result`: Operation result (for after hooks)
|
* `Result`: Operation result (for after hooks)
|
||||||
* `Writer`: Response writer (allows hooks to modify response)
|
* `Writer`: Response writer (allows hooks to modify response)
|
||||||
|
* `Abort`, `AbortMessage`, `AbortCode`: Set in hook to abort with an error response
|
||||||
|
|
||||||
## Model Registration
|
## Model Registration
|
||||||
|
|
||||||
|
|||||||
+44
-13
@@ -24,6 +24,7 @@ const (
|
|||||||
// - pkName: primary key column (e.g. "id")
|
// - pkName: primary key column (e.g. "id")
|
||||||
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
// - modelColumns: optional list of valid main-table columns (for validation). Pass nil to skip.
|
||||||
// - options: the request options containing sort and cursor information
|
// - options: the request options containing sort and cursor information
|
||||||
|
// - expandJoins: optional map[alias]string of JOIN clauses for join-column sort support
|
||||||
//
|
//
|
||||||
// Returns SQL snippet to embed in WHERE clause.
|
// Returns SQL snippet to embed in WHERE clause.
|
||||||
func GetCursorFilter(
|
func GetCursorFilter(
|
||||||
@@ -31,8 +32,10 @@ func GetCursorFilter(
|
|||||||
pkName string,
|
pkName string,
|
||||||
modelColumns []string,
|
modelColumns []string,
|
||||||
options common.RequestOptions,
|
options common.RequestOptions,
|
||||||
|
expandJoins map[string]string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
// Remove schema prefix if present
|
// Separate schema prefix from bare table name
|
||||||
|
fullTableName := tableName
|
||||||
if strings.Contains(tableName, ".") {
|
if strings.Contains(tableName, ".") {
|
||||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||||
}
|
}
|
||||||
@@ -57,18 +60,19 @@ func GetCursorFilter(
|
|||||||
// 3. Prepare
|
// 3. Prepare
|
||||||
// --------------------------------------------------------------------- //
|
// --------------------------------------------------------------------- //
|
||||||
var whereClauses []string
|
var whereClauses []string
|
||||||
|
joinSQL := ""
|
||||||
reverse := direction < 0
|
reverse := direction < 0
|
||||||
|
|
||||||
// --------------------------------------------------------------------- //
|
// --------------------------------------------------------------------- //
|
||||||
// 4. Process each sort column
|
// 4. Process each sort column
|
||||||
// --------------------------------------------------------------------- //
|
// --------------------------------------------------------------------- //
|
||||||
for _, s := range sortItems {
|
for _, s := range sortItems {
|
||||||
col := strings.TrimSpace(s.Column)
|
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||||
if col == "" {
|
if col == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse: "created_at", "user.name", etc.
|
// Parse: "created_at", "user.name", "fn.sortorder", etc.
|
||||||
parts := strings.Split(col, ".")
|
parts := strings.Split(col, ".")
|
||||||
field := strings.TrimSpace(parts[len(parts)-1])
|
field := strings.TrimSpace(parts[len(parts)-1])
|
||||||
prefix := strings.Join(parts[:len(parts)-1], ".")
|
prefix := strings.Join(parts[:len(parts)-1], ".")
|
||||||
@@ -81,7 +85,7 @@ func GetCursorFilter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Resolve column
|
// Resolve column
|
||||||
cursorCol, targetCol, err := resolveColumn(
|
cursorCol, targetCol, isJoin, err := resolveColumn(
|
||||||
field, prefix, tableName, modelColumns,
|
field, prefix, tableName, modelColumns,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -89,6 +93,22 @@ func GetCursorFilter(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle joins
|
||||||
|
if isJoin {
|
||||||
|
if expandJoins != nil {
|
||||||
|
if joinClause, ok := expandJoins[prefix]; ok {
|
||||||
|
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||||
|
joinSQL = jSQL
|
||||||
|
cursorCol = cRef + "." + field
|
||||||
|
targetCol = prefix + "." + field
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cursorCol == "" {
|
||||||
|
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Build inequality
|
// Build inequality
|
||||||
op := "<"
|
op := "<"
|
||||||
if desc {
|
if desc {
|
||||||
@@ -112,10 +132,12 @@ func GetCursorFilter(
|
|||||||
query := fmt.Sprintf(`EXISTS (
|
query := fmt.Sprintf(`EXISTS (
|
||||||
SELECT 1
|
SELECT 1
|
||||||
FROM %s cursor_select
|
FROM %s cursor_select
|
||||||
|
%s
|
||||||
WHERE cursor_select.%s = %s
|
WHERE cursor_select.%s = %s
|
||||||
AND (%s)
|
AND (%s)
|
||||||
)`,
|
)`,
|
||||||
tableName,
|
fullTableName,
|
||||||
|
joinSQL,
|
||||||
pkName,
|
pkName,
|
||||||
cursorID,
|
cursorID,
|
||||||
orSQL,
|
orSQL,
|
||||||
@@ -136,35 +158,44 @@ func getActiveCursor(options common.RequestOptions) (id string, direction Cursor
|
|||||||
return "", 0
|
return "", 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper: resolve column (main table only for now)
|
// Helper: resolve column (main table or join)
|
||||||
func resolveColumn(
|
func resolveColumn(
|
||||||
field, prefix, tableName string,
|
field, prefix, tableName string,
|
||||||
modelColumns []string,
|
modelColumns []string,
|
||||||
) (cursorCol, targetCol string, err error) {
|
) (cursorCol, targetCol string, isJoin bool, err error) {
|
||||||
|
|
||||||
// JSON field
|
// JSON field
|
||||||
if strings.Contains(field, "->") {
|
if strings.Contains(field, "->") {
|
||||||
return "cursor_select." + field, tableName + "." + field, nil
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Main table column
|
// Main table column
|
||||||
if modelColumns != nil {
|
if modelColumns != nil {
|
||||||
for _, col := range modelColumns {
|
for _, col := range modelColumns {
|
||||||
if strings.EqualFold(col, field) {
|
if strings.EqualFold(col, field) {
|
||||||
return "cursor_select." + field, tableName + "." + field, nil
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// No validation → allow all main-table fields
|
// No validation → allow all main-table fields
|
||||||
return "cursor_select." + field, tableName + "." + field, nil
|
return "cursor_select." + field, tableName + "." + field, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Joined column (not supported in resolvespec yet)
|
// Joined column
|
||||||
if prefix != "" && prefix != tableName {
|
if prefix != "" && prefix != tableName {
|
||||||
return "", "", fmt.Errorf("joined columns not supported in cursor pagination: %s", field)
|
return "", "", true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return "", "", fmt.Errorf("invalid column: %s", field)
|
return "", "", false, fmt.Errorf("invalid column: %s", field)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper: rewrite JOIN clause for cursor subquery
|
||||||
|
func rewriteJoin(joinClause, mainTable, alias string) (joinSQL, cursorAlias string) {
|
||||||
|
joinSQL = strings.ReplaceAll(joinClause, mainTable+".", "cursor_select.")
|
||||||
|
cursorAlias = "cursor_select_" + alias
|
||||||
|
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+" ", " "+cursorAlias+" ")
|
||||||
|
joinSQL = strings.ReplaceAll(joinSQL, " "+alias+".", " "+cursorAlias+".")
|
||||||
|
return joinSQL, cursorAlias
|
||||||
}
|
}
|
||||||
|
|
||||||
// ------------------------------------------------------------------------- //
|
// ------------------------------------------------------------------------- //
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ func TestGetCursorFilter_Forward(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||||
|
|
||||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -65,7 +65,7 @@ func TestGetCursorFilter_Backward(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
modelColumns := []string{"id", "title", "created_at", "user_id"}
|
||||||
|
|
||||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -96,7 +96,7 @@ func TestGetCursorFilter_NoCursor(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "title", "created_at"}
|
modelColumns := []string{"id", "title", "created_at"}
|
||||||
|
|
||||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error when no cursor is provided")
|
t.Error("Expected error when no cursor is provided")
|
||||||
}
|
}
|
||||||
@@ -116,7 +116,7 @@ func TestGetCursorFilter_NoSort(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "title"}
|
modelColumns := []string{"id", "title"}
|
||||||
|
|
||||||
_, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
_, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error when no sort columns are defined")
|
t.Error("Expected error when no sort columns are defined")
|
||||||
}
|
}
|
||||||
@@ -140,7 +140,7 @@ func TestGetCursorFilter_MultiColumnSort(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "title", "priority", "created_at"}
|
modelColumns := []string{"id", "title", "priority", "created_at"}
|
||||||
|
|
||||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -170,19 +170,50 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "name", "email"}
|
modelColumns := []string{"id", "name", "email"}
|
||||||
|
|
||||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should handle schema prefix properly
|
// Should include full schema-qualified name in FROM clause
|
||||||
if !strings.Contains(filter, "users") {
|
if !strings.Contains(filter, "public.users") {
|
||||||
t.Errorf("Filter should reference table name users, got: %s", filter)
|
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetCursorFilter_LateralJoin(t *testing.T) {
|
||||||
|
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
|
||||||
|
|
||||||
|
options := common.RequestOptions{
|
||||||
|
Sort: []common.SortOption{{Column: "fn.sortorder", Direction: "ASC"}},
|
||||||
|
CursorForward: "8975",
|
||||||
|
}
|
||||||
|
|
||||||
|
tableName := "core.account"
|
||||||
|
pkName := "rid_account"
|
||||||
|
modelColumns := []string{"rid_account", "description", "pastelno"}
|
||||||
|
expandJoins := map[string]string{"fn": lateralJoin}
|
||||||
|
|
||||||
|
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, expandJoins)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Generated lateral cursor filter: %s", filter)
|
||||||
|
|
||||||
|
if !strings.Contains(filter, "cursor_select_fn") {
|
||||||
|
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
|
||||||
|
}
|
||||||
|
if !strings.Contains(filter, "sortorder") {
|
||||||
|
t.Errorf("Filter should reference sortorder column, got: %s", filter)
|
||||||
|
}
|
||||||
|
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
|
||||||
|
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetActiveCursor(t *testing.T) {
|
func TestGetActiveCursor(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@@ -288,18 +319,19 @@ func TestResolveColumn(t *testing.T) {
|
|||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Joined column (not supported)",
|
name: "Joined column (isJoin=true, no error)",
|
||||||
field: "name",
|
field: "name",
|
||||||
prefix: "user",
|
prefix: "user",
|
||||||
tableName: "posts",
|
tableName: "posts",
|
||||||
modelColumns: []string{"id", "title"},
|
modelColumns: []string{"id", "title"},
|
||||||
wantErr: true,
|
wantErr: false,
|
||||||
|
// cursorCol and targetCol are empty when isJoin=true; handled by caller
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
cursor, target, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
cursor, target, isJoin, err := resolveColumn(tt.field, tt.prefix, tt.tableName, tt.modelColumns)
|
||||||
|
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -312,6 +344,14 @@ func TestResolveColumn(t *testing.T) {
|
|||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For join columns, cursor/target are empty and isJoin=true
|
||||||
|
if isJoin {
|
||||||
|
if cursor != "" || target != "" {
|
||||||
|
t.Errorf("Expected empty cursor/target for join column, got %q / %q", cursor, target)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if cursor != tt.wantCursor {
|
if cursor != tt.wantCursor {
|
||||||
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
|
t.Errorf("Expected cursor %q, got %q", tt.wantCursor, cursor)
|
||||||
}
|
}
|
||||||
@@ -362,7 +402,7 @@ func TestCursorFilter_SQL_Safety(t *testing.T) {
|
|||||||
pkName := "id"
|
pkName := "id"
|
||||||
modelColumns := []string{"id", "created_at"}
|
modelColumns := []string{"id", "created_at"}
|
||||||
|
|
||||||
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
filter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,143 @@
|
|||||||
|
package resolvespec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestBuildFilterCondition tests the filter condition builder
|
||||||
|
func TestBuildFilterCondition(t *testing.T) {
|
||||||
|
h := &Handler{}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
filter common.FilterOption
|
||||||
|
expectedCondition string
|
||||||
|
expectedArgsCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Equal operator",
|
||||||
|
filter: common.FilterOption{
|
||||||
|
Column: "status",
|
||||||
|
Operator: "eq",
|
||||||
|
Value: "active",
|
||||||
|
},
|
||||||
|
expectedCondition: "status = ?",
|
||||||
|
expectedArgsCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Greater than operator",
|
||||||
|
filter: common.FilterOption{
|
||||||
|
Column: "age",
|
||||||
|
Operator: "gt",
|
||||||
|
Value: 18,
|
||||||
|
},
|
||||||
|
expectedCondition: "age > ?",
|
||||||
|
expectedArgsCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IN operator",
|
||||||
|
filter: common.FilterOption{
|
||||||
|
Column: "status",
|
||||||
|
Operator: "in",
|
||||||
|
Value: []string{"active", "pending"},
|
||||||
|
},
|
||||||
|
expectedCondition: "status IN (?,?)",
|
||||||
|
expectedArgsCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LIKE operator",
|
||||||
|
filter: common.FilterOption{
|
||||||
|
Column: "email",
|
||||||
|
Operator: "like",
|
||||||
|
Value: "%@example.com",
|
||||||
|
},
|
||||||
|
expectedCondition: "CAST(email AS TEXT) LIKE ?",
|
||||||
|
expectedArgsCount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
condition, args := h.buildFilterCondition(tt.filter)
|
||||||
|
|
||||||
|
if condition != tt.expectedCondition {
|
||||||
|
t.Errorf("Expected condition '%s', got '%s'", tt.expectedCondition, condition)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) != tt.expectedArgsCount {
|
||||||
|
t.Errorf("Expected %d args, got %d", tt.expectedArgsCount, len(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Skip value comparison for slices as they can't be compared with ==
|
||||||
|
// The important part is that args are populated correctly
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestORGrouping tests that consecutive OR filters are properly grouped
|
||||||
|
func TestORGrouping(t *testing.T) {
|
||||||
|
// This is a conceptual test - in practice we'd need a mock SelectQuery
|
||||||
|
// to verify the actual SQL grouping behavior
|
||||||
|
t.Run("Consecutive OR filters should be grouped", func(t *testing.T) {
|
||||||
|
filters := []common.FilterOption{
|
||||||
|
{Column: "status", Operator: "eq", Value: "active"},
|
||||||
|
{Column: "status", Operator: "eq", Value: "pending", LogicOperator: "OR"},
|
||||||
|
{Column: "status", Operator: "eq", Value: "trial", LogicOperator: "OR"},
|
||||||
|
{Column: "age", Operator: "gte", Value: 18},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expected behavior: (status='active' OR status='pending' OR status='trial') AND age>=18
|
||||||
|
// The first three filters should be grouped together
|
||||||
|
// The fourth filter should be separate with AND
|
||||||
|
|
||||||
|
// Count OR groups
|
||||||
|
orGroupCount := 0
|
||||||
|
inORGroup := false
|
||||||
|
|
||||||
|
for i := 1; i < len(filters); i++ {
|
||||||
|
if strings.EqualFold(filters[i].LogicOperator, "OR") && !inORGroup {
|
||||||
|
orGroupCount++
|
||||||
|
inORGroup = true
|
||||||
|
} else if !strings.EqualFold(filters[i].LogicOperator, "OR") {
|
||||||
|
inORGroup = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We should have detected one OR group
|
||||||
|
if orGroupCount != 1 {
|
||||||
|
t.Errorf("Expected 1 OR group, detected %d", orGroupCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Multiple OR groups should be handled correctly", func(t *testing.T) {
|
||||||
|
filters := []common.FilterOption{
|
||||||
|
{Column: "status", Operator: "eq", Value: "active"},
|
||||||
|
{Column: "status", Operator: "eq", Value: "pending", LogicOperator: "OR"},
|
||||||
|
{Column: "priority", Operator: "eq", Value: "high"},
|
||||||
|
{Column: "priority", Operator: "eq", Value: "urgent", LogicOperator: "OR"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expected: (status='active' OR status='pending') AND (priority='high' OR priority='urgent')
|
||||||
|
// Should have two OR groups
|
||||||
|
|
||||||
|
orGroupCount := 0
|
||||||
|
inORGroup := false
|
||||||
|
|
||||||
|
for i := 1; i < len(filters); i++ {
|
||||||
|
if strings.EqualFold(filters[i].LogicOperator, "OR") && !inORGroup {
|
||||||
|
orGroupCount++
|
||||||
|
inORGroup = true
|
||||||
|
} else if !strings.EqualFold(filters[i].LogicOperator, "OR") {
|
||||||
|
inORGroup = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We should have detected two OR groups
|
||||||
|
if orGroupCount != 2 {
|
||||||
|
t.Errorf("Expected 2 OR groups, detected %d", orGroupCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
+368
-37
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -138,6 +139,26 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
validator := common.NewColumnValidator(model)
|
validator := common.NewColumnValidator(model)
|
||||||
req.Options = validator.FilterRequestOptions(req.Options)
|
req.Options = validator.FilterRequestOptions(req.Options)
|
||||||
|
|
||||||
|
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||||
|
beforeCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Writer: w,
|
||||||
|
Request: r,
|
||||||
|
Operation: req.Operation,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
|
||||||
|
code := http.StatusUnauthorized
|
||||||
|
if beforeCtx.AbortCode != 0 {
|
||||||
|
code = beforeCtx.AbortCode
|
||||||
|
}
|
||||||
|
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
switch req.Operation {
|
switch req.Operation {
|
||||||
case "read":
|
case "read":
|
||||||
h.handleRead(ctx, w, id, req.Options)
|
h.handleRead(ctx, w, id, req.Options)
|
||||||
@@ -280,10 +301,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply filters
|
// Apply filters with proper grouping for OR logic
|
||||||
for _, filter := range options.Filters {
|
query = h.applyFilters(query, options.Filters)
|
||||||
logger.Debug("Applying filter: %s %s %v", filter.Column, filter.Operator, filter.Value)
|
|
||||||
query = h.applyFilter(query, filter)
|
// Apply custom operators
|
||||||
|
for _, customOp := range options.CustomOperators {
|
||||||
|
logger.Debug("Applying custom operator: %s - %s", customOp.Name, customOp.SQL)
|
||||||
|
query = query.Where(customOp.SQL)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply sorting
|
// Apply sorting
|
||||||
@@ -306,8 +330,13 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Extract model columns for validation
|
// Extract model columns for validation
|
||||||
modelColumns := reflection.GetModelColumns(model)
|
modelColumns := reflection.GetModelColumns(model)
|
||||||
|
|
||||||
// Get cursor filter SQL
|
// Default sort to primary key when none provided
|
||||||
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options)
|
if len(options.Sort) == 0 {
|
||||||
|
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get cursor filter SQL (expandJoins is empty for resolvespec — no custom SQL join support yet)
|
||||||
|
cursorFilter, err := GetCursorFilter(tableName, pkName, modelColumns, options, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error building cursor filter: %v", err)
|
logger.Error("Error building cursor filter: %v", err)
|
||||||
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
|
h.sendError(w, http.StatusBadRequest, "cursor_error", "Invalid cursor pagination", err)
|
||||||
@@ -381,7 +410,77 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply pagination
|
// Handle FetchRowNumber if requested
|
||||||
|
var rowNumber *int64
|
||||||
|
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||||
|
logger.Debug("Fetching row number for ID: %s", *options.FetchRowNumber)
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
|
||||||
|
// Build ROW_NUMBER window function SQL
|
||||||
|
rowNumberSQL := "ROW_NUMBER() OVER ("
|
||||||
|
if len(options.Sort) > 0 {
|
||||||
|
rowNumberSQL += "ORDER BY "
|
||||||
|
for i, sort := range options.Sort {
|
||||||
|
if i > 0 {
|
||||||
|
rowNumberSQL += ", "
|
||||||
|
}
|
||||||
|
direction := "ASC"
|
||||||
|
if strings.EqualFold(sort.Direction, "desc") {
|
||||||
|
direction = "DESC"
|
||||||
|
}
|
||||||
|
rowNumberSQL += fmt.Sprintf("%s %s", sort.Column, direction)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rowNumberSQL += ")"
|
||||||
|
|
||||||
|
// Create a query to fetch the row number using a subquery approach
|
||||||
|
// We'll select the PK and row_number, then filter by the target ID
|
||||||
|
type RowNumResult struct {
|
||||||
|
RowNum int64 `bun:"row_num"`
|
||||||
|
}
|
||||||
|
|
||||||
|
rowNumQuery := h.db.NewSelect().Table(tableName).
|
||||||
|
ColumnExpr(fmt.Sprintf("%s AS row_num", rowNumberSQL)).
|
||||||
|
Column(pkName)
|
||||||
|
|
||||||
|
// Apply the same filters as the main query
|
||||||
|
for _, filter := range options.Filters {
|
||||||
|
rowNumQuery = h.applyFilter(rowNumQuery, filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply custom operators
|
||||||
|
for _, customOp := range options.CustomOperators {
|
||||||
|
rowNumQuery = rowNumQuery.Where(customOp.SQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter for the specific ID we want the row number for
|
||||||
|
rowNumQuery = rowNumQuery.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), *options.FetchRowNumber)
|
||||||
|
|
||||||
|
// Execute query to get row number
|
||||||
|
var result RowNumResult
|
||||||
|
if err := rowNumQuery.Scan(ctx, &result); err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
// Build filter description for error message
|
||||||
|
filterInfo := fmt.Sprintf("filters: %d", len(options.Filters))
|
||||||
|
if len(options.CustomOperators) > 0 {
|
||||||
|
customOps := make([]string, 0, len(options.CustomOperators))
|
||||||
|
for _, op := range options.CustomOperators {
|
||||||
|
customOps = append(customOps, op.SQL)
|
||||||
|
}
|
||||||
|
filterInfo += fmt.Sprintf(", custom operators: [%s]", strings.Join(customOps, "; "))
|
||||||
|
}
|
||||||
|
logger.Warn("No row found for primary key %s=%s with %s", pkName, *options.FetchRowNumber, filterInfo)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Error fetching row number: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
rowNumber = &result.RowNum
|
||||||
|
logger.Debug("Found row number: %d", *rowNumber)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply pagination (skip if FetchRowNumber is set - we want only that record)
|
||||||
|
if options.FetchRowNumber == nil || *options.FetchRowNumber == "" {
|
||||||
if options.Limit != nil && *options.Limit > 0 {
|
if options.Limit != nil && *options.Limit > 0 {
|
||||||
logger.Debug("Applying limit: %d", *options.Limit)
|
logger.Debug("Applying limit: %d", *options.Limit)
|
||||||
query = query.Limit(*options.Limit)
|
query = query.Limit(*options.Limit)
|
||||||
@@ -390,15 +489,26 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
logger.Debug("Applying offset: %d", *options.Offset)
|
logger.Debug("Applying offset: %d", *options.Offset)
|
||||||
query = query.Offset(*options.Offset)
|
query = query.Offset(*options.Offset)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Execute query
|
// Execute query
|
||||||
var result interface{}
|
var result interface{}
|
||||||
|
if id != "" || (options.FetchRowNumber != nil && *options.FetchRowNumber != "") {
|
||||||
|
// Single record query - either by URL ID or FetchRowNumber
|
||||||
|
var targetID string
|
||||||
if id != "" {
|
if id != "" {
|
||||||
logger.Debug("Querying single record with ID: %s", id)
|
targetID = id
|
||||||
|
logger.Debug("Querying single record with URL ID: %s", id)
|
||||||
|
} else {
|
||||||
|
targetID = *options.FetchRowNumber
|
||||||
|
logger.Debug("Querying single record with FetchRowNumber ID: %s", targetID)
|
||||||
|
}
|
||||||
|
|
||||||
// For single record, create a new pointer to the struct type
|
// For single record, create a new pointer to the struct type
|
||||||
singleResult := reflect.New(modelType).Interface()
|
singleResult := reflect.New(modelType).Interface()
|
||||||
|
pkName := reflection.GetPrimaryKeyName(singleResult)
|
||||||
|
|
||||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(singleResult))), id)
|
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
if err := query.Scan(ctx, singleResult); err != nil {
|
if err := query.Scan(ctx, singleResult); err != nil {
|
||||||
logger.Error("Error querying record: %v", err)
|
logger.Error("Error querying record: %v", err)
|
||||||
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
h.sendError(w, http.StatusInternalServerError, "query_error", "Error executing query", err)
|
||||||
@@ -418,20 +528,39 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
|
|
||||||
logger.Info("Successfully retrieved records")
|
logger.Info("Successfully retrieved records")
|
||||||
|
|
||||||
|
// Build metadata
|
||||||
limit := 0
|
limit := 0
|
||||||
|
offset := 0
|
||||||
|
count := int64(total)
|
||||||
|
|
||||||
|
// When FetchRowNumber is used, we only return 1 record
|
||||||
|
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||||
|
count = 1
|
||||||
|
// Set the fetched row number on the record
|
||||||
|
if rowNumber != nil {
|
||||||
|
logger.Debug("FetchRowNumber: Setting row number %d on record", *rowNumber)
|
||||||
|
h.setRowNumbersOnRecords(result, int(*rowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if options.Limit != nil {
|
if options.Limit != nil {
|
||||||
limit = *options.Limit
|
limit = *options.Limit
|
||||||
}
|
}
|
||||||
offset := 0
|
|
||||||
if options.Offset != nil {
|
if options.Offset != nil {
|
||||||
offset = *options.Offset
|
offset = *options.Offset
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set row numbers on records if RowNumber field exists
|
||||||
|
// Only for multiple records (not when fetching single record)
|
||||||
|
h.setRowNumbersOnRecords(result, offset)
|
||||||
|
}
|
||||||
|
|
||||||
h.sendResponse(w, result, &common.Metadata{
|
h.sendResponse(w, result, &common.Metadata{
|
||||||
Total: int64(total),
|
Total: int64(total),
|
||||||
Filtered: int64(total),
|
Filtered: int64(total),
|
||||||
|
Count: count,
|
||||||
Limit: limit,
|
Limit: limit,
|
||||||
Offset: offset,
|
Offset: offset,
|
||||||
|
RowNumber: rowNumber,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -475,7 +604,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
// Standard processing without nested relations
|
// Standard processing without nested relations
|
||||||
query := h.db.NewInsert().Table(tableName)
|
query := h.db.NewInsert().Table(tableName)
|
||||||
for key, value := range v {
|
for key, value := range v {
|
||||||
query = query.Value(key, value)
|
query = query.Value(key, common.ConvertSliceForBun(value))
|
||||||
}
|
}
|
||||||
result, err := query.Exec(ctx)
|
result, err := query.Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -541,7 +670,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
for _, item := range v {
|
for _, item := range v {
|
||||||
txQuery := tx.NewInsert().Table(tableName)
|
txQuery := tx.NewInsert().Table(tableName)
|
||||||
for key, value := range item {
|
for key, value := range item {
|
||||||
txQuery = txQuery.Value(key, value)
|
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
|
||||||
}
|
}
|
||||||
if _, err := txQuery.Exec(ctx); err != nil {
|
if _, err := txQuery.Exec(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -619,7 +748,7 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
if itemMap, ok := item.(map[string]interface{}); ok {
|
if itemMap, ok := item.(map[string]interface{}); ok {
|
||||||
txQuery := tx.NewInsert().Table(tableName)
|
txQuery := tx.NewInsert().Table(tableName)
|
||||||
for key, value := range itemMap {
|
for key, value := range itemMap {
|
||||||
txQuery = txQuery.Value(key, value)
|
txQuery = txQuery.Value(key, common.ConvertSliceForBun(value))
|
||||||
}
|
}
|
||||||
if _, err := txQuery.Exec(ctx); err != nil {
|
if _, err := txQuery.Exec(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -1133,6 +1262,24 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
logger.Info("Deleting records from %s.%s", schema, entity)
|
logger.Info("Deleting records from %s.%s", schema, entity)
|
||||||
|
|
||||||
|
// Execute BeforeDelete hooks (covers model-rule checks before any deletion)
|
||||||
|
hookCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
ID: id,
|
||||||
|
Data: data,
|
||||||
|
Writer: w,
|
||||||
|
Tx: h.db,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
|
logger.Error("BeforeDelete hook failed: %v", err)
|
||||||
|
h.sendError(w, http.StatusForbidden, "delete_forbidden", "Delete operation not allowed", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Handle batch delete from request data
|
// Handle batch delete from request data
|
||||||
if data != nil {
|
if data != nil {
|
||||||
switch v := data.(type) {
|
switch v := data.(type) {
|
||||||
@@ -1303,29 +1450,165 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
h.sendResponse(w, recordToDelete, nil)
|
h.sendResponse(w, recordToDelete, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
// applyFilters applies all filters with proper grouping for OR logic
|
||||||
|
// Groups consecutive OR filters together to ensure proper query precedence
|
||||||
|
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
|
||||||
|
func (h *Handler) applyFilters(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||||
|
if len(filters) == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for i < len(filters) {
|
||||||
|
// Check if this starts an OR group (current or next filter has OR logic)
|
||||||
|
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||||
|
|
||||||
|
if startORGroup {
|
||||||
|
// Collect all consecutive filters that are OR'd together
|
||||||
|
orGroup := []common.FilterOption{filters[i]}
|
||||||
|
j := i + 1
|
||||||
|
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||||
|
orGroup = append(orGroup, filters[j])
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply the OR group as a single grouped WHERE clause
|
||||||
|
query = h.applyFilterGroup(query, orGroup)
|
||||||
|
i = j
|
||||||
|
} else {
|
||||||
|
// Single filter with AND logic (or first filter)
|
||||||
|
condition, args := h.buildFilterCondition(filters[i])
|
||||||
|
if condition != "" {
|
||||||
|
query = query.Where(condition, args...)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyFilterGroup applies a group of filters that should be OR'd together
|
||||||
|
// Always wraps them in parentheses and applies as a single WHERE clause
|
||||||
|
func (h *Handler) applyFilterGroup(query common.SelectQuery, filters []common.FilterOption) common.SelectQuery {
|
||||||
|
if len(filters) == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build all conditions and collect args
|
||||||
|
var conditions []string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
for _, filter := range filters {
|
||||||
|
condition, filterArgs := h.buildFilterCondition(filter)
|
||||||
|
if condition != "" {
|
||||||
|
conditions = append(conditions, condition)
|
||||||
|
args = append(args, filterArgs...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(conditions) == 0 {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single filter - no need for grouping
|
||||||
|
if len(conditions) == 1 {
|
||||||
|
return query.Where(conditions[0], args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple conditions - group with parentheses and OR
|
||||||
|
groupedCondition := "(" + strings.Join(conditions, " OR ") + ")"
|
||||||
|
return query.Where(groupedCondition, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildFilterCondition builds a filter condition and returns it with args
|
||||||
|
func (h *Handler) buildFilterCondition(filter common.FilterOption) (conditionString string, conditionArgs []interface{}) {
|
||||||
|
var condition string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
switch filter.Operator {
|
switch filter.Operator {
|
||||||
case "eq":
|
case "eq", "=":
|
||||||
return query.Where(fmt.Sprintf("%s = ?", filter.Column), filter.Value)
|
condition = fmt.Sprintf("%s = ?", filter.Column)
|
||||||
case "neq":
|
args = []interface{}{filter.Value}
|
||||||
return query.Where(fmt.Sprintf("%s != ?", filter.Column), filter.Value)
|
case "neq", "!=", "<>":
|
||||||
case "gt":
|
condition = fmt.Sprintf("%s != ?", filter.Column)
|
||||||
return query.Where(fmt.Sprintf("%s > ?", filter.Column), filter.Value)
|
args = []interface{}{filter.Value}
|
||||||
case "gte":
|
case "gt", ">":
|
||||||
return query.Where(fmt.Sprintf("%s >= ?", filter.Column), filter.Value)
|
condition = fmt.Sprintf("%s > ?", filter.Column)
|
||||||
case "lt":
|
args = []interface{}{filter.Value}
|
||||||
return query.Where(fmt.Sprintf("%s < ?", filter.Column), filter.Value)
|
case "gte", ">=":
|
||||||
case "lte":
|
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
||||||
return query.Where(fmt.Sprintf("%s <= ?", filter.Column), filter.Value)
|
args = []interface{}{filter.Value}
|
||||||
|
case "lt", "<":
|
||||||
|
condition = fmt.Sprintf("%s < ?", filter.Column)
|
||||||
|
args = []interface{}{filter.Value}
|
||||||
|
case "lte", "<=":
|
||||||
|
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||||
|
args = []interface{}{filter.Value}
|
||||||
case "like":
|
case "like":
|
||||||
return query.Where(fmt.Sprintf("%s LIKE ?", filter.Column), filter.Value)
|
condition = fmt.Sprintf("CAST(%s AS TEXT) LIKE ?", filter.Column)
|
||||||
|
args = []interface{}{filter.Value}
|
||||||
case "ilike":
|
case "ilike":
|
||||||
return query.Where(fmt.Sprintf("%s ILIKE ?", filter.Column), filter.Value)
|
condition = fmt.Sprintf("CAST(%s AS TEXT) ILIKE ?", filter.Column)
|
||||||
|
args = []interface{}{filter.Value}
|
||||||
case "in":
|
case "in":
|
||||||
return query.Where(fmt.Sprintf("%s IN (?)", filter.Column), filter.Value)
|
condition, args = common.BuildInCondition(filter.Column, filter.Value)
|
||||||
|
if condition == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return condition, args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption) common.SelectQuery {
|
||||||
|
// Determine which method to use based on LogicOperator
|
||||||
|
useOrLogic := strings.EqualFold(filter.LogicOperator, "OR")
|
||||||
|
|
||||||
|
var condition string
|
||||||
|
var args []interface{}
|
||||||
|
|
||||||
|
switch filter.Operator {
|
||||||
|
case "eq", "=":
|
||||||
|
condition = fmt.Sprintf("%s = ?", filter.Column)
|
||||||
|
args = []interface{}{filter.Value}
|
||||||
|
case "neq", "!=", "<>":
|
||||||
|
condition = fmt.Sprintf("%s != ?", filter.Column)
|
||||||
|
args = []interface{}{filter.Value}
|
||||||
|
case "gt", ">":
|
||||||
|
condition = fmt.Sprintf("%s > ?", filter.Column)
|
||||||
|
args = []interface{}{filter.Value}
|
||||||
|
case "gte", ">=":
|
||||||
|
condition = fmt.Sprintf("%s >= ?", filter.Column)
|
||||||
|
args = []interface{}{filter.Value}
|
||||||
|
case "lt", "<":
|
||||||
|
condition = fmt.Sprintf("%s < ?", filter.Column)
|
||||||
|
args = []interface{}{filter.Value}
|
||||||
|
case "lte", "<=":
|
||||||
|
condition = fmt.Sprintf("%s <= ?", filter.Column)
|
||||||
|
args = []interface{}{filter.Value}
|
||||||
|
case "like":
|
||||||
|
condition = fmt.Sprintf("CAST(%s AS TEXT) LIKE ?", filter.Column)
|
||||||
|
args = []interface{}{filter.Value}
|
||||||
|
case "ilike":
|
||||||
|
condition = fmt.Sprintf("CAST(%s AS TEXT) ILIKE ?", filter.Column)
|
||||||
|
args = []interface{}{filter.Value}
|
||||||
|
case "in":
|
||||||
|
condition, args = common.BuildInCondition(filter.Column, filter.Value)
|
||||||
|
if condition == "" {
|
||||||
|
return query
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return query
|
return query
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply filter with appropriate logic operator
|
||||||
|
if useOrLogic {
|
||||||
|
return query.WhereOr(condition, args...)
|
||||||
|
}
|
||||||
|
return query.Where(condition, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseTableName splits a table name that may contain schema into separate schema and table
|
// parseTableName splits a table name that may contain schema into separate schema and table
|
||||||
@@ -1475,18 +1758,21 @@ func (h *Handler) sendResponse(w common.ResponseWriter, data interface{}, metada
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) sendError(w common.ResponseWriter, status int, code, message string, details interface{}) {
|
func (h *Handler) sendError(w common.ResponseWriter, status int, code, message string, details interface{}) {
|
||||||
w.SetHeader("Content-Type", "application/json")
|
apiErr := &common.APIError{
|
||||||
w.WriteHeader(status)
|
|
||||||
err := w.WriteJSON(common.Response{
|
|
||||||
Success: false,
|
|
||||||
Error: &common.APIError{
|
|
||||||
Code: code,
|
Code: code,
|
||||||
Message: message,
|
Message: message,
|
||||||
Details: details,
|
Details: details,
|
||||||
Detail: fmt.Sprintf("%v", details),
|
Detail: fmt.Sprintf("%v", details),
|
||||||
},
|
}
|
||||||
})
|
if asErr, ok := details.(error); ok {
|
||||||
if err != nil {
|
var sqlErr *common.SQLError
|
||||||
|
if errors.As(asErr, &sqlErr) {
|
||||||
|
apiErr.SQL = sqlErr.SQL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.SetHeader("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
if err := w.WriteJSON(common.Response{Success: false, Error: apiErr}); err != nil {
|
||||||
logger.Error("Error sending response: %v", err)
|
logger.Error("Error sending response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1709,6 +1995,51 @@ func toSnakeCase(s string) string {
|
|||||||
return strings.ToLower(result.String())
|
return strings.ToLower(result.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setRowNumbersOnRecords sets the RowNumber field on each record if it exists
|
||||||
|
// The row number is calculated as offset + index + 1 (1-based)
|
||||||
|
func (h *Handler) setRowNumbersOnRecords(records interface{}, offset int) {
|
||||||
|
// Get the reflect value of the records
|
||||||
|
recordsValue := reflect.ValueOf(records)
|
||||||
|
if recordsValue.Kind() == reflect.Ptr {
|
||||||
|
recordsValue = recordsValue.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure it's a slice
|
||||||
|
if recordsValue.Kind() != reflect.Slice {
|
||||||
|
logger.Debug("setRowNumbersOnRecords: records is not a slice, skipping")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate through each record
|
||||||
|
for i := 0; i < recordsValue.Len(); i++ {
|
||||||
|
record := recordsValue.Index(i)
|
||||||
|
|
||||||
|
// Dereference if it's a pointer
|
||||||
|
if record.Kind() == reflect.Ptr {
|
||||||
|
if record.IsNil() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
record = record.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure it's a struct
|
||||||
|
if record.Kind() != reflect.Struct {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to find and set the RowNumber field
|
||||||
|
rowNumberField := record.FieldByName("RowNumber")
|
||||||
|
if rowNumberField.IsValid() && rowNumberField.CanSet() {
|
||||||
|
// Check if the field is of type int64
|
||||||
|
if rowNumberField.Kind() == reflect.Int64 {
|
||||||
|
rowNum := int64(offset + i + 1)
|
||||||
|
rowNumberField.SetInt(rowNum)
|
||||||
|
logger.Debug("Set RowNumber=%d for record index %d", rowNum, i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// HandleOpenAPI generates and returns the OpenAPI specification
|
// HandleOpenAPI generates and returns the OpenAPI specification
|
||||||
func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
|
func (h *Handler) HandleOpenAPI(w common.ResponseWriter, r common.Request) {
|
||||||
if h.openAPIGenerator == nil {
|
if h.openAPIGenerator == nil {
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ import (
|
|||||||
type HookType string
|
type HookType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||||
|
// Use this for auth checks that need model rules and user context simultaneously.
|
||||||
|
BeforeHandle HookType = "before_handle"
|
||||||
|
|
||||||
// Read operation hooks
|
// Read operation hooks
|
||||||
BeforeRead HookType = "before_read"
|
BeforeRead HookType = "before_read"
|
||||||
AfterRead HookType = "after_read"
|
AfterRead HookType = "after_read"
|
||||||
@@ -43,6 +47,9 @@ type HookContext struct {
|
|||||||
Writer common.ResponseWriter
|
Writer common.ResponseWriter
|
||||||
Request common.Request
|
Request common.Request
|
||||||
|
|
||||||
|
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||||
|
Operation string
|
||||||
|
|
||||||
// Operation-specific fields
|
// Operation-specific fields
|
||||||
ID string
|
ID string
|
||||||
Data interface{} // For create/update operations
|
Data interface{} // For create/update operations
|
||||||
|
|||||||
+130
-20
@@ -70,17 +70,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
|||||||
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
|
entityWithIDPath := buildRoutePath(schema, entity) + "/{id}"
|
||||||
|
|
||||||
// Create handler functions for this specific entity
|
// Create handler functions for this specific entity
|
||||||
postEntityHandler := createMuxHandler(handler, schema, entity, "")
|
var postEntityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
|
||||||
postEntityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
var postEntityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
|
||||||
getEntityHandler := createMuxGetHandler(handler, schema, entity, "")
|
var getEntityHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
|
||||||
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||||
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
|
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"POST", "OPTIONS"})
|
||||||
|
|
||||||
// Apply authentication middleware if provided
|
// Apply authentication middleware if provided
|
||||||
if authMiddleware != nil {
|
if authMiddleware != nil {
|
||||||
postEntityHandler = authMiddleware(postEntityHandler).(http.HandlerFunc)
|
postEntityHandler = authMiddleware(postEntityHandler)
|
||||||
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler).(http.HandlerFunc)
|
postEntityWithIDHandler = authMiddleware(postEntityWithIDHandler)
|
||||||
getEntityHandler = authMiddleware(getEntityHandler).(http.HandlerFunc)
|
getEntityHandler = authMiddleware(getEntityHandler)
|
||||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,9 +216,34 @@ type BunRouterHandler interface {
|
|||||||
Handle(method, path string, handler bunrouter.HandlerFunc)
|
Handle(method, path string, handler bunrouter.HandlerFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wrapBunRouterHandler wraps a bunrouter handler with auth middleware if provided
|
||||||
|
func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware MiddlewareFunc) bunrouter.HandlerFunc {
|
||||||
|
if authMiddleware == nil {
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
// Create an http.Handler that calls the bunrouter handler
|
||||||
|
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Replace the embedded *http.Request with the middleware-enriched one
|
||||||
|
// so that auth context (user ID, etc.) is visible to the handler.
|
||||||
|
enrichedReq := req
|
||||||
|
enrichedReq.Request = r
|
||||||
|
_ = handler(w, enrichedReq)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Wrap with auth middleware and execute
|
||||||
|
wrappedHandler := authMiddleware(httpHandler)
|
||||||
|
wrappedHandler.ServeHTTP(w, req.Request)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
|
// SetupBunRouterRoutes sets up bunrouter routes for the ResolveSpec API
|
||||||
// Accepts bunrouter.Router or bunrouter.Group
|
// Accepts bunrouter.Router or bunrouter.Group
|
||||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||||
|
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||||
|
|
||||||
// CORS config
|
// CORS config
|
||||||
corsConfig := common.DefaultCORSConfig()
|
corsConfig := common.DefaultCORSConfig()
|
||||||
@@ -256,7 +281,7 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
currentEntity := entity
|
currentEntity := entity
|
||||||
|
|
||||||
// POST route without ID
|
// POST route without ID
|
||||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
postEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
@@ -267,10 +292,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
r.Handle("POST", entityPath, wrapBunRouterHandler(postEntityHandler, authMiddleware))
|
||||||
|
|
||||||
// POST route with ID
|
// POST route with ID
|
||||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
postEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
@@ -282,10 +308,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
r.Handle("POST", entityWithIDPath, wrapBunRouterHandler(postEntityWithIDHandler, authMiddleware))
|
||||||
|
|
||||||
// GET route without ID
|
// GET route without ID
|
||||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
getEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
@@ -296,10 +323,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
r.Handle("GET", entityPath, wrapBunRouterHandler(getEntityHandler, authMiddleware))
|
||||||
|
|
||||||
// GET route with ID
|
// GET route with ID
|
||||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
getEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
@@ -311,9 +339,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
r.Handle("GET", entityWithIDPath, wrapBunRouterHandler(getEntityWithIDHandler, authMiddleware))
|
||||||
|
|
||||||
// OPTIONS route without ID (returns metadata)
|
// OPTIONS route without ID (returns metadata)
|
||||||
|
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
@@ -330,6 +360,7 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// OPTIONS route with ID (returns metadata)
|
// OPTIONS route with ID (returns metadata)
|
||||||
|
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewHTTPRequest(req.Request)
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
@@ -355,8 +386,8 @@ func ExampleWithBunRouter(bunDB *bun.DB) {
|
|||||||
// Create bunrouter
|
// Create bunrouter
|
||||||
bunRouter := bunrouter.New()
|
bunRouter := bunrouter.New()
|
||||||
|
|
||||||
// Setup ResolveSpec routes with bunrouter
|
// Setup ResolveSpec routes with bunrouter without authentication
|
||||||
SetupBunRouterRoutes(bunRouter, handler)
|
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
// http.ListenAndServe(":8080", bunRouter)
|
// http.ListenAndServe(":8080", bunRouter)
|
||||||
@@ -377,8 +408,8 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
|
|||||||
// Create bunrouter
|
// Create bunrouter
|
||||||
bunRouter := bunrouter.New()
|
bunRouter := bunrouter.New()
|
||||||
|
|
||||||
// Setup ResolveSpec routes
|
// Setup ResolveSpec routes without authentication
|
||||||
SetupBunRouterRoutes(bunRouter, handler)
|
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||||
|
|
||||||
// This gives you the full uptrace stack: bunrouter + Bun ORM
|
// This gives you the full uptrace stack: bunrouter + Bun ORM
|
||||||
// http.ListenAndServe(":8080", bunRouter)
|
// http.ListenAndServe(":8080", bunRouter)
|
||||||
@@ -396,8 +427,87 @@ func ExampleBunRouterWithGroup(bunDB *bun.DB) {
|
|||||||
apiGroup := bunRouter.NewGroup("/api")
|
apiGroup := bunRouter.NewGroup("/api")
|
||||||
|
|
||||||
// Setup ResolveSpec routes on the group - routes will be under /api
|
// Setup ResolveSpec routes on the group - routes will be under /api
|
||||||
SetupBunRouterRoutes(apiGroup, handler)
|
SetupBunRouterRoutes(apiGroup, handler, nil)
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
// http.ListenAndServe(":8080", bunRouter)
|
// http.ListenAndServe(":8080", bunRouter)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExampleWithGORMAndAuth shows how to use ResolveSpec with GORM and authentication
|
||||||
|
func ExampleWithGORMAndAuth(db *gorm.DB) {
|
||||||
|
// Create handler using GORM
|
||||||
|
_ = NewHandlerWithGORM(db)
|
||||||
|
|
||||||
|
// Create auth middleware
|
||||||
|
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
// secList := security.NewSecurityList(myProvider)
|
||||||
|
// authMiddleware := func(h http.Handler) http.Handler {
|
||||||
|
// return security.NewAuthHandler(secList, h)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Setup router with authentication
|
||||||
|
_ = mux.NewRouter()
|
||||||
|
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||||
|
|
||||||
|
// Register models
|
||||||
|
// handler.RegisterModel("public", "users", &User{})
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
// http.ListenAndServe(":8080", muxRouter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleWithBunAndAuth shows how to use ResolveSpec with Bun and authentication
|
||||||
|
func ExampleWithBunAndAuth(bunDB *bun.DB) {
|
||||||
|
// Create Bun adapter
|
||||||
|
dbAdapter := database.NewBunAdapter(bunDB)
|
||||||
|
|
||||||
|
// Create model registry
|
||||||
|
registry := modelregistry.NewModelRegistry()
|
||||||
|
// registry.RegisterModel("public.users", &User{})
|
||||||
|
|
||||||
|
// Create handler
|
||||||
|
_ = NewHandler(dbAdapter, registry)
|
||||||
|
|
||||||
|
// Create auth middleware
|
||||||
|
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
// secList := security.NewSecurityList(myProvider)
|
||||||
|
// authMiddleware := func(h http.Handler) http.Handler {
|
||||||
|
// return security.NewAuthHandler(secList, h)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Setup routes with authentication
|
||||||
|
_ = mux.NewRouter()
|
||||||
|
// SetupMuxRoutes(muxRouter, handler, authMiddleware)
|
||||||
|
|
||||||
|
// Start server
|
||||||
|
// http.ListenAndServe(":8080", muxRouter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExampleBunRouterWithBunDBAndAuth shows the full uptrace stack with authentication
|
||||||
|
func ExampleBunRouterWithBunDBAndAuth(bunDB *bun.DB) {
|
||||||
|
// Create Bun database adapter
|
||||||
|
dbAdapter := database.NewBunAdapter(bunDB)
|
||||||
|
|
||||||
|
// Create model registry
|
||||||
|
registry := modelregistry.NewModelRegistry()
|
||||||
|
// registry.RegisterModel("public.users", &User{})
|
||||||
|
|
||||||
|
// Create handler with Bun
|
||||||
|
_ = NewHandler(dbAdapter, registry)
|
||||||
|
|
||||||
|
// Create auth middleware
|
||||||
|
// import "github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
|
// secList := security.NewSecurityList(myProvider)
|
||||||
|
// authMiddleware := func(h http.Handler) http.Handler {
|
||||||
|
// return security.NewAuthHandler(secList, h)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Create bunrouter
|
||||||
|
_ = bunrouter.New()
|
||||||
|
|
||||||
|
// Setup ResolveSpec routes with authentication
|
||||||
|
// SetupBunRouterRoutes(bunRouter, handler, authMiddleware)
|
||||||
|
|
||||||
|
// This gives you the full uptrace stack: bunrouter + Bun ORM with authentication
|
||||||
|
// http.ListenAndServe(":8080", bunRouter)
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package resolvespec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
@@ -10,6 +11,17 @@ import (
|
|||||||
|
|
||||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||||
|
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||||
|
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = err.Error()
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
// Hook 1: BeforeRead - Load security rules
|
// Hook 1: BeforeRead - Load security rules
|
||||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
secCtx := newSecurityContext(hookCtx)
|
secCtx := newSecurityContext(hookCtx)
|
||||||
@@ -34,6 +46,18 @@ func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList
|
|||||||
return security.LogDataAccess(secCtx)
|
return security.LogDataAccess(secCtx)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Hook 5: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelUpdateAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 6: BeforeDelete - enforce CanDelete rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelDeleteAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
logger.Info("Security hooks registered for resolvespec handler")
|
logger.Info("Security hooks registered for resolvespec handler")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -147,6 +147,7 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Available Hook Types**:
|
**Available Hook Types**:
|
||||||
|
* `BeforeHandle` — fires after model resolution, before operation dispatch (auth checks)
|
||||||
* `BeforeRead`, `AfterRead`
|
* `BeforeRead`, `AfterRead`
|
||||||
* `BeforeCreate`, `AfterCreate`
|
* `BeforeCreate`, `AfterCreate`
|
||||||
* `BeforeUpdate`, `AfterUpdate`
|
* `BeforeUpdate`, `AfterUpdate`
|
||||||
@@ -157,11 +158,13 @@ handler.Hooks.Register(restheadspec.BeforeCreate, func(ctx *restheadspec.HookCon
|
|||||||
* `Handler`: Access to handler, database, and registry
|
* `Handler`: Access to handler, database, and registry
|
||||||
* `Schema`, `Entity`, `TableName`: Request info
|
* `Schema`, `Entity`, `TableName`: Request info
|
||||||
* `Model`: The registered model type
|
* `Model`: The registered model type
|
||||||
|
* `Operation`: Current operation string (`"read"`, `"create"`, `"update"`, `"delete"`)
|
||||||
* `Options`: Parsed request options (filters, sorting, etc.)
|
* `Options`: Parsed request options (filters, sorting, etc.)
|
||||||
* `ID`: Record ID (for single-record operations)
|
* `ID`: Record ID (for single-record operations)
|
||||||
* `Data`: Request data (for create/update)
|
* `Data`: Request data (for create/update)
|
||||||
* `Result`: Operation result (for after hooks)
|
* `Result`: Operation result (for after hooks)
|
||||||
* `Writer`: Response writer (allows hooks to modify response)
|
* `Writer`: Response writer (allows hooks to modify response)
|
||||||
|
* `Abort`, `AbortMessage`, `AbortCode`: Set in hook to abort with an error response
|
||||||
|
|
||||||
## Cursor Pagination
|
## Cursor Pagination
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
|||||||
modelColumns []string, // optional: for validation
|
modelColumns []string, // optional: for validation
|
||||||
expandJoins map[string]string, // optional: alias → JOIN SQL
|
expandJoins map[string]string, // optional: alias → JOIN SQL
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
|
// Separate schema prefix from bare table name
|
||||||
|
fullTableName := tableName
|
||||||
if strings.Contains(tableName, ".") {
|
if strings.Contains(tableName, ".") {
|
||||||
tableName = strings.SplitN(tableName, ".", 2)[1]
|
tableName = strings.SplitN(tableName, ".", 2)[1]
|
||||||
}
|
}
|
||||||
@@ -62,7 +64,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
|||||||
// 4. Process each sort column
|
// 4. Process each sort column
|
||||||
// --------------------------------------------------------------------- //
|
// --------------------------------------------------------------------- //
|
||||||
for _, s := range sortItems {
|
for _, s := range sortItems {
|
||||||
col := strings.TrimSpace(s.Column)
|
col := strings.Trim(strings.TrimSpace(s.Column), "()")
|
||||||
if col == "" {
|
if col == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -91,7 +93,8 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle joins
|
// Handle joins
|
||||||
if isJoin && expandJoins != nil {
|
if isJoin {
|
||||||
|
if expandJoins != nil {
|
||||||
if joinClause, ok := expandJoins[prefix]; ok {
|
if joinClause, ok := expandJoins[prefix]; ok {
|
||||||
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
jSQL, cRef := rewriteJoin(joinClause, tableName, prefix)
|
||||||
joinSQL = jSQL
|
joinSQL = jSQL
|
||||||
@@ -99,6 +102,11 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
|||||||
targetCol = prefix + "." + field
|
targetCol = prefix + "." + field
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if cursorCol == "" {
|
||||||
|
logger.Warn("Skipping cursor sort column %q: join alias %q not in expandJoins", col, prefix)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Build inequality
|
// Build inequality
|
||||||
op := "<"
|
op := "<"
|
||||||
@@ -127,7 +135,7 @@ func (opts *ExtendedRequestOptions) GetCursorFilter(
|
|||||||
WHERE cursor_select.%s = %s
|
WHERE cursor_select.%s = %s
|
||||||
AND (%s)
|
AND (%s)
|
||||||
)`,
|
)`,
|
||||||
tableName,
|
fullTableName,
|
||||||
joinSQL,
|
joinSQL,
|
||||||
pkName,
|
pkName,
|
||||||
cursorID,
|
cursorID,
|
||||||
|
|||||||
@@ -187,9 +187,9 @@ func TestGetCursorFilter_WithSchemaPrefix(t *testing.T) {
|
|||||||
t.Fatalf("GetCursorFilter failed: %v", err)
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Should handle schema prefix properly
|
// Should include full schema-qualified name in FROM clause
|
||||||
if !strings.Contains(filter, "users") {
|
if !strings.Contains(filter, "public.users") {
|
||||||
t.Errorf("Filter should reference table name users, got: %s", filter)
|
t.Errorf("Filter FROM clause should use schema-qualified name public.users, got: %s", filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Logf("Generated cursor filter with schema: %s", filter)
|
t.Logf("Generated cursor filter with schema: %s", filter)
|
||||||
@@ -278,6 +278,47 @@ func TestCleanSortField(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetCursorFilter_LateralJoin(t *testing.T) {
|
||||||
|
lateralJoin := "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(account.rid_account) r\ninner join account a on a.id = r.id\n) fn on true"
|
||||||
|
|
||||||
|
opts := &ExtendedRequestOptions{
|
||||||
|
RequestOptions: common.RequestOptions{
|
||||||
|
Sort: []common.SortOption{
|
||||||
|
{Column: "fn.sortorder", Direction: "ASC"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
opts.CursorForward = "8975"
|
||||||
|
|
||||||
|
tableName := "core.account"
|
||||||
|
pkName := "rid_account"
|
||||||
|
// modelColumns does not contain "sortorder" - it's a lateral join computed column
|
||||||
|
modelColumns := []string{"rid_account", "description", "pastelno"}
|
||||||
|
expandJoins := map[string]string{"fn": lateralJoin}
|
||||||
|
|
||||||
|
filter, err := opts.GetCursorFilter(tableName, pkName, modelColumns, expandJoins)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCursorFilter failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Generated lateral cursor filter: %s", filter)
|
||||||
|
|
||||||
|
// Should contain the rewritten lateral join inside the EXISTS subquery
|
||||||
|
if !strings.Contains(filter, "cursor_select_fn") {
|
||||||
|
t.Errorf("Filter should reference cursor_select_fn alias, got: %s", filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should compare fn.sortorder values
|
||||||
|
if !strings.Contains(filter, "sortorder") {
|
||||||
|
t.Errorf("Filter should reference sortorder column, got: %s", filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should NOT contain empty comparison like "< "
|
||||||
|
if strings.Contains(filter, " < ") || strings.Contains(filter, " > ") {
|
||||||
|
t.Errorf("Filter should not contain empty comparison operators, got: %s", filter)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildPriorityChain(t *testing.T) {
|
func TestBuildPriorityChain(t *testing.T) {
|
||||||
clauses := []string{
|
clauses := []string{
|
||||||
"cursor_select.priority > posts.priority",
|
"cursor_select.priority > posts.priority",
|
||||||
|
|||||||
@@ -9,29 +9,29 @@ import (
|
|||||||
"github.com/bitechdev/ResolveSpec/pkg/common"
|
"github.com/bitechdev/ResolveSpec/pkg/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Test that normalizeResultArray returns empty array when no records found without ID
|
// Test that normalizeResultArray returns empty object when no records found (single-record mode)
|
||||||
func TestNormalizeResultArray_EmptyArrayWhenNoID(t *testing.T) {
|
func TestNormalizeResultArray_EmptyArrayWhenNoID(t *testing.T) {
|
||||||
handler := &Handler{}
|
handler := &Handler{}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
input interface{}
|
input interface{}
|
||||||
shouldBeEmptyArr bool
|
shouldBeEmptyObj bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "nil should return empty array",
|
name: "nil should return empty object",
|
||||||
input: nil,
|
input: nil,
|
||||||
shouldBeEmptyArr: true,
|
shouldBeEmptyObj: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "empty slice should return empty array",
|
name: "empty slice should return empty object",
|
||||||
input: []*EmptyTestModel{},
|
input: []*EmptyTestModel{},
|
||||||
shouldBeEmptyArr: true,
|
shouldBeEmptyObj: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "single element should return the element",
|
name: "single element should return the element",
|
||||||
input: []*EmptyTestModel{{ID: 1, Name: "test"}},
|
input: []*EmptyTestModel{{ID: 1, Name: "test"}},
|
||||||
shouldBeEmptyArr: false,
|
shouldBeEmptyObj: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple elements should return the slice",
|
name: "multiple elements should return the slice",
|
||||||
@@ -39,7 +39,7 @@ func TestNormalizeResultArray_EmptyArrayWhenNoID(t *testing.T) {
|
|||||||
{ID: 1, Name: "test1"},
|
{ID: 1, Name: "test1"},
|
||||||
{ID: 2, Name: "test2"},
|
{ID: 2, Name: "test2"},
|
||||||
},
|
},
|
||||||
shouldBeEmptyArr: false,
|
shouldBeEmptyObj: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -47,25 +47,25 @@ func TestNormalizeResultArray_EmptyArrayWhenNoID(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result := handler.normalizeResultArray(tt.input)
|
result := handler.normalizeResultArray(tt.input)
|
||||||
|
|
||||||
// For cases that should return empty array
|
// For cases that should return empty object
|
||||||
if tt.shouldBeEmptyArr {
|
if tt.shouldBeEmptyObj {
|
||||||
emptyArr, ok := result.([]interface{})
|
emptyObj, ok := result.(map[string]interface{})
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Errorf("Expected empty array []interface{}{}, got %T: %v", result, result)
|
t.Errorf("Expected empty object map[string]interface{}{}, got %T: %v", result, result)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(emptyArr) != 0 {
|
if len(emptyObj) != 0 {
|
||||||
t.Errorf("Expected empty array with length 0, got length %d", len(emptyArr))
|
t.Errorf("Expected empty object with length 0, got length %d", len(emptyObj))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify it serializes to [] and not null
|
// Verify it serializes to {} and not null
|
||||||
jsonBytes, err := json.Marshal(result)
|
jsonBytes, err := json.Marshal(result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Failed to marshal result: %v", err)
|
t.Errorf("Failed to marshal result: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if string(jsonBytes) != "[]" {
|
if string(jsonBytes) != "{}" {
|
||||||
t.Errorf("Expected JSON '[]', got '%s'", string(jsonBytes))
|
t.Errorf("Expected JSON '{}', got '%s'", string(jsonBytes))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@@ -138,12 +138,12 @@ func TestSendResponseWithOptions_NoDataFoundHeader(t *testing.T) {
|
|||||||
t.Errorf("Expected X-No-Data-Found header to be 'true', got '%s'", mockWriter.headers["X-No-Data-Found"])
|
t.Errorf("Expected X-No-Data-Found header to be 'true', got '%s'", mockWriter.headers["X-No-Data-Found"])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check status code is 200
|
// Check status code is 200 even when no records found
|
||||||
if mockWriter.statusCode != 200 {
|
if mockWriter.statusCode != 200 {
|
||||||
t.Errorf("Expected status code 200, got %d", mockWriter.statusCode)
|
t.Errorf("Expected status code 200, got %d", mockWriter.statusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the body is an empty array
|
// Verify the body is an empty array (list request, SingleRecordAsObject not set)
|
||||||
if mockWriter.body == nil {
|
if mockWriter.body == nil {
|
||||||
t.Error("Expected body to be set, got nil")
|
t.Error("Expected body to be set, got nil")
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
+244
-96
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
@@ -133,6 +134,41 @@ func (h *Handler) Handle(w common.ResponseWriter, r common.Request, params map[s
|
|||||||
// Add request-scoped data to context (including options)
|
// Add request-scoped data to context (including options)
|
||||||
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
ctx = WithRequestData(ctx, schema, entity, tableName, model, modelPtr, options)
|
||||||
|
|
||||||
|
// Derive operation for auth check
|
||||||
|
var operation string
|
||||||
|
switch method {
|
||||||
|
case "GET":
|
||||||
|
operation = "read"
|
||||||
|
case "POST":
|
||||||
|
operation = "create"
|
||||||
|
case "PUT", "PATCH":
|
||||||
|
operation = "update"
|
||||||
|
case "DELETE":
|
||||||
|
operation = "delete"
|
||||||
|
default:
|
||||||
|
operation = "read"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute BeforeHandle hook - auth check fires here, after model resolution
|
||||||
|
beforeCtx := &HookContext{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: h,
|
||||||
|
Schema: schema,
|
||||||
|
Entity: entity,
|
||||||
|
Model: model,
|
||||||
|
Writer: w,
|
||||||
|
Request: r,
|
||||||
|
Operation: operation,
|
||||||
|
}
|
||||||
|
if err := h.hooks.Execute(BeforeHandle, beforeCtx); err != nil {
|
||||||
|
code := http.StatusUnauthorized
|
||||||
|
if beforeCtx.AbortCode != 0 {
|
||||||
|
code = beforeCtx.AbortCode
|
||||||
|
}
|
||||||
|
h.sendError(w, code, "unauthorized", beforeCtx.AbortMessage, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
switch method {
|
switch method {
|
||||||
case "GET":
|
case "GET":
|
||||||
if id != "" {
|
if id != "" {
|
||||||
@@ -540,24 +576,63 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply custom SQL JOIN clauses
|
// Apply custom SQL JOIN clauses, skipping any whose alias is already provided by a
|
||||||
|
// preload LEFT JOIN (to prevent "table name specified more than once" errors).
|
||||||
if len(options.CustomSQLJoin) > 0 {
|
if len(options.CustomSQLJoin) > 0 {
|
||||||
for _, joinClause := range options.CustomSQLJoin {
|
preloadAliasSet := make(map[string]bool, len(options.Preload))
|
||||||
|
for i := range options.Preload {
|
||||||
|
if alias := common.RelationPathToBunAlias(options.Preload[i].Relation); alias != "" {
|
||||||
|
preloadAliasSet[alias] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, joinClause := range options.CustomSQLJoin {
|
||||||
|
if i < len(options.JoinAliases) && options.JoinAliases[i] != "" {
|
||||||
|
alias := strings.ToLower(options.JoinAliases[i])
|
||||||
|
if preloadAliasSet[alias] {
|
||||||
|
logger.Debug("Skipping custom SQL JOIN (alias '%s' already joined by preload): %s", alias, joinClause)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
logger.Debug("Applying custom SQL JOIN: %s", joinClause)
|
logger.Debug("Applying custom SQL JOIN: %s", joinClause)
|
||||||
// Joins are already sanitized during parsing, so we can apply them directly
|
|
||||||
query = query.Join(joinClause)
|
query = query.Join(joinClause)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If ID is provided, filter by ID
|
// Handle FetchRowNumber before applying ID filter
|
||||||
if id != "" {
|
// This must happen before the query to get the row position, then filter by PK
|
||||||
|
var fetchedRowNumber *int64
|
||||||
|
var fetchRowNumberPKValue string
|
||||||
|
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
||||||
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
|
fetchRowNumberPKValue = *options.FetchRowNumber
|
||||||
|
|
||||||
|
logger.Debug("FetchRowNumber: Fetching row number for PK %s = %s", pkName, fetchRowNumberPKValue)
|
||||||
|
|
||||||
|
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, fetchRowNumberPKValue, options, model)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to fetch row number: %v", err)
|
||||||
|
h.sendError(w, http.StatusBadRequest, "fetch_rownumber_error", "Failed to fetch row number", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fetchedRowNumber = &rowNum
|
||||||
|
logger.Debug("FetchRowNumber: Row number %d for PK %s = %s", rowNum, pkName, fetchRowNumberPKValue)
|
||||||
|
|
||||||
|
// Now filter the main query to this specific primary key
|
||||||
|
tableAlias := reflection.ExtractTableNameOnly(tableName)
|
||||||
|
query = query.Where(fmt.Sprintf("%s.%s = ?", common.QuoteIdent(tableAlias), common.QuoteIdent(pkName)), fetchRowNumberPKValue)
|
||||||
|
} else if id != "" {
|
||||||
|
// If ID is provided (and not FetchRowNumber), filter by ID
|
||||||
pkName := reflection.GetPrimaryKeyName(model)
|
pkName := reflection.GetPrimaryKeyName(model)
|
||||||
logger.Debug("Filtering by ID=%s: %s", pkName, id)
|
logger.Debug("Filtering by ID=%s: %s", pkName, id)
|
||||||
|
|
||||||
query = query.Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), id)
|
tableAlias := reflection.ExtractTableNameOnly(tableName)
|
||||||
|
query = query.Where(fmt.Sprintf("%s.%s = ?", common.QuoteIdent(tableAlias), common.QuoteIdent(pkName)), id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply sorting
|
// Apply sorting
|
||||||
|
tableAlias := reflection.ExtractTableNameOnly(tableName)
|
||||||
for _, sort := range options.Sort {
|
for _, sort := range options.Sort {
|
||||||
direction := "ASC"
|
direction := "ASC"
|
||||||
if strings.EqualFold(sort.Direction, "desc") {
|
if strings.EqualFold(sort.Direction, "desc") {
|
||||||
@@ -569,9 +644,12 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
if strings.HasPrefix(sort.Column, "(") && strings.HasSuffix(sort.Column, ")") {
|
||||||
// For expressions, pass as raw SQL to prevent auto-quoting
|
// For expressions, pass as raw SQL to prevent auto-quoting
|
||||||
query = query.OrderExpr(fmt.Sprintf("%s %s", sort.Column, direction))
|
query = query.OrderExpr(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||||
|
} else if strings.Contains(sort.Column, ".") {
|
||||||
|
// Already qualified (e.g. alias.column) - pass as raw expression to preserve the dot
|
||||||
|
query = query.OrderExpr(fmt.Sprintf("%s %s", sort.Column, direction))
|
||||||
} else {
|
} else {
|
||||||
// Regular column - let Bun handle quoting
|
// Unqualified column - prefix with main table alias to avoid ambiguity on JOINs
|
||||||
query = query.Order(fmt.Sprintf("%s %s", sort.Column, direction))
|
query = query.OrderExpr(fmt.Sprintf("%s.%s %s", common.QuoteIdent(tableAlias), common.QuoteIdent(sort.Column), direction))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -666,12 +744,19 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
// Extract model columns for validation using the generic database function
|
// Extract model columns for validation using the generic database function
|
||||||
modelColumns := reflection.GetModelColumns(model)
|
modelColumns := reflection.GetModelColumns(model)
|
||||||
|
|
||||||
// Build expand joins map (if needed in future)
|
// Build expand joins map: custom SQL joins are available in cursor subquery
|
||||||
var expandJoins map[string]string
|
expandJoins := make(map[string]string)
|
||||||
if len(options.Expand) > 0 {
|
for _, joinClause := range options.CustomSQLJoin {
|
||||||
expandJoins = make(map[string]string)
|
alias := extractJoinAlias(joinClause)
|
||||||
// TODO: Build actual JOIN SQL for each expand relation
|
if alias != "" {
|
||||||
// For now, pass empty map as joins are handled via Preload
|
expandJoins[alias] = joinClause
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO: also add Expand relation JOINs when those are built as SQL rather than Preload
|
||||||
|
|
||||||
|
// Default sort to primary key when none provided
|
||||||
|
if len(options.Sort) == 0 {
|
||||||
|
options.Sort = []common.SortOption{{Column: pkName, Direction: "ASC"}}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get cursor filter SQL
|
// Get cursor filter SQL
|
||||||
@@ -730,7 +815,14 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set row numbers on each record if the model has a RowNumber field
|
// Set row numbers on each record if the model has a RowNumber field
|
||||||
|
// If FetchRowNumber was used, set the fetched row number instead of offset-based
|
||||||
|
if fetchedRowNumber != nil {
|
||||||
|
// FetchRowNumber: set the actual row position on the record
|
||||||
|
logger.Debug("FetchRowNumber: Setting row number %d on record", *fetchedRowNumber)
|
||||||
|
h.setRowNumbersOnRecords(modelPtr, int(*fetchedRowNumber-1)) // -1 because setRowNumbersOnRecords adds 1
|
||||||
|
} else {
|
||||||
h.setRowNumbersOnRecords(modelPtr, offset)
|
h.setRowNumbersOnRecords(modelPtr, offset)
|
||||||
|
}
|
||||||
|
|
||||||
metadata := &common.Metadata{
|
metadata := &common.Metadata{
|
||||||
Total: int64(total),
|
Total: int64(total),
|
||||||
@@ -740,21 +832,10 @@ func (h *Handler) handleRead(ctx context.Context, w common.ResponseWriter, id st
|
|||||||
Offset: offset,
|
Offset: offset,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch row number for a specific record if requested
|
// If FetchRowNumber was used, also set it in metadata
|
||||||
if options.FetchRowNumber != nil && *options.FetchRowNumber != "" {
|
if fetchedRowNumber != nil {
|
||||||
pkName := reflection.GetPrimaryKeyName(model)
|
metadata.RowNumber = fetchedRowNumber
|
||||||
pkValue := *options.FetchRowNumber
|
logger.Debug("FetchRowNumber: Row number %d set in metadata", *fetchedRowNumber)
|
||||||
|
|
||||||
logger.Debug("Fetching row number for specific PK %s = %s", pkName, pkValue)
|
|
||||||
|
|
||||||
rowNum, err := h.FetchRowNumber(ctx, tableName, pkName, pkValue, options, model)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Failed to fetch row number: %v", err)
|
|
||||||
// Don't fail the entire request, just log the warning
|
|
||||||
} else {
|
|
||||||
metadata.RowNumber = &rowNum
|
|
||||||
logger.Debug("Row number for PK %s: %d", pkValue, rowNum)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute AfterRead hooks
|
// Execute AfterRead hooks
|
||||||
@@ -1137,8 +1218,8 @@ func (h *Handler) handleCreate(ctx context.Context, w common.ResponseWriter, dat
|
|||||||
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
if provider, ok := modelValue.(common.TableNameProvider); !ok || provider.TableName() == "" {
|
||||||
query = query.Table(tableName)
|
query = query.Table(tableName)
|
||||||
}
|
}
|
||||||
|
fields := reflection.GetSQLModelColumns(model)
|
||||||
query = query.Returning("*")
|
query = query.Returning(fields...)
|
||||||
|
|
||||||
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
// Execute BeforeScan hooks - pass query chain so hooks can modify it
|
||||||
itemHookCtx := &HookContext{
|
itemHookCtx := &HookContext{
|
||||||
@@ -1286,7 +1367,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
|
|
||||||
// First, read the existing record from the database
|
// First, read the existing record from the database
|
||||||
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
existingRecord := reflect.New(reflection.GetPointerElement(reflect.TypeOf(model))).Interface()
|
||||||
selectQuery := tx.NewSelect().Model(existingRecord).Column("*").Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
selectQuery := tx.NewSelect().Model(existingRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return fmt.Errorf("record not found with ID: %v", targetID)
|
return fmt.Errorf("record not found with ID: %v", targetID)
|
||||||
@@ -1399,18 +1480,7 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the updated record to return the new values
|
_ = result
|
||||||
modelValue := reflect.New(reflect.TypeOf(model)).Interface()
|
|
||||||
selectQuery = tx.NewSelect().Model(modelValue).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
|
||||||
if err := selectQuery.ScanModel(ctx); err != nil {
|
|
||||||
return fmt.Errorf("failed to fetch updated record: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedRecord = modelValue
|
|
||||||
|
|
||||||
// Store result for hooks
|
|
||||||
hookCtx.Result = updatedRecord
|
|
||||||
_ = result // Keep result variable for potential future use
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1420,6 +1490,16 @@ func (h *Handler) handleUpdate(ctx context.Context, w common.ResponseWriter, id
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Fetch the updated record after the transaction commits to capture any trigger changes
|
||||||
|
fetchedRecord := reflect.New(reflect.TypeOf(model)).Interface()
|
||||||
|
selectQuery := h.db.NewSelect().Model(fetchedRecord).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(pkName)), targetID)
|
||||||
|
if err := selectQuery.ScanModel(ctx); err != nil {
|
||||||
|
logger.Error("Failed to fetch updated record: %v", err)
|
||||||
|
h.sendError(w, http.StatusInternalServerError, "fetch_error", "Failed to fetch updated record", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updatedRecord = fetchedRecord
|
||||||
|
|
||||||
// Merge the updated record with the original request data
|
// Merge the updated record with the original request data
|
||||||
// This preserves extra keys from the request and updates values from the database
|
// This preserves extra keys from the request and updates values from the database
|
||||||
mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap)
|
mergedData := h.mergeRecordWithRequest(updatedRecord, dataMap)
|
||||||
@@ -1480,8 +1560,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
logger.Warn("BeforeDelete hook failed for ID %s: %v", itemID, err)
|
logger.Error("BeforeDelete hook failed for ID %s: %v", itemID, err)
|
||||||
continue
|
return fmt.Errorf("delete not allowed for ID %s: %w", itemID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
@@ -1554,8 +1634,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
logger.Warn("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
logger.Error("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||||
continue
|
return fmt.Errorf("delete not allowed for ID %v: %w", itemID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
@@ -1612,8 +1692,8 @@ func (h *Handler) handleDelete(ctx context.Context, w common.ResponseWriter, id
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
if err := h.hooks.Execute(BeforeDelete, hookCtx); err != nil {
|
||||||
logger.Warn("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
logger.Error("BeforeDelete hook failed for ID %v: %v", itemID, err)
|
||||||
continue
|
return fmt.Errorf("delete not allowed for ID %v: %w", itemID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
query := tx.NewDelete().Table(tableName).Where(fmt.Sprintf("%s = ?", common.QuoteIdent(reflection.GetPrimaryKeyName(model))), itemID)
|
||||||
@@ -2058,11 +2138,12 @@ func (h *Handler) qualifyColumnName(columnName, fullTableName string) string {
|
|||||||
|
|
||||||
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption, tableName string, needsCast bool, logicOp string) common.SelectQuery {
|
func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOption, tableName string, needsCast bool, logicOp string) common.SelectQuery {
|
||||||
// Qualify the column name with table name if not already qualified
|
// Qualify the column name with table name if not already qualified
|
||||||
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
rawQualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||||
|
qualifiedColumn := rawQualifiedColumn
|
||||||
|
|
||||||
// Apply casting to text if needed for non-numeric columns or non-numeric values
|
// Apply casting to text if needed for non-numeric columns or non-numeric values
|
||||||
if needsCast {
|
if needsCast {
|
||||||
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn)
|
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", rawQualifiedColumn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to apply the correct Where method based on logic operator
|
// Helper function to apply the correct Where method based on logic operator
|
||||||
@@ -2087,13 +2168,17 @@ func (h *Handler) applyFilter(query common.SelectQuery, filter common.FilterOpti
|
|||||||
case "lte", "less_than_equals", "le":
|
case "lte", "less_than_equals", "le":
|
||||||
return applyWhere(fmt.Sprintf("%s <= ?", qualifiedColumn), filter.Value)
|
return applyWhere(fmt.Sprintf("%s <= ?", qualifiedColumn), filter.Value)
|
||||||
case "like":
|
case "like":
|
||||||
return applyWhere(fmt.Sprintf("%s LIKE ?", qualifiedColumn), filter.Value)
|
// Always cast to TEXT for LIKE/ILIKE to support date/time/timestamp columns
|
||||||
|
return applyWhere(fmt.Sprintf("CAST(%s AS TEXT) LIKE ?", rawQualifiedColumn), filter.Value)
|
||||||
case "ilike":
|
case "ilike":
|
||||||
// Use ILIKE for case-insensitive search (PostgreSQL)
|
// Always cast to TEXT for LIKE/ILIKE to support date/time/timestamp columns
|
||||||
// Column is already cast to TEXT if needed
|
return applyWhere(fmt.Sprintf("CAST(%s AS TEXT) ILIKE ?", rawQualifiedColumn), filter.Value)
|
||||||
return applyWhere(fmt.Sprintf("%s ILIKE ?", qualifiedColumn), filter.Value)
|
|
||||||
case "in":
|
case "in":
|
||||||
return applyWhere(fmt.Sprintf("%s IN (?)", qualifiedColumn), filter.Value)
|
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
|
||||||
|
if cond == "" {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
return applyWhere(cond, inArgs...)
|
||||||
case "between":
|
case "between":
|
||||||
// Handle between operator - exclusive (> val1 AND < val2)
|
// Handle between operator - exclusive (> val1 AND < val2)
|
||||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||||
@@ -2139,11 +2224,16 @@ func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common
|
|||||||
|
|
||||||
for i, filter := range filters {
|
for i, filter := range filters {
|
||||||
// Qualify the column name with table name if not already qualified
|
// Qualify the column name with table name if not already qualified
|
||||||
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
rawQualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||||
|
qualifiedColumn := rawQualifiedColumn
|
||||||
|
|
||||||
|
op := strings.ToLower(filter.Operator)
|
||||||
|
if op == "like" || op == "ilike" {
|
||||||
|
// Always cast to TEXT for LIKE/ILIKE to support date/time/timestamp columns
|
||||||
|
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", rawQualifiedColumn)
|
||||||
|
} else if castInfo[i].NeedsCast {
|
||||||
// Apply casting to text if needed for non-numeric columns or non-numeric values
|
// Apply casting to text if needed for non-numeric columns or non-numeric values
|
||||||
if castInfo[i].NeedsCast {
|
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", rawQualifiedColumn)
|
||||||
qualifiedColumn = fmt.Sprintf("CAST(%s AS TEXT)", qualifiedColumn)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build the condition based on operator
|
// Build the condition based on operator
|
||||||
@@ -2169,24 +2259,25 @@ func (h *Handler) applyOrFilterGroup(query common.SelectQuery, filters []*common
|
|||||||
// buildFilterCondition builds a single filter condition and returns the condition string and args
|
// buildFilterCondition builds a single filter condition and returns the condition string and args
|
||||||
func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) {
|
func (h *Handler) buildFilterCondition(qualifiedColumn string, filter *common.FilterOption, tableName string) (filterStr string, filterInterface []interface{}) {
|
||||||
switch strings.ToLower(filter.Operator) {
|
switch strings.ToLower(filter.Operator) {
|
||||||
case "eq", "equals":
|
case "eq", "equals", "=":
|
||||||
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s = ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "neq", "not_equals", "ne":
|
case "neq", "not_equals", "ne", "!=", "<>":
|
||||||
return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s != ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "gt", "greater_than":
|
case "gt", "greater_than", ">":
|
||||||
return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s > ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "gte", "greater_than_equals", "ge":
|
case "gte", "greater_than_equals", "ge", ">=":
|
||||||
return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s >= ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "lt", "less_than":
|
case "lt", "less_than", "<":
|
||||||
return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s < ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "lte", "less_than_equals", "le":
|
case "lte", "less_than_equals", "le", "<=":
|
||||||
return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s <= ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "like":
|
case "like":
|
||||||
return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s LIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "ilike":
|
case "ilike":
|
||||||
return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
return fmt.Sprintf("%s ILIKE ?", qualifiedColumn), []interface{}{filter.Value}
|
||||||
case "in":
|
case "in":
|
||||||
return fmt.Sprintf("%s IN (?)", qualifiedColumn), []interface{}{filter.Value}
|
cond, inArgs := common.BuildInCondition(qualifiedColumn, filter.Value)
|
||||||
|
return cond, inArgs
|
||||||
case "between":
|
case "between":
|
||||||
// Handle between operator - exclusive (> val1 AND < val2)
|
// Handle between operator - exclusive (> val1 AND < val2)
|
||||||
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
if values, ok := filter.Value.([]interface{}); ok && len(values) == 2 {
|
||||||
@@ -2417,14 +2508,12 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
|
|||||||
w.SetHeader("X-No-Data-Found", "true")
|
w.SetHeader("X-No-Data-Found", "true")
|
||||||
}
|
}
|
||||||
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
|
|
||||||
// Normalize single-record arrays to objects if requested
|
// Normalize single-record arrays to objects if requested
|
||||||
if options != nil && options.SingleRecordAsObject {
|
if options != nil && options.SingleRecordAsObject {
|
||||||
data = h.normalizeResultArray(data)
|
data = h.normalizeResultArray(data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return data as-is without wrapping in common.Response
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
if err := w.WriteJSON(data); err != nil {
|
if err := w.WriteJSON(data); err != nil {
|
||||||
logger.Error("Failed to write JSON response: %v", err)
|
logger.Error("Failed to write JSON response: %v", err)
|
||||||
@@ -2435,7 +2524,7 @@ func (h *Handler) sendResponseWithOptions(w common.ResponseWriter, data interfac
|
|||||||
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
// Returns the single element if data is a slice/array with exactly one element, otherwise returns data unchanged
|
||||||
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
||||||
if data == nil {
|
if data == nil {
|
||||||
return []interface{}{}
|
return map[string]interface{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use reflection to check if data is a slice or array
|
// Use reflection to check if data is a slice or array
|
||||||
@@ -2450,15 +2539,15 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
|||||||
// Return the single element
|
// Return the single element
|
||||||
return dataValue.Index(0).Interface()
|
return dataValue.Index(0).Interface()
|
||||||
} else if dataValue.Len() == 0 {
|
} else if dataValue.Len() == 0 {
|
||||||
// Keep empty array as empty array, don't convert to empty object
|
// Single-record request with no result → empty object
|
||||||
return []interface{}{}
|
return map[string]interface{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if dataValue.Kind() == reflect.String {
|
if dataValue.Kind() == reflect.String {
|
||||||
str := dataValue.String()
|
str := dataValue.String()
|
||||||
if str == "" || str == "null" {
|
if str == "" || str == "null" {
|
||||||
return []interface{}{}
|
return map[string]interface{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -2467,9 +2556,6 @@ func (h *Handler) normalizeResultArray(data interface{}) interface{} {
|
|||||||
|
|
||||||
// sendFormattedResponse sends response with formatting options
|
// sendFormattedResponse sends response with formatting options
|
||||||
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
|
func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{}, metadata *common.Metadata, options ExtendedRequestOptions) {
|
||||||
// Normalize single-record arrays to objects if requested
|
|
||||||
httpStatus := http.StatusOK
|
|
||||||
|
|
||||||
// Handle nil data - convert to empty array
|
// Handle nil data - convert to empty array
|
||||||
if data == nil {
|
if data == nil {
|
||||||
data = []interface{}{}
|
data = []interface{}{}
|
||||||
@@ -2506,7 +2592,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
|||||||
switch options.ResponseFormat {
|
switch options.ResponseFormat {
|
||||||
case "simple":
|
case "simple":
|
||||||
// Simple format: just return the data array
|
// Simple format: just return the data array
|
||||||
w.WriteHeader(httpStatus)
|
w.WriteHeader(http.StatusOK)
|
||||||
if err := w.WriteJSON(data); err != nil {
|
if err := w.WriteJSON(data); err != nil {
|
||||||
logger.Error("Failed to write JSON response: %v", err)
|
logger.Error("Failed to write JSON response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -2518,7 +2604,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
|||||||
if metadata != nil {
|
if metadata != nil {
|
||||||
response["count"] = metadata.Total
|
response["count"] = metadata.Total
|
||||||
}
|
}
|
||||||
w.WriteHeader(httpStatus)
|
w.WriteHeader(http.StatusOK)
|
||||||
if err := w.WriteJSON(response); err != nil {
|
if err := w.WriteJSON(response); err != nil {
|
||||||
logger.Error("Failed to write JSON response: %v", err)
|
logger.Error("Failed to write JSON response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -2529,7 +2615,7 @@ func (h *Handler) sendFormattedResponse(w common.ResponseWriter, data interface{
|
|||||||
Data: data,
|
Data: data,
|
||||||
Metadata: metadata,
|
Metadata: metadata,
|
||||||
}
|
}
|
||||||
w.WriteHeader(httpStatus)
|
w.WriteHeader(http.StatusOK)
|
||||||
if err := w.WriteJSON(response); err != nil {
|
if err := w.WriteJSON(response); err != nil {
|
||||||
logger.Error("Failed to write JSON response: %v", err)
|
logger.Error("Failed to write JSON response: %v", err)
|
||||||
}
|
}
|
||||||
@@ -2559,6 +2645,12 @@ func (h *Handler) sendError(w common.ResponseWriter, statusCode int, code, messa
|
|||||||
"_error": errorMsg,
|
"_error": errorMsg,
|
||||||
"_retval": 1,
|
"_retval": 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var sqlErr *common.SQLError
|
||||||
|
if errors.As(err, &sqlErr) {
|
||||||
|
response["_sql"] = sqlErr.SQL
|
||||||
|
}
|
||||||
|
|
||||||
w.SetHeader("Content-Type", "application/json")
|
w.SetHeader("Content-Type", "application/json")
|
||||||
w.WriteHeader(statusCode)
|
w.WriteHeader(statusCode)
|
||||||
if jsonErr := w.WriteJSON(response); jsonErr != nil {
|
if jsonErr := w.WriteJSON(response); jsonErr != nil {
|
||||||
@@ -2602,21 +2694,8 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
|
|||||||
sortSQL = fmt.Sprintf("%s.%s ASC", tableName, pkName)
|
sortSQL = fmt.Sprintf("%s.%s ASC", tableName, pkName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build WHERE clauses from filters
|
// Build WHERE clause from filters with proper OR grouping
|
||||||
whereClauses := make([]string, 0)
|
whereSQL := h.buildWhereClauseWithORGrouping(options.Filters, tableName)
|
||||||
for i := range options.Filters {
|
|
||||||
filter := &options.Filters[i]
|
|
||||||
whereClause := h.buildFilterSQL(filter, tableName)
|
|
||||||
if whereClause != "" {
|
|
||||||
whereClauses = append(whereClauses, fmt.Sprintf("(%s)", whereClause))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Combine WHERE clauses
|
|
||||||
whereSQL := ""
|
|
||||||
if len(whereClauses) > 0 {
|
|
||||||
whereSQL = "WHERE " + strings.Join(whereClauses, " AND ")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add custom SQL WHERE if provided
|
// Add custom SQL WHERE if provided
|
||||||
if options.CustomSQLWhere != "" {
|
if options.CustomSQLWhere != "" {
|
||||||
@@ -2664,19 +2743,86 @@ func (h *Handler) FetchRowNumber(ctx context.Context, tableName string, pkName s
|
|||||||
var result []struct {
|
var result []struct {
|
||||||
RN int64 `bun:"rn"`
|
RN int64 `bun:"rn"`
|
||||||
}
|
}
|
||||||
|
logger.Debug("[FetchRowNumber] BEFORE Query call - about to execute raw query")
|
||||||
err := h.db.Query(ctx, &result, queryStr, pkValue)
|
err := h.db.Query(ctx, &result, queryStr, pkValue)
|
||||||
|
logger.Debug("[FetchRowNumber] AFTER Query call - query completed with %d results, err: %v", len(result), err)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("failed to fetch row number: %w", err)
|
return 0, fmt.Errorf("failed to fetch row number: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(result) == 0 {
|
if len(result) == 0 {
|
||||||
return 0, fmt.Errorf("no row found for primary key %s", pkValue)
|
whereInfo := "none"
|
||||||
|
if whereSQL != "" {
|
||||||
|
whereInfo = whereSQL
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("no row found for primary key %s=%s with active filters: %s", pkName, pkValue, whereInfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result[0].RN, nil
|
return result[0].RN, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildFilterSQL converts a filter to SQL WHERE clause string
|
// buildFilterSQL converts a filter to SQL WHERE clause string
|
||||||
|
// buildWhereClauseWithORGrouping builds a WHERE clause from filters with proper OR grouping
|
||||||
|
// Groups consecutive OR filters together to ensure proper SQL precedence
|
||||||
|
// Example: [A, B(OR), C(OR), D(AND)] => WHERE (A OR B OR C) AND D
|
||||||
|
func (h *Handler) buildWhereClauseWithORGrouping(filters []common.FilterOption, tableName string) string {
|
||||||
|
if len(filters) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var groups []string
|
||||||
|
i := 0
|
||||||
|
|
||||||
|
for i < len(filters) {
|
||||||
|
// Check if this starts an OR group (next filter has OR logic)
|
||||||
|
startORGroup := i+1 < len(filters) && strings.EqualFold(filters[i+1].LogicOperator, "OR")
|
||||||
|
|
||||||
|
if startORGroup {
|
||||||
|
// Collect all consecutive filters that are OR'd together
|
||||||
|
orGroup := []string{}
|
||||||
|
|
||||||
|
// Add current filter
|
||||||
|
filterSQL := h.buildFilterSQL(&filters[i], tableName)
|
||||||
|
if filterSQL != "" {
|
||||||
|
orGroup = append(orGroup, filterSQL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect remaining OR filters
|
||||||
|
j := i + 1
|
||||||
|
for j < len(filters) && strings.EqualFold(filters[j].LogicOperator, "OR") {
|
||||||
|
filterSQL := h.buildFilterSQL(&filters[j], tableName)
|
||||||
|
if filterSQL != "" {
|
||||||
|
orGroup = append(orGroup, filterSQL)
|
||||||
|
}
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Group OR filters with parentheses
|
||||||
|
if len(orGroup) > 0 {
|
||||||
|
if len(orGroup) == 1 {
|
||||||
|
groups = append(groups, orGroup[0])
|
||||||
|
} else {
|
||||||
|
groups = append(groups, "("+strings.Join(orGroup, " OR ")+")")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i = j
|
||||||
|
} else {
|
||||||
|
// Single filter with AND logic (or first filter)
|
||||||
|
filterSQL := h.buildFilterSQL(&filters[i], tableName)
|
||||||
|
if filterSQL != "" {
|
||||||
|
groups = append(groups, filterSQL)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(groups) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return "WHERE " + strings.Join(groups, " AND ")
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string) string {
|
func (h *Handler) buildFilterSQL(filter *common.FilterOption, tableName string) string {
|
||||||
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
qualifiedColumn := h.qualifyColumnName(filter.Column, tableName)
|
||||||
|
|
||||||
@@ -2767,6 +2913,8 @@ func (h *Handler) filterExtendedOptions(validator *common.ColumnValidator, optio
|
|||||||
|
|
||||||
// Filter base RequestOptions
|
// Filter base RequestOptions
|
||||||
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
|
filtered.RequestOptions = validator.FilterRequestOptions(options.RequestOptions)
|
||||||
|
// Restore JoinAliases cleared by FilterRequestOptions — still needed for SanitizeWhereClause
|
||||||
|
filtered.RequestOptions.JoinAliases = options.JoinAliases
|
||||||
|
|
||||||
// Filter SearchColumns
|
// Filter SearchColumns
|
||||||
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)
|
filtered.SearchColumns = validator.FilterValidColumns(options.SearchColumns)
|
||||||
|
|||||||
+215
-49
@@ -5,6 +5,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -63,7 +65,10 @@ type ExpandOption struct {
|
|||||||
// decodeHeaderValue decodes base64 encoded header values
|
// decodeHeaderValue decodes base64 encoded header values
|
||||||
// Supports ZIP_ and __ prefixes for base64 encoding
|
// Supports ZIP_ and __ prefixes for base64 encoding
|
||||||
func decodeHeaderValue(value string) string {
|
func decodeHeaderValue(value string) string {
|
||||||
str, _ := DecodeParam(value)
|
str, err := DecodeParam(value)
|
||||||
|
if err != nil {
|
||||||
|
return value
|
||||||
|
}
|
||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,9 +141,21 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
combinedParams[strings.ToLower(key)] = value
|
combinedParams[strings.ToLower(key)] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sortedKeys := make([]string, 0, len(combinedParams))
|
||||||
|
for key := range combinedParams {
|
||||||
|
sortedKeys = append(sortedKeys, key)
|
||||||
|
}
|
||||||
|
sort.Slice(sortedKeys, func(i, j int) bool {
|
||||||
|
if sortedKeys[i] != sortedKeys[j] {
|
||||||
|
return sortedKeys[i] < sortedKeys[j]
|
||||||
|
}
|
||||||
|
return combinedParams[sortedKeys[i]] < combinedParams[sortedKeys[j]]
|
||||||
|
})
|
||||||
|
|
||||||
// Process each parameter (from both headers and query params)
|
// Process each parameter (from both headers and query params)
|
||||||
// Note: keys are already normalized to lowercase in combinedParams
|
// Note: keys are already normalized to lowercase in combinedParams
|
||||||
for key, value := range combinedParams {
|
for _, key := range sortedKeys {
|
||||||
|
value := combinedParams[key]
|
||||||
// Decode value if it's base64 encoded
|
// Decode value if it's base64 encoded
|
||||||
decodedValue := decodeHeaderValue(value)
|
decodedValue := decodeHeaderValue(value)
|
||||||
|
|
||||||
@@ -274,9 +291,11 @@ func (h *Handler) parseOptionsFromHeaders(r common.Request, model interface{}) E
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve relation names (convert table names to field names) if model is provided
|
// Resolve relation names (convert table names/prefixes to actual model field names) if model is provided.
|
||||||
// Skip resolution if X-Files header was provided, as XFiles uses Prefix which already contains the correct field names
|
// This runs for both regular headers and X-Files, because XFile prefixes don't always match model
|
||||||
if model != nil && !options.XFilesPresent {
|
// field names (e.g., prefix "HUB" vs field "HUB_RID_HUB"). RelatedKey/ForeignKey are used to
|
||||||
|
// disambiguate when multiple fields point to the same related type.
|
||||||
|
if model != nil {
|
||||||
h.resolveRelationNamesInOptions(&options, model)
|
h.resolveRelationNamesInOptions(&options, model)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -499,6 +518,31 @@ func (h *Handler) parseExpand(options *ExtendedRequestOptions, value string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reMultiJoinBoundary finds the start of each individual JOIN clause within a string that
|
||||||
|
// may contain multiple consecutive JOIN clauses (e.g., "INNER JOIN ... LEFT OUTER JOIN ...").
|
||||||
|
var reMultiJoinBoundary = regexp.MustCompile(`(?i)(?:inner|left(?:\s+outer)?|right(?:\s+outer)?|full(?:\s+outer)?|cross)\s+join\b`)
|
||||||
|
|
||||||
|
// splitJoinClauses splits a SQL string that may contain multiple JOIN clauses into
|
||||||
|
// individual clauses. A plain pipe-separated segment may itself contain several JOINs;
|
||||||
|
// this function splits them so each gets its own alias entry.
|
||||||
|
func splitJoinClauses(joinStr string) []string {
|
||||||
|
indices := reMultiJoinBoundary.FindAllStringIndex(joinStr, -1)
|
||||||
|
if len(indices) <= 1 {
|
||||||
|
return []string{strings.TrimSpace(joinStr)}
|
||||||
|
}
|
||||||
|
parts := make([]string, 0, len(indices))
|
||||||
|
for i, idx := range indices {
|
||||||
|
end := len(joinStr)
|
||||||
|
if i+1 < len(indices) {
|
||||||
|
end = indices[i+1][0]
|
||||||
|
}
|
||||||
|
if part := strings.TrimSpace(joinStr[idx[0]:end]); part != "" {
|
||||||
|
parts = append(parts, part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
|
||||||
// parseCustomSQLJoin parses x-custom-sql-join header
|
// parseCustomSQLJoin parses x-custom-sql-join header
|
||||||
// Format: Single JOIN clause or multiple JOIN clauses separated by |
|
// Format: Single JOIN clause or multiple JOIN clauses separated by |
|
||||||
// Example: "LEFT JOIN departments d ON d.id = employees.department_id"
|
// Example: "LEFT JOIN departments d ON d.id = employees.department_id"
|
||||||
@@ -531,17 +575,19 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract table alias from the JOIN clause
|
// Split into individual JOIN clauses so each clause gets its own alias entry.
|
||||||
alias := extractJoinAlias(sanitizedJoin)
|
// CustomSQLJoin and JoinAliases are kept parallel (one entry per individual clause).
|
||||||
if alias != "" {
|
for _, clause := range splitJoinClauses(sanitizedJoin) {
|
||||||
|
alias := extractJoinAlias(clause)
|
||||||
|
// Keep arrays parallel; use empty string when alias cannot be extracted.
|
||||||
options.JoinAliases = append(options.JoinAliases, alias)
|
options.JoinAliases = append(options.JoinAliases, alias)
|
||||||
// Also add to the embedded RequestOptions for validation
|
|
||||||
options.RequestOptions.JoinAliases = append(options.RequestOptions.JoinAliases, alias)
|
options.RequestOptions.JoinAliases = append(options.RequestOptions.JoinAliases, alias)
|
||||||
|
if alias != "" {
|
||||||
logger.Debug("Extracted join alias: %s", alias)
|
logger.Debug("Extracted join alias: %s", alias)
|
||||||
}
|
}
|
||||||
|
logger.Debug("Adding custom SQL join: %s", clause)
|
||||||
logger.Debug("Adding custom SQL join: %s", sanitizedJoin)
|
options.CustomSQLJoin = append(options.CustomSQLJoin, clause)
|
||||||
options.CustomSQLJoin = append(options.CustomSQLJoin, sanitizedJoin)
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -550,10 +596,8 @@ func (h *Handler) parseCustomSQLJoin(options *ExtendedRequestOptions, value stri
|
|||||||
// - "LEFT JOIN departments d ON ..." -> "d"
|
// - "LEFT JOIN departments d ON ..." -> "d"
|
||||||
// - "INNER JOIN users AS u ON ..." -> "u"
|
// - "INNER JOIN users AS u ON ..." -> "u"
|
||||||
// - "JOIN roles r ON ..." -> "r"
|
// - "JOIN roles r ON ..." -> "r"
|
||||||
|
// - "INNER JOIN LATERAL (...) fn ON true" -> "fn"
|
||||||
func extractJoinAlias(joinClause string) string {
|
func extractJoinAlias(joinClause string) string {
|
||||||
// Pattern: JOIN table_name [AS] alias ON ...
|
|
||||||
// We need to extract the alias (word before ON)
|
|
||||||
|
|
||||||
upperJoin := strings.ToUpper(joinClause)
|
upperJoin := strings.ToUpper(joinClause)
|
||||||
|
|
||||||
// Find the "JOIN" keyword position
|
// Find the "JOIN" keyword position
|
||||||
@@ -562,7 +606,20 @@ func extractJoinAlias(joinClause string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the "ON" keyword position
|
// Lateral joins: alias is the word after the closing ) and before ON
|
||||||
|
if strings.Contains(upperJoin, "LATERAL") {
|
||||||
|
lastClose := strings.LastIndex(joinClause, ")")
|
||||||
|
if lastClose != -1 {
|
||||||
|
words := strings.Fields(joinClause[lastClose+1:])
|
||||||
|
// words should be like ["fn", "on", "true"] or ["on", "true"]
|
||||||
|
if len(words) >= 1 && !strings.EqualFold(words[0], "on") {
|
||||||
|
return words[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular joins: find the "ON" keyword position (first occurrence)
|
||||||
onIdx := strings.Index(upperJoin, " ON ")
|
onIdx := strings.Index(upperJoin, " ON ")
|
||||||
if onIdx == -1 {
|
if onIdx == -1 {
|
||||||
return ""
|
return ""
|
||||||
@@ -863,8 +920,21 @@ func (h *Handler) resolveRelationNamesInOptions(options *ExtendedRequestOptions,
|
|||||||
|
|
||||||
// Resolve each part of the path
|
// Resolve each part of the path
|
||||||
currentModel := model
|
currentModel := model
|
||||||
for _, part := range parts {
|
for partIdx, part := range parts {
|
||||||
resolvedPart := h.resolveRelationName(currentModel, part)
|
isLast := partIdx == len(parts)-1
|
||||||
|
var resolvedPart string
|
||||||
|
if isLast {
|
||||||
|
// For the final part, use join-key-aware resolution to disambiguate when
|
||||||
|
// multiple fields point to the same type (e.g., HUB_RID_HUB vs HUB_RID_ASSIGNEDTO).
|
||||||
|
// RelatedKey = parent's local column linking to child; ForeignKey = local column linking to parent.
|
||||||
|
localKey := preload.RelatedKey
|
||||||
|
if localKey == "" {
|
||||||
|
localKey = preload.ForeignKey
|
||||||
|
}
|
||||||
|
resolvedPart = h.resolveRelationNameWithJoinKey(currentModel, part, localKey)
|
||||||
|
} else {
|
||||||
|
resolvedPart = h.resolveRelationName(currentModel, part)
|
||||||
|
}
|
||||||
resolvedParts = append(resolvedParts, resolvedPart)
|
resolvedParts = append(resolvedParts, resolvedPart)
|
||||||
|
|
||||||
// Try to get the model type for the next level
|
// Try to get the model type for the next level
|
||||||
@@ -980,6 +1050,101 @@ func (h *Handler) resolveRelationName(model interface{}, nameOrTable string) str
|
|||||||
return nameOrTable
|
return nameOrTable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resolveRelationNameWithJoinKey resolves a relation name like resolveRelationName, but when
|
||||||
|
// multiple fields point to the same related type, uses localKey to pick the one whose bun join
|
||||||
|
// tag starts with "join:localKey=". Falls back to resolveRelationName if no key match is found.
|
||||||
|
func (h *Handler) resolveRelationNameWithJoinKey(model interface{}, nameOrTable string, localKey string) string {
|
||||||
|
if localKey == "" {
|
||||||
|
return h.resolveRelationName(model, nameOrTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelType := reflect.TypeOf(model)
|
||||||
|
if modelType == nil {
|
||||||
|
return nameOrTable
|
||||||
|
}
|
||||||
|
if modelType.Kind() == reflect.Ptr {
|
||||||
|
modelType = modelType.Elem()
|
||||||
|
}
|
||||||
|
if modelType == nil || modelType.Kind() != reflect.Struct {
|
||||||
|
return nameOrTable
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's already a direct field name, return as-is (no ambiguity).
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
if modelType.Field(i).Name == nameOrTable {
|
||||||
|
return nameOrTable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedInput := strings.ToLower(strings.ReplaceAll(nameOrTable, "_", ""))
|
||||||
|
localKeyLower := strings.ToLower(localKey)
|
||||||
|
|
||||||
|
// Find all fields whose related type matches nameOrTable, then pick the one
|
||||||
|
// whose bun join tag local key matches localKey.
|
||||||
|
var fallbackField string
|
||||||
|
for i := 0; i < modelType.NumField(); i++ {
|
||||||
|
field := modelType.Field(i)
|
||||||
|
fieldType := field.Type
|
||||||
|
|
||||||
|
var targetType reflect.Type
|
||||||
|
if fieldType.Kind() == reflect.Slice {
|
||||||
|
targetType = fieldType.Elem()
|
||||||
|
} else if fieldType.Kind() == reflect.Ptr {
|
||||||
|
targetType = fieldType.Elem()
|
||||||
|
}
|
||||||
|
if targetType != nil && targetType.Kind() == reflect.Ptr {
|
||||||
|
targetType = targetType.Elem()
|
||||||
|
}
|
||||||
|
if targetType == nil || targetType.Kind() != reflect.Struct {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
normalizedTypeName := strings.ToLower(targetType.Name())
|
||||||
|
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "modelcore")
|
||||||
|
normalizedTypeName = strings.TrimPrefix(normalizedTypeName, "model")
|
||||||
|
if normalizedTypeName != normalizedInput {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type name matches; record as fallback.
|
||||||
|
if fallbackField == "" {
|
||||||
|
fallbackField = field.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check bun join tag: "join:localKey=foreignKey"
|
||||||
|
bunTag := field.Tag.Get("bun")
|
||||||
|
for _, tagPart := range strings.Split(bunTag, ",") {
|
||||||
|
tagPart = strings.TrimSpace(tagPart)
|
||||||
|
if !strings.HasPrefix(tagPart, "join:") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
joinSpec := strings.TrimPrefix(tagPart, "join:")
|
||||||
|
// joinSpec can be "col1=col2" or "col1=col2 col3=col4" (multi-col joins)
|
||||||
|
joinCols := strings.Fields(joinSpec)
|
||||||
|
if len(joinCols) == 0 {
|
||||||
|
joinCols = []string{joinSpec}
|
||||||
|
}
|
||||||
|
for _, joinCol := range joinCols {
|
||||||
|
eqIdx := strings.Index(joinCol, "=")
|
||||||
|
if eqIdx < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
joinLocalKey := strings.ToLower(joinCol[:eqIdx])
|
||||||
|
if joinLocalKey == localKeyLower {
|
||||||
|
logger.Debug("Resolved '%s' (localKey: %s) -> field '%s'", nameOrTable, localKey, field.Name)
|
||||||
|
return field.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if fallbackField != "" {
|
||||||
|
logger.Debug("No join key match for '%s' (localKey: %s), using first type match: '%s'", nameOrTable, localKey, fallbackField)
|
||||||
|
return fallbackField
|
||||||
|
}
|
||||||
|
return h.resolveRelationName(model, nameOrTable)
|
||||||
|
}
|
||||||
|
|
||||||
// addXFilesPreload converts an XFiles relation into a PreloadOption
|
// addXFilesPreload converts an XFiles relation into a PreloadOption
|
||||||
// and recursively processes its children
|
// and recursively processes its children
|
||||||
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
|
func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOptions, basePath string) {
|
||||||
@@ -1061,15 +1226,42 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transfer SqlJoins from XFiles to PreloadOption first, so aliases are available for WHERE sanitization
|
||||||
|
if len(xfile.SqlJoins) > 0 {
|
||||||
|
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
|
||||||
|
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
|
||||||
|
|
||||||
|
for _, joinClause := range xfile.SqlJoins {
|
||||||
|
// Sanitize the join clause
|
||||||
|
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
|
||||||
|
if sanitizedJoin == "" {
|
||||||
|
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
|
||||||
|
|
||||||
|
// Extract join alias for validation
|
||||||
|
alias := extractJoinAlias(sanitizedJoin)
|
||||||
|
if alias != "" {
|
||||||
|
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
|
||||||
|
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
|
||||||
|
}
|
||||||
|
|
||||||
// Add WHERE clause if SQL conditions specified
|
// Add WHERE clause if SQL conditions specified
|
||||||
|
// SqlJoins must be processed first so join aliases are known and not incorrectly replaced
|
||||||
whereConditions := make([]string, 0)
|
whereConditions := make([]string, 0)
|
||||||
if len(xfile.SqlAnd) > 0 {
|
if len(xfile.SqlAnd) > 0 {
|
||||||
// Process each SQL condition
|
var sqlAndOpts *common.RequestOptions
|
||||||
// Note: We don't add table prefixes here because they're only needed for JOINs
|
if len(preloadOpt.JoinAliases) > 0 {
|
||||||
// The handler will add prefixes later if SqlJoins are present
|
sqlAndOpts = &common.RequestOptions{JoinAliases: preloadOpt.JoinAliases}
|
||||||
|
}
|
||||||
for _, sqlCond := range xfile.SqlAnd {
|
for _, sqlCond := range xfile.SqlAnd {
|
||||||
// Sanitize the condition without adding prefixes
|
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName, sqlAndOpts)
|
||||||
sanitizedCond := common.SanitizeWhereClause(sqlCond, xfile.TableName)
|
|
||||||
if sanitizedCond != "" {
|
if sanitizedCond != "" {
|
||||||
whereConditions = append(whereConditions, sanitizedCond)
|
whereConditions = append(whereConditions, sanitizedCond)
|
||||||
}
|
}
|
||||||
@@ -1114,32 +1306,6 @@ func (h *Handler) addXFilesPreload(xfile *XFiles, options *ExtendedRequestOption
|
|||||||
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
logger.Debug("X-Files: Set foreign key for %s: %s", relationPath, xfile.ForeignKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Transfer SqlJoins from XFiles to PreloadOption
|
|
||||||
if len(xfile.SqlJoins) > 0 {
|
|
||||||
preloadOpt.SqlJoins = make([]string, 0, len(xfile.SqlJoins))
|
|
||||||
preloadOpt.JoinAliases = make([]string, 0, len(xfile.SqlJoins))
|
|
||||||
|
|
||||||
for _, joinClause := range xfile.SqlJoins {
|
|
||||||
// Sanitize the join clause
|
|
||||||
sanitizedJoin := common.SanitizeWhereClause(joinClause, "", nil)
|
|
||||||
if sanitizedJoin == "" {
|
|
||||||
logger.Warn("X-Files: SqlJoin failed sanitization for %s: %s", relationPath, joinClause)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
preloadOpt.SqlJoins = append(preloadOpt.SqlJoins, sanitizedJoin)
|
|
||||||
|
|
||||||
// Extract join alias for validation
|
|
||||||
alias := extractJoinAlias(sanitizedJoin)
|
|
||||||
if alias != "" {
|
|
||||||
preloadOpt.JoinAliases = append(preloadOpt.JoinAliases, alias)
|
|
||||||
logger.Debug("X-Files: Extracted join alias for %s: %s", relationPath, alias)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug("X-Files: Added %d SQL joins to preload %s", len(preloadOpt.SqlJoins), relationPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if this table has a recursive child - if so, mark THIS preload as recursive
|
// Check if this table has a recursive child - if so, mark THIS preload as recursive
|
||||||
// and store the recursive child's RelatedKey for recursion generation
|
// and store the recursive child's RelatedKey for recursion generation
|
||||||
hasRecursiveChild := false
|
hasRecursiveChild := false
|
||||||
|
|||||||
@@ -142,6 +142,16 @@ func TestExtractJoinAlias(t *testing.T) {
|
|||||||
joinClause: "LEFT JOIN departments",
|
joinClause: "LEFT JOIN departments",
|
||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "LATERAL join with alias",
|
||||||
|
joinClause: "inner join lateral (select sortorder from compute_fn(t.id)) fn on true",
|
||||||
|
expected: "fn",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LATERAL join with multiline subquery containing inner ON",
|
||||||
|
joinClause: "inner join lateral (\nselect string_agg(a.name, '.') as sortorder\nfrom tree(t.id) r\ninner join account a on a.id = r.id\n) fn on true",
|
||||||
|
expected: "fn",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
@@ -12,6 +12,10 @@ import (
|
|||||||
type HookType string
|
type HookType string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
// BeforeHandle fires after model resolution, before operation dispatch.
|
||||||
|
// Use this for auth checks that need model rules and user context simultaneously.
|
||||||
|
BeforeHandle HookType = "before_handle"
|
||||||
|
|
||||||
// Read operation hooks
|
// Read operation hooks
|
||||||
BeforeRead HookType = "before_read"
|
BeforeRead HookType = "before_read"
|
||||||
AfterRead HookType = "after_read"
|
AfterRead HookType = "after_read"
|
||||||
@@ -42,6 +46,9 @@ type HookContext struct {
|
|||||||
Model interface{}
|
Model interface{}
|
||||||
Options ExtendedRequestOptions
|
Options ExtendedRequestOptions
|
||||||
|
|
||||||
|
// Operation being dispatched (e.g. "read", "create", "update", "delete")
|
||||||
|
Operation string
|
||||||
|
|
||||||
// Operation-specific fields
|
// Operation-specific fields
|
||||||
ID string
|
ID string
|
||||||
Data interface{} // For create/update operations
|
Data interface{} // For create/update operations
|
||||||
@@ -56,6 +63,14 @@ type HookContext struct {
|
|||||||
// Response writer - allows hooks to modify response
|
// Response writer - allows hooks to modify response
|
||||||
Writer common.ResponseWriter
|
Writer common.ResponseWriter
|
||||||
|
|
||||||
|
// Request - the original HTTP request
|
||||||
|
Request common.Request
|
||||||
|
|
||||||
|
// Allow hooks to abort the operation
|
||||||
|
Abort bool // If set to true, the operation will be aborted
|
||||||
|
AbortMessage string // Message to return if aborted
|
||||||
|
AbortCode int // HTTP status code if aborted
|
||||||
|
|
||||||
// Tx provides access to the database/transaction for executing additional SQL
|
// Tx provides access to the database/transaction for executing additional SQL
|
||||||
// This allows hooks to run custom queries in addition to the main Query chain
|
// This allows hooks to run custom queries in addition to the main Query chain
|
||||||
Tx common.Database
|
Tx common.Database
|
||||||
@@ -110,6 +125,12 @@ func (r *HookRegistry) Execute(hookType HookType, ctx *HookContext) error {
|
|||||||
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
|
logger.Error("Hook %d for %s failed: %v", i+1, hookType, err)
|
||||||
return fmt.Errorf("hook execution failed: %w", err)
|
return fmt.Errorf("hook execution failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if hook requested abort
|
||||||
|
if ctx.Abort {
|
||||||
|
logger.Warn("Hook %d for %s requested abort: %s", i+1, hookType, ctx.AbortMessage)
|
||||||
|
return fmt.Errorf("operation aborted by hook: %s", ctx.AbortMessage)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// logger.Debug("All hooks for %s executed successfully", hookType)
|
// logger.Debug("All hooks for %s executed successfully", hookType)
|
||||||
|
|||||||
@@ -125,17 +125,17 @@ func SetupMuxRoutes(muxRouter *mux.Router, handler *Handler, authMiddleware Midd
|
|||||||
metadataPath := buildRoutePath(schema, entity) + "/metadata"
|
metadataPath := buildRoutePath(schema, entity) + "/metadata"
|
||||||
|
|
||||||
// Create handler functions for this specific entity
|
// Create handler functions for this specific entity
|
||||||
entityHandler := createMuxHandler(handler, schema, entity, "")
|
var entityHandler http.Handler = createMuxHandler(handler, schema, entity, "")
|
||||||
entityWithIDHandler := createMuxHandler(handler, schema, entity, "id")
|
var entityWithIDHandler http.Handler = createMuxHandler(handler, schema, entity, "id")
|
||||||
metadataHandler := createMuxGetHandler(handler, schema, entity, "")
|
var metadataHandler http.Handler = createMuxGetHandler(handler, schema, entity, "")
|
||||||
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
optionsEntityHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "POST", "OPTIONS"})
|
||||||
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"})
|
optionsEntityWithIDHandler := createMuxOptionsHandler(handler, schema, entity, []string{"GET", "PUT", "PATCH", "DELETE", "POST", "OPTIONS"})
|
||||||
|
|
||||||
// Apply authentication middleware if provided
|
// Apply authentication middleware if provided
|
||||||
if authMiddleware != nil {
|
if authMiddleware != nil {
|
||||||
entityHandler = authMiddleware(entityHandler).(http.HandlerFunc)
|
entityHandler = authMiddleware(entityHandler)
|
||||||
entityWithIDHandler = authMiddleware(entityWithIDHandler).(http.HandlerFunc)
|
entityWithIDHandler = authMiddleware(entityWithIDHandler)
|
||||||
metadataHandler = authMiddleware(metadataHandler).(http.HandlerFunc)
|
metadataHandler = authMiddleware(metadataHandler)
|
||||||
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,9 +280,34 @@ type BunRouterHandler interface {
|
|||||||
Handle(method, path string, handler bunrouter.HandlerFunc)
|
Handle(method, path string, handler bunrouter.HandlerFunc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wrapBunRouterHandler wraps a bunrouter handler with auth middleware if provided
|
||||||
|
func wrapBunRouterHandler(handler bunrouter.HandlerFunc, authMiddleware MiddlewareFunc) bunrouter.HandlerFunc {
|
||||||
|
if authMiddleware == nil {
|
||||||
|
return handler
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
// Create an http.Handler that calls the bunrouter handler
|
||||||
|
httpHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Replace the embedded *http.Request with the middleware-enriched one
|
||||||
|
// so that auth context (user ID, etc.) is visible to the handler.
|
||||||
|
enrichedReq := req
|
||||||
|
enrichedReq.Request = r
|
||||||
|
_ = handler(w, enrichedReq)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Wrap with auth middleware and execute
|
||||||
|
wrappedHandler := authMiddleware(httpHandler)
|
||||||
|
wrappedHandler.ServeHTTP(w, req.Request)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API
|
// SetupBunRouterRoutes sets up bunrouter routes for the RestHeadSpec API
|
||||||
// Accepts bunrouter.Router or bunrouter.Group
|
// Accepts bunrouter.Router or bunrouter.Group
|
||||||
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
// authMiddleware is optional - if provided, routes will be protected with the middleware
|
||||||
|
func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler, authMiddleware MiddlewareFunc) {
|
||||||
|
|
||||||
// CORS config
|
// CORS config
|
||||||
corsConfig := common.DefaultCORSConfig()
|
corsConfig := common.DefaultCORSConfig()
|
||||||
@@ -292,6 +317,14 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
|
handler.HandleOpenAPI(respAdapter, reqAdapter)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
r.Handle("OPTIONS", "/openapi", func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
|
reqAdapter := router.NewHTTPRequest(req.Request)
|
||||||
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -313,7 +346,7 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
currentEntity := entity
|
currentEntity := entity
|
||||||
|
|
||||||
// GET and POST for /{schema}/{entity}
|
// GET and POST for /{schema}/{entity}
|
||||||
r.Handle("GET", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
getEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
@@ -324,9 +357,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
r.Handle("GET", entityPath, wrapBunRouterHandler(getEntityHandler, authMiddleware))
|
||||||
|
|
||||||
r.Handle("POST", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
postEntityHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
@@ -337,10 +371,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
r.Handle("POST", entityPath, wrapBunRouterHandler(postEntityHandler, authMiddleware))
|
||||||
|
|
||||||
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}/:id
|
// GET, POST, PUT, PATCH, DELETE for /{schema}/{entity}/:id
|
||||||
r.Handle("GET", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
getEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
@@ -352,9 +387,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
r.Handle("GET", entityWithIDPath, wrapBunRouterHandler(getEntityWithIDHandler, authMiddleware))
|
||||||
|
|
||||||
r.Handle("POST", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
postEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
@@ -366,9 +402,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
r.Handle("POST", entityWithIDPath, wrapBunRouterHandler(postEntityWithIDHandler, authMiddleware))
|
||||||
|
|
||||||
r.Handle("PUT", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
putEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
@@ -380,9 +417,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
r.Handle("PUT", entityWithIDPath, wrapBunRouterHandler(putEntityWithIDHandler, authMiddleware))
|
||||||
|
|
||||||
r.Handle("PATCH", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
patchEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
@@ -394,9 +432,10 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
r.Handle("PATCH", entityWithIDPath, wrapBunRouterHandler(patchEntityWithIDHandler, authMiddleware))
|
||||||
|
|
||||||
r.Handle("DELETE", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
deleteEntityWithIDHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
@@ -408,10 +447,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
|
|
||||||
handler.Handle(respAdapter, reqAdapter, params)
|
handler.Handle(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
r.Handle("DELETE", entityWithIDPath, wrapBunRouterHandler(deleteEntityWithIDHandler, authMiddleware))
|
||||||
|
|
||||||
// Metadata endpoint
|
// Metadata endpoint
|
||||||
r.Handle("GET", metadataPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
metadataHandler := func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
common.SetCORSHeaders(respAdapter, reqAdapter, corsConfig)
|
||||||
@@ -422,9 +462,11 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
|
|
||||||
handler.HandleGet(respAdapter, reqAdapter, params)
|
handler.HandleGet(respAdapter, reqAdapter, params)
|
||||||
return nil
|
return nil
|
||||||
})
|
}
|
||||||
|
r.Handle("GET", metadataPath, wrapBunRouterHandler(metadataHandler, authMiddleware))
|
||||||
|
|
||||||
// OPTIONS route without ID (returns metadata)
|
// OPTIONS route without ID (returns metadata)
|
||||||
|
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||||
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("OPTIONS", entityPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
@@ -441,6 +483,7 @@ func SetupBunRouterRoutes(r BunRouterHandler, handler *Handler) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// OPTIONS route with ID (returns metadata)
|
// OPTIONS route with ID (returns metadata)
|
||||||
|
// Don't apply auth middleware to OPTIONS - CORS preflight must not require auth
|
||||||
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
r.Handle("OPTIONS", entityWithIDPath, func(w http.ResponseWriter, req bunrouter.Request) error {
|
||||||
respAdapter := router.NewHTTPResponseWriter(w)
|
respAdapter := router.NewHTTPResponseWriter(w)
|
||||||
reqAdapter := router.NewBunRouterRequest(req)
|
reqAdapter := router.NewBunRouterRequest(req)
|
||||||
@@ -466,8 +509,8 @@ func ExampleBunRouterWithBunDB(bunDB *bun.DB) {
|
|||||||
// Create bunrouter
|
// Create bunrouter
|
||||||
bunRouter := bunrouter.New()
|
bunRouter := bunrouter.New()
|
||||||
|
|
||||||
// Setup routes
|
// Setup routes without authentication
|
||||||
SetupBunRouterRoutes(bunRouter, handler)
|
SetupBunRouterRoutes(bunRouter, handler, nil)
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
||||||
@@ -487,7 +530,7 @@ func ExampleBunRouterWithGroup(bunDB *bun.DB) {
|
|||||||
apiGroup := bunRouter.NewGroup("/api")
|
apiGroup := bunRouter.NewGroup("/api")
|
||||||
|
|
||||||
// Setup RestHeadSpec routes on the group - routes will be under /api
|
// Setup RestHeadSpec routes on the group - routes will be under /api
|
||||||
SetupBunRouterRoutes(apiGroup, handler)
|
SetupBunRouterRoutes(apiGroup, handler, nil)
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
if err := http.ListenAndServe(":8080", bunRouter); err != nil {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package restheadspec
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/security"
|
"github.com/bitechdev/ResolveSpec/pkg/security"
|
||||||
@@ -9,6 +10,17 @@ import (
|
|||||||
|
|
||||||
// RegisterSecurityHooks registers all security-related hooks with the handler
|
// RegisterSecurityHooks registers all security-related hooks with the handler
|
||||||
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList) {
|
||||||
|
// Hook 0: BeforeHandle - enforce auth after model resolution
|
||||||
|
handler.Hooks().Register(BeforeHandle, func(hookCtx *HookContext) error {
|
||||||
|
if err := security.CheckModelAuthAllowed(newSecurityContext(hookCtx), hookCtx.Operation); err != nil {
|
||||||
|
hookCtx.Abort = true
|
||||||
|
hookCtx.AbortMessage = err.Error()
|
||||||
|
hookCtx.AbortCode = http.StatusUnauthorized
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
// Hook 1: BeforeRead - Load security rules
|
// Hook 1: BeforeRead - Load security rules
|
||||||
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
handler.Hooks().Register(BeforeRead, func(hookCtx *HookContext) error {
|
||||||
secCtx := newSecurityContext(hookCtx)
|
secCtx := newSecurityContext(hookCtx)
|
||||||
@@ -33,6 +45,18 @@ func RegisterSecurityHooks(handler *Handler, securityList *security.SecurityList
|
|||||||
return security.LogDataAccess(secCtx)
|
return security.LogDataAccess(secCtx)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Hook 5: BeforeUpdate - enforce CanUpdate rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeUpdate, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelUpdateAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Hook 6: BeforeDelete - enforce CanDelete rule from context/registry
|
||||||
|
handler.Hooks().Register(BeforeDelete, func(hookCtx *HookContext) error {
|
||||||
|
secCtx := newSecurityContext(hookCtx)
|
||||||
|
return security.CheckModelDeleteAllowed(secCtx)
|
||||||
|
})
|
||||||
|
|
||||||
logger.Info("Security hooks registered for restheadspec handler")
|
logger.Info("Security hooks registered for restheadspec handler")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,153 @@
|
|||||||
|
# Keystore
|
||||||
|
|
||||||
|
Per-user named auth keys with pluggable storage. Each user can hold multiple keys of different types — JWT secrets, header API keys, OAuth2 client credentials, or generic API keys. Keys are identified by a human-readable name ("CI deploy", "mobile app") and can carry scopes and arbitrary metadata.
|
||||||
|
|
||||||
|
## Key types
|
||||||
|
|
||||||
|
| Constant | Value | Use case |
|
||||||
|
|---|---|---|
|
||||||
|
| `KeyTypeJWTSecret` | `jwt_secret` | Per-user JWT signing secret |
|
||||||
|
| `KeyTypeHeaderAPI` | `header_api` | Static API key sent in a request header |
|
||||||
|
| `KeyTypeOAuth2` | `oauth2` | OAuth2 client credentials |
|
||||||
|
| `KeyTypeGenericAPI` | `api` | General-purpose application key |
|
||||||
|
|
||||||
|
## Storage backends
|
||||||
|
|
||||||
|
### ConfigKeyStore
|
||||||
|
|
||||||
|
In-memory store seeded from a static list. Suitable for a small, fixed set of service-account keys loaded from a config file. Keys created at runtime via `CreateKey` are held in memory and lost on restart.
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Pre-load keys from config (KeyHash = SHA-256 hex of the raw key)
|
||||||
|
store := security.NewConfigKeyStore([]security.UserKey{
|
||||||
|
{
|
||||||
|
UserID: 1,
|
||||||
|
KeyType: security.KeyTypeGenericAPI,
|
||||||
|
KeyHash: "e3b0c44298fc1c149afb...", // sha256(rawKey)
|
||||||
|
Name: "CI deploy",
|
||||||
|
Scopes: []string{"deploy"},
|
||||||
|
IsActive: true,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### DatabaseKeyStore
|
||||||
|
|
||||||
|
Backed by PostgreSQL stored procedures. Supports optional caching (default 2-minute TTL). Apply `keystore_schema.sql` before use.
|
||||||
|
|
||||||
|
```go
|
||||||
|
db, _ := sql.Open("postgres", dsn)
|
||||||
|
|
||||||
|
store := security.NewDatabaseKeyStore(db)
|
||||||
|
|
||||||
|
// With options
|
||||||
|
store = security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
|
||||||
|
CacheTTL: 5 * time.Minute,
|
||||||
|
SQLNames: &security.KeyStoreSQLNames{
|
||||||
|
ValidateKey: "myapp_keystore_validate", // override one procedure name
|
||||||
|
},
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## Managing keys
|
||||||
|
|
||||||
|
```go
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Create — raw key returned once; store it securely
|
||||||
|
resp, err := store.CreateKey(ctx, security.CreateKeyRequest{
|
||||||
|
UserID: 42,
|
||||||
|
KeyType: security.KeyTypeGenericAPI,
|
||||||
|
Name: "mobile app",
|
||||||
|
Scopes: []string{"read", "write"},
|
||||||
|
})
|
||||||
|
fmt.Println(resp.RawKey) // only shown here; hashed internally
|
||||||
|
|
||||||
|
// List
|
||||||
|
keys, err := store.GetUserKeys(ctx, 42, "") // "" = all types
|
||||||
|
keys, err = store.GetUserKeys(ctx, 42, security.KeyTypeGenericAPI)
|
||||||
|
|
||||||
|
// Revoke
|
||||||
|
err = store.DeleteKey(ctx, 42, resp.Key.ID)
|
||||||
|
|
||||||
|
// Validate (used by authenticators internally)
|
||||||
|
key, err := store.ValidateKey(ctx, rawKey, "")
|
||||||
|
```
|
||||||
|
|
||||||
|
## HTTP authentication
|
||||||
|
|
||||||
|
`KeyStoreAuthenticator` wraps any `KeyStore` and implements the `Authenticator` interface. It is drop-in compatible with `DatabaseAuthenticator` and works in `CompositeSecurityProvider`.
|
||||||
|
|
||||||
|
Keys are extracted from the request in this order:
|
||||||
|
|
||||||
|
1. `Authorization: Bearer <key>`
|
||||||
|
2. `Authorization: ApiKey <key>`
|
||||||
|
3. `X-API-Key: <key>`
|
||||||
|
|
||||||
|
```go
|
||||||
|
auth := security.NewKeyStoreAuthenticator(store, "") // "" = accept any key type
|
||||||
|
// Restrict to a specific type:
|
||||||
|
auth = security.NewKeyStoreAuthenticator(store, security.KeyTypeGenericAPI)
|
||||||
|
```
|
||||||
|
|
||||||
|
Plug it into a handler:
|
||||||
|
|
||||||
|
```go
|
||||||
|
handler := resolvespec.NewHandler(db, registry,
|
||||||
|
resolvespec.WithAuthenticator(auth),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
`Login` and `Logout` return an error — key lifecycle is managed through `KeyStore` directly.
|
||||||
|
|
||||||
|
On successful validation the request context receives a `UserContext` where:
|
||||||
|
|
||||||
|
- `UserID` — from the key
|
||||||
|
- `Roles` — the key's `Scopes`
|
||||||
|
- `Claims["key_type"]` — key type string
|
||||||
|
- `Claims["key_name"]` — key name
|
||||||
|
|
||||||
|
## Database setup
|
||||||
|
|
||||||
|
Apply `keystore_schema.sql` to your PostgreSQL database. It requires the `users` table from the main `database_schema.sql`.
|
||||||
|
|
||||||
|
```sql
|
||||||
|
\i pkg/security/keystore_schema.sql
|
||||||
|
```
|
||||||
|
|
||||||
|
This creates:
|
||||||
|
|
||||||
|
- `user_keys` table with indexes on `user_id`, `key_hash`, and `key_type`
|
||||||
|
- `resolvespec_keystore_get_user_keys(p_user_id, p_key_type)`
|
||||||
|
- `resolvespec_keystore_create_key(p_request jsonb)`
|
||||||
|
- `resolvespec_keystore_delete_key(p_user_id, p_key_id)`
|
||||||
|
- `resolvespec_keystore_validate_key(p_key_hash, p_key_type)`
|
||||||
|
|
||||||
|
### Custom procedure names
|
||||||
|
|
||||||
|
```go
|
||||||
|
store := security.NewDatabaseKeyStore(db, security.DatabaseKeyStoreOptions{
|
||||||
|
SQLNames: &security.KeyStoreSQLNames{
|
||||||
|
GetUserKeys: "myschema_get_keys",
|
||||||
|
CreateKey: "myschema_create_key",
|
||||||
|
DeleteKey: "myschema_delete_key",
|
||||||
|
ValidateKey: "myschema_validate_key",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Validate names at startup
|
||||||
|
names := &security.KeyStoreSQLNames{
|
||||||
|
GetUserKeys: "myschema_get_keys",
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
if err := security.ValidateKeyStoreSQLNames(names); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Security notes
|
||||||
|
|
||||||
|
- Raw keys are never stored. Only the SHA-256 hex digest is persisted.
|
||||||
|
- The raw key is generated with `crypto/rand` (32 bytes, base64url-encoded) and returned exactly once in `CreateKeyResponse.RawKey`.
|
||||||
|
- Hash comparisons in `ConfigKeyStore` use `crypto/subtle.ConstantTimeCompare` to prevent timing side-channels.
|
||||||
|
- `DeleteKey` performs a soft delete (`is_active = false`). The `DatabaseKeyStore` invalidates the cache entry immediately, but due to the cache TTL a revoked key may authenticate for up to `CacheTTL` (default 2 minutes) in a distributed environment. Set `CacheTTL: 0` to disable caching if immediate revocation is required.
|
||||||
@@ -258,11 +258,8 @@ func (a *JWTAuthenticator) Login(ctx context.Context, req security.LoginRequest)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
|
func (a *JWTAuthenticator) Logout(ctx context.Context, req security.LogoutRequest) error {
|
||||||
// Add to blacklist
|
// Invalidate session via stored procedure
|
||||||
return a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]any{
|
return nil
|
||||||
"token": req.Token,
|
|
||||||
"user_id": req.UserID,
|
|
||||||
}).Error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
|
func (a *JWTAuthenticator) Authenticate(r *http.Request) (*security.UserContext, error) {
|
||||||
@@ -405,11 +402,16 @@ assert.Equal(t, "user_id = {UserID}", row.Template)
|
|||||||
```
|
```
|
||||||
HTTP Request
|
HTTP Request
|
||||||
↓
|
↓
|
||||||
NewAuthMiddleware → calls provider.Authenticate()
|
NewOptionalAuthMiddleware → calls provider.Authenticate()
|
||||||
↓ (adds UserContext to context)
|
↓ (adds UserContext or guest context; never 401)
|
||||||
SetSecurityMiddleware → adds SecurityList to context
|
SetSecurityMiddleware → adds SecurityList to context
|
||||||
↓
|
↓
|
||||||
Handler.Handle()
|
Handler.Handle() → resolves model
|
||||||
|
↓
|
||||||
|
BeforeHandle Hook → CheckModelAuthAllowed(secCtx, operation)
|
||||||
|
├─ SecurityDisabled → allow
|
||||||
|
├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
|
||||||
|
└─ UserID == 0 → abort 401
|
||||||
↓
|
↓
|
||||||
BeforeRead Hook → calls provider.GetColumnSecurity() + GetRowSecurity()
|
BeforeRead Hook → calls provider.GetColumnSecurity() + GetRowSecurity()
|
||||||
↓
|
↓
|
||||||
@@ -693,15 +695,30 @@ http.Handle("/api/protected", authHandler)
|
|||||||
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
|
optionalHandler := security.NewOptionalAuthHandler(securityList, myHandler)
|
||||||
http.Handle("/home", optionalHandler)
|
http.Handle("/home", optionalHandler)
|
||||||
|
|
||||||
// Example handler
|
// NewOptionalAuthMiddleware - For spec routes; auth enforcement deferred to BeforeHandle
|
||||||
func myHandler(w http.ResponseWriter, r *http.Request) {
|
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
|
||||||
userCtx, _ := security.GetUserContext(r.Context())
|
apiRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||||
if userCtx.UserID == 0 {
|
restheadspec.RegisterSecurityHooks(handler, securityList) // includes BeforeHandle
|
||||||
// Guest user
|
```
|
||||||
} else {
|
|
||||||
// Authenticated user
|
---
|
||||||
}
|
|
||||||
}
|
## Model-Level Access Control
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Register model with rules (pkg/modelregistry)
|
||||||
|
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
|
||||||
|
SecurityDisabled: false, // skip all auth when true
|
||||||
|
CanPublicRead: true, // unauthenticated reads allowed
|
||||||
|
CanPublicCreate: false, // requires auth
|
||||||
|
CanPublicUpdate: false, // requires auth
|
||||||
|
CanPublicDelete: false, // requires auth
|
||||||
|
CanUpdate: true, // authenticated can update
|
||||||
|
CanDelete: false, // authenticated cannot delete (enforced in BeforeDelete)
|
||||||
|
})
|
||||||
|
|
||||||
|
// CheckModelAuthAllowed used automatically in BeforeHandle hook
|
||||||
|
// No code needed — call RegisterSecurityHooks and it's applied
|
||||||
```
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
+279
-5
@@ -12,6 +12,8 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
|||||||
- ✅ **Testable** - Easy to mock and test
|
- ✅ **Testable** - Easy to mock and test
|
||||||
- ✅ **Extensible** - Implement custom providers for your needs
|
- ✅ **Extensible** - Implement custom providers for your needs
|
||||||
- ✅ **Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability
|
- ✅ **Stored Procedures** - All database operations use PostgreSQL stored procedures for security and maintainability
|
||||||
|
- ✅ **OAuth2 Authorization Server** - Built-in OAuth 2.1 + PKCE server (RFC 8414, 7591, 7009, 7662) with login form and external provider federation
|
||||||
|
- ✅ **Password Reset** - Self-service password reset with secure token generation and session invalidation
|
||||||
|
|
||||||
## Stored Procedure Architecture
|
## Stored Procedure Architecture
|
||||||
|
|
||||||
@@ -38,6 +40,14 @@ Type-safe, composable security system for ResolveSpec with support for authentic
|
|||||||
| `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator |
|
| `resolvespec_jwt_logout` | JWT token blacklist | JWTAuthenticator |
|
||||||
| `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider |
|
| `resolvespec_column_security` | Load column rules | DatabaseColumnSecurityProvider |
|
||||||
| `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider |
|
| `resolvespec_row_security` | Load row templates | DatabaseRowSecurityProvider |
|
||||||
|
| `resolvespec_oauth_register_client` | Persist OAuth2 client (RFC 7591) | OAuthServer / DatabaseAuthenticator |
|
||||||
|
| `resolvespec_oauth_get_client` | Retrieve OAuth2 client by ID | OAuthServer / DatabaseAuthenticator |
|
||||||
|
| `resolvespec_oauth_save_code` | Persist authorization code | OAuthServer / DatabaseAuthenticator |
|
||||||
|
| `resolvespec_oauth_exchange_code` | Consume authorization code (single-use) | OAuthServer / DatabaseAuthenticator |
|
||||||
|
| `resolvespec_oauth_introspect` | Token introspection (RFC 7662) | OAuthServer / DatabaseAuthenticator |
|
||||||
|
| `resolvespec_oauth_revoke` | Token revocation (RFC 7009) | OAuthServer / DatabaseAuthenticator |
|
||||||
|
| `resolvespec_password_reset_request` | Create password reset token | DatabaseAuthenticator |
|
||||||
|
| `resolvespec_password_reset` | Validate token and set new password | DatabaseAuthenticator |
|
||||||
|
|
||||||
See `database_schema.sql` for complete stored procedure definitions and examples.
|
See `database_schema.sql` for complete stored procedure definitions and examples.
|
||||||
|
|
||||||
@@ -751,14 +761,25 @@ resolvespec.RegisterSecurityHooks(resolveHandler, securityList)
|
|||||||
```
|
```
|
||||||
HTTP Request
|
HTTP Request
|
||||||
↓
|
↓
|
||||||
NewAuthMiddleware (security package)
|
NewOptionalAuthMiddleware (security package) ← recommended for spec routes
|
||||||
├─ Calls provider.Authenticate(request)
|
├─ Calls provider.Authenticate(request)
|
||||||
└─ Adds UserContext to context
|
├─ On success: adds authenticated UserContext to context
|
||||||
|
└─ On failure: adds guest UserContext (UserID=0) to context
|
||||||
↓
|
↓
|
||||||
SetSecurityMiddleware (security package)
|
SetSecurityMiddleware (security package)
|
||||||
└─ Adds SecurityList to context
|
└─ Adds SecurityList to context
|
||||||
↓
|
↓
|
||||||
Spec Handler (restheadspec/funcspec/resolvespec)
|
Spec Handler (restheadspec/funcspec/resolvespec/websocketspec/mqttspec)
|
||||||
|
└─ Resolves schema + entity + model from request
|
||||||
|
↓
|
||||||
|
BeforeHandle Hook (registered by spec via RegisterSecurityHooks)
|
||||||
|
├─ Adapts spec's HookContext → SecurityContext
|
||||||
|
├─ Calls security.CheckModelAuthAllowed(secCtx, operation)
|
||||||
|
│ ├─ Loads model rules from context or registry
|
||||||
|
│ ├─ SecurityDisabled → allow
|
||||||
|
│ ├─ CanPublicRead/Create/Update/Delete → allow unauthenticated
|
||||||
|
│ └─ UserID == 0 → 401 unauthorized
|
||||||
|
└─ On error: aborts with 401
|
||||||
↓
|
↓
|
||||||
BeforeRead Hook (registered by spec)
|
BeforeRead Hook (registered by spec)
|
||||||
├─ Adapts spec's HookContext → SecurityContext
|
├─ Adapts spec's HookContext → SecurityContext
|
||||||
@@ -784,7 +805,8 @@ HTTP Response (secured data)
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Key Points:**
|
**Key Points:**
|
||||||
- Security package is spec-agnostic and provides core logic
|
- `NewOptionalAuthMiddleware` never rejects — it sets guest context on auth failure; `BeforeHandle` enforces auth after model resolution
|
||||||
|
- `BeforeHandle` fires after model resolution, giving access to model rules and user context simultaneously
|
||||||
- Each spec registers its own hooks that adapt to SecurityContext
|
- Each spec registers its own hooks that adapt to SecurityContext
|
||||||
- Security rules are loaded once and cached for the request
|
- Security rules are loaded once and cached for the request
|
||||||
- Row security is applied to the query (database level)
|
- Row security is applied to the query (database level)
|
||||||
@@ -885,6 +907,216 @@ securityList := security.NewSecurityList(provider)
|
|||||||
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
|
restheadspec.RegisterSecurityHooks(handler, securityList) // or funcspec/resolvespec
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Password Reset
|
||||||
|
|
||||||
|
`DatabaseAuthenticator` implements `PasswordResettable` for self-service password reset.
|
||||||
|
|
||||||
|
### Flow
|
||||||
|
|
||||||
|
1. User submits email or username → `RequestPasswordReset` → server generates a token and returns it for out-of-band delivery (email, SMS, etc.)
|
||||||
|
2. User submits the raw token + new password → `CompletePasswordReset` → password updated, all sessions invalidated
|
||||||
|
|
||||||
|
### DB Requirements
|
||||||
|
|
||||||
|
Run the migrations in `database_schema.sql`:
|
||||||
|
- `user_password_resets` table (`user_id`, `token_hash` SHA-256, `expires_at`, `used`, `used_at`)
|
||||||
|
- `resolvespec_password_reset_request` stored procedure
|
||||||
|
- `resolvespec_password_reset` stored procedure
|
||||||
|
|
||||||
|
Requires the `pgcrypto` extension (`gen_random_bytes`, `digest`) — already used by `resolvespec_login`.
|
||||||
|
|
||||||
|
### Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
auth := security.NewDatabaseAuthenticator(db)
|
||||||
|
|
||||||
|
// Step 1 — initiate reset (call after user submits their email)
|
||||||
|
resp, err := auth.RequestPasswordReset(ctx, security.PasswordResetRequest{
|
||||||
|
Email: "user@example.com",
|
||||||
|
})
|
||||||
|
// resp.Token is the raw token — deliver it out-of-band
|
||||||
|
// resp.ExpiresIn is 3600 (1 hour)
|
||||||
|
// Always returns success regardless of whether the user exists (anti-enumeration)
|
||||||
|
|
||||||
|
// Step 2 — complete reset (call after user submits token + new password)
|
||||||
|
err = auth.CompletePasswordReset(ctx, security.PasswordResetCompleteRequest{
|
||||||
|
Token: rawToken,
|
||||||
|
NewPassword: "newSecurePassword",
|
||||||
|
})
|
||||||
|
// On success: password updated, all active sessions deleted
|
||||||
|
```
|
||||||
|
|
||||||
|
### Security Notes
|
||||||
|
|
||||||
|
- The raw token is never stored; only its SHA-256 hash is persisted
|
||||||
|
- Requesting a reset invalidates any previous unused tokens for that user
|
||||||
|
- Tokens expire after 1 hour
|
||||||
|
- Completing a reset deletes all active sessions, forcing re-login
|
||||||
|
- `RequestPasswordReset` always returns success even when the email/username is not found, preventing user enumeration
|
||||||
|
- Hash the new password with bcrypt before storing (pgcrypto `crypt`/`gen_salt`) — see the TODO comment in `resolvespec_password_reset`
|
||||||
|
|
||||||
|
### SQLNames
|
||||||
|
|
||||||
|
```go
|
||||||
|
type SQLNames struct {
|
||||||
|
// ...
|
||||||
|
PasswordResetRequest string // default: "resolvespec_password_reset_request"
|
||||||
|
PasswordResetComplete string // default: "resolvespec_password_reset"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## OAuth2 Authorization Server
|
||||||
|
|
||||||
|
`OAuthServer` is a generic OAuth 2.1 + PKCE authorization server. It is not tied to any spec — `pkg/resolvemcp` uses it, but it can be used standalone with any `http.ServeMux`.
|
||||||
|
|
||||||
|
### Endpoints
|
||||||
|
|
||||||
|
| Method | Path | RFC |
|
||||||
|
|--------|------|-----|
|
||||||
|
| `GET` | `/.well-known/oauth-authorization-server` | RFC 8414 — server metadata |
|
||||||
|
| `POST` | `/oauth/register` | RFC 7591 — dynamic client registration |
|
||||||
|
| `GET` | `/oauth/authorize` | OAuth 2.1 — start authorization / provider selection |
|
||||||
|
| `POST` | `/oauth/authorize` | OAuth 2.1 — login form submission |
|
||||||
|
| `POST` | `/oauth/token` | OAuth 2.1 — code exchange + refresh |
|
||||||
|
| `POST` | `/oauth/revoke` | RFC 7009 — token revocation |
|
||||||
|
| `POST` | `/oauth/introspect` | RFC 7662 — token introspection |
|
||||||
|
| `GET` | `{ProviderCallbackPath}` | External provider redirect target |
|
||||||
|
|
||||||
|
### Config
|
||||||
|
|
||||||
|
```go
|
||||||
|
cfg := security.OAuthServerConfig{
|
||||||
|
Issuer: "https://example.com", // Required — token issuer URL
|
||||||
|
ProviderCallbackPath: "/oauth/provider/callback", // External provider redirect target
|
||||||
|
LoginTitle: "My App Login", // HTML login page title
|
||||||
|
PersistClients: true, // Store clients in DB (multi-instance safe)
|
||||||
|
PersistCodes: true, // Store codes in DB (multi-instance safe)
|
||||||
|
DefaultScopes: []string{"openid", "profile"}, // Returned when no scope requested
|
||||||
|
AccessTokenTTL: time.Hour,
|
||||||
|
AuthCodeTTL: 5 * time.Minute,
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| Field | Default | Notes |
|
||||||
|
|-------|---------|-------|
|
||||||
|
| `Issuer` | — | Required; trailing slash is trimmed automatically |
|
||||||
|
| `ProviderCallbackPath` | `/oauth/provider/callback` | |
|
||||||
|
| `LoginTitle` | `"Sign in"` | |
|
||||||
|
| `PersistClients` | `false` | Set `true` for multi-instance |
|
||||||
|
| `PersistCodes` | `false` | Set `true` for multi-instance; does not require `PersistClients` |
|
||||||
|
| `DefaultScopes` | `["openid","profile","email"]` | |
|
||||||
|
| `AccessTokenTTL` | `24h` | Also used as `expires_in` in token responses |
|
||||||
|
| `AuthCodeTTL` | `2m` | |
|
||||||
|
|
||||||
|
### Operating Modes
|
||||||
|
|
||||||
|
**Mode 1 — Direct login (username/password form)**
|
||||||
|
|
||||||
|
Pass a `*DatabaseAuthenticator` to `NewOAuthServer`. The server renders a login form at `GET /oauth/authorize` and issues tokens via the stored session after login.
|
||||||
|
|
||||||
|
```go
|
||||||
|
auth := security.NewDatabaseAuthenticator(db)
|
||||||
|
srv := security.NewOAuthServer(cfg, auth)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Mode 2 — External provider federation**
|
||||||
|
|
||||||
|
Pass a `*DatabaseAuthenticator` for persistence (authorization codes, revoke, introspect) and register external providers. The authorize endpoint redirects to the specified provider (via the `provider` query param) or to the first registered provider by default.
|
||||||
|
|
||||||
|
```go
|
||||||
|
auth := security.NewDatabaseAuthenticator(db)
|
||||||
|
srv := security.NewOAuthServer(cfg, auth)
|
||||||
|
srv.RegisterExternalProvider(googleAuth, "google")
|
||||||
|
srv.RegisterExternalProvider(githubAuth, "github")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Mode 3 — Both**
|
||||||
|
|
||||||
|
Pass auth for the login form and also register external providers. The authorize page shows both a login form and provider buttons.
|
||||||
|
|
||||||
|
```go
|
||||||
|
srv := security.NewOAuthServer(cfg, auth)
|
||||||
|
srv.RegisterExternalProvider(googleAuth, "google")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Standalone Usage
|
||||||
|
|
||||||
|
```go
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.Handle("/.well-known/", srv.HTTPHandler())
|
||||||
|
mux.Handle("/oauth/", srv.HTTPHandler())
|
||||||
|
mux.Handle(cfg.ProviderCallbackPath, srv.HTTPHandler())
|
||||||
|
|
||||||
|
http.ListenAndServe(":8080", mux)
|
||||||
|
```
|
||||||
|
|
||||||
|
### DB Persistence
|
||||||
|
|
||||||
|
When `PersistClients: true` or `PersistCodes: true`, the server calls the corresponding `DatabaseAuthenticator` methods. Both flags default to `false` (in-memory maps). Enable both for multi-instance deployments.
|
||||||
|
|
||||||
|
Requires `oauth_clients` and `oauth_codes` tables + 6 stored procedures from `database_schema.sql`.
|
||||||
|
|
||||||
|
#### New DB Types
|
||||||
|
|
||||||
|
```go
|
||||||
|
type OAuthServerClient struct {
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
RedirectURIs []string `json:"redirect_uris"`
|
||||||
|
ClientName string `json:"client_name,omitempty"`
|
||||||
|
GrantTypes []string `json:"grant_types"`
|
||||||
|
AllowedScopes []string `json:"allowed_scopes,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthCode struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
ClientID string `json:"client_id"`
|
||||||
|
RedirectURI string `json:"redirect_uri"`
|
||||||
|
ClientState string `json:"client_state,omitempty"`
|
||||||
|
CodeChallenge string `json:"code_challenge"`
|
||||||
|
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||||
|
SessionToken string `json:"session_token"`
|
||||||
|
Scopes []string `json:"scopes,omitempty"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthTokenInfo struct {
|
||||||
|
Active bool `json:"active"`
|
||||||
|
Sub string `json:"sub,omitempty"`
|
||||||
|
Username string `json:"username,omitempty"`
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
Roles []string `json:"roles,omitempty"`
|
||||||
|
Exp int64 `json:"exp,omitempty"`
|
||||||
|
Iat int64 `json:"iat,omitempty"`
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### DatabaseAuthenticator OAuth Methods
|
||||||
|
|
||||||
|
```go
|
||||||
|
auth.OAuthRegisterClient(ctx, client) // RFC 7591 — persist client
|
||||||
|
auth.OAuthGetClient(ctx, clientID) // retrieve client
|
||||||
|
auth.OAuthSaveCode(ctx, code) // persist authorization code
|
||||||
|
auth.OAuthExchangeCode(ctx, code) // consume code (single-use, deletes on read)
|
||||||
|
auth.OAuthIntrospectToken(ctx, token) // RFC 7662 — returns OAuthTokenInfo
|
||||||
|
auth.OAuthRevokeToken(ctx, token) // RFC 7009 — revoke session
|
||||||
|
```
|
||||||
|
|
||||||
|
#### SQLNames Fields
|
||||||
|
|
||||||
|
```go
|
||||||
|
type SQLNames struct {
|
||||||
|
// ... existing fields ...
|
||||||
|
OAuthRegisterClient string // default: "resolvespec_oauth_register_client"
|
||||||
|
OAuthGetClient string // default: "resolvespec_oauth_get_client"
|
||||||
|
OAuthSaveCode string // default: "resolvespec_oauth_save_code"
|
||||||
|
OAuthExchangeCode string // default: "resolvespec_oauth_exchange_code"
|
||||||
|
OAuthIntrospect string // default: "resolvespec_oauth_introspect"
|
||||||
|
OAuthRevoke string // default: "resolvespec_oauth_revoke"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
The main changes:
|
The main changes:
|
||||||
1. Security package no longer knows about specific spec types
|
1. Security package no longer knows about specific spec types
|
||||||
2. Each spec registers its own security hooks
|
2. Each spec registers its own security hooks
|
||||||
@@ -941,6 +1173,14 @@ type Cacheable interface {
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**PasswordResettable** - Self-service password reset:
|
||||||
|
```go
|
||||||
|
type PasswordResettable interface {
|
||||||
|
RequestPasswordReset(ctx context.Context, req PasswordResetRequest) (*PasswordResetResponse, error)
|
||||||
|
CompletePasswordReset(ctx context.Context, req PasswordResetCompleteRequest) error
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
## Benefits Over Callbacks
|
## Benefits Over Callbacks
|
||||||
|
|
||||||
| Feature | Old (Callbacks) | New (Interfaces) |
|
| Feature | Old (Callbacks) | New (Interfaces) |
|
||||||
@@ -1002,15 +1242,49 @@ func (p *MyProvider) GetRowSecurity(ctx context.Context, userID int, schema, tab
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Model-Level Access Control
|
||||||
|
|
||||||
|
Use `ModelRules` (from `pkg/modelregistry`) to control per-entity auth behavior:
|
||||||
|
|
||||||
|
```go
|
||||||
|
modelregistry.RegisterModelWithRules("public.products", &Product{}, modelregistry.ModelRules{
|
||||||
|
SecurityDisabled: false, // true = skip all auth checks
|
||||||
|
CanPublicRead: true, // unauthenticated GET allowed
|
||||||
|
CanPublicCreate: false, // requires auth
|
||||||
|
CanPublicUpdate: false, // requires auth
|
||||||
|
CanPublicDelete: false, // requires auth
|
||||||
|
CanUpdate: true, // authenticated users can update
|
||||||
|
CanDelete: false, // authenticated users cannot delete
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
`CheckModelAuthAllowed(secCtx, operation)` applies these rules in `BeforeHandle`:
|
||||||
|
1. `SecurityDisabled` → allow all
|
||||||
|
2. `CanPublicRead/Create/Update/Delete` → allow unauthenticated for that operation
|
||||||
|
3. Guest (UserID == 0) → return 401
|
||||||
|
4. Authenticated → allow (operation-specific `CanUpdate`/`CanDelete` checked in `BeforeUpdate`/`BeforeDelete`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Middleware and Handler API
|
## Middleware and Handler API
|
||||||
|
|
||||||
### NewAuthMiddleware
|
### NewAuthMiddleware
|
||||||
Standard middleware that authenticates all requests:
|
Standard middleware that authenticates all requests and returns 401 on failure:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
router.Use(security.NewAuthMiddleware(securityList))
|
router.Use(security.NewAuthMiddleware(securityList))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### NewOptionalAuthMiddleware
|
||||||
|
Middleware for spec routes — always continues; sets guest context on auth failure:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Use with RegisterSecurityHooks — auth enforcement is deferred to BeforeHandle
|
||||||
|
apiRouter.Use(security.NewOptionalAuthMiddleware(securityList))
|
||||||
|
apiRouter.Use(security.SetSecurityMiddleware(securityList))
|
||||||
|
restheadspec.RegisterSecurityHooks(handler, securityList) // registers BeforeHandle
|
||||||
|
```
|
||||||
|
|
||||||
Routes can skip authentication using the `SkipAuth` helper:
|
Routes can skip authentication using the `SkipAuth` helper:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
|
|||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ChainAuthenticator tries each authenticator in order, returning the first success.
|
||||||
|
// Login and Logout are delegated to the primary authenticator.
|
||||||
|
type ChainAuthenticator struct {
|
||||||
|
authenticators []Authenticator
|
||||||
|
authenticateCallback func(r *http.Request) (*UserContext, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewChainAuthenticator creates a ChainAuthenticator from the given authenticators.
|
||||||
|
// At least one authenticator is required; the first is treated as primary for Login/Logout.
|
||||||
|
func NewChainAuthenticator(primary Authenticator, rest ...Authenticator) *ChainAuthenticator {
|
||||||
|
return &ChainAuthenticator{
|
||||||
|
authenticators: append([]Authenticator{primary}, rest...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChainAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||||
|
var lastErr error
|
||||||
|
for _, a := range c.authenticators {
|
||||||
|
if uc, err := a.Authenticate(r); err == nil {
|
||||||
|
return uc, nil
|
||||||
|
} else {
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.authenticateCallback != nil {
|
||||||
|
return c.authenticateCallback(r)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("all authenticators failed; last error: %w", lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChainAuthenticator) SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error)) {
|
||||||
|
c.authenticateCallback = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChainAuthenticator) Login(ctx context.Context, req LoginRequest) (*LoginResponse, error) {
|
||||||
|
return c.authenticators[0].Login(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChainAuthenticator) LoginWithCookie(ctx context.Context, req LoginRequest, w http.ResponseWriter) (*LoginResponse, error) {
|
||||||
|
return c.authenticators[0].LoginWithCookie(ctx, req, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChainAuthenticator) Logout(ctx context.Context, req LogoutRequest) error {
|
||||||
|
return c.authenticators[0].Logout(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChainAuthenticator) LogoutWithCookie(ctx context.Context, req LogoutRequest, w http.ResponseWriter) error {
|
||||||
|
return c.authenticators[0].LogoutWithCookie(ctx, req, w)
|
||||||
|
}
|
||||||
@@ -0,0 +1,127 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// stubAuthenticator is a configurable Authenticator for testing.
|
||||||
|
type stubAuthenticator struct {
|
||||||
|
userCtx *UserContext
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAuthenticator) Authenticate(_ *http.Request) (*UserContext, error) {
|
||||||
|
return s.userCtx, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAuthenticator) Login(_ context.Context, _ LoginRequest) (*LoginResponse, error) {
|
||||||
|
if s.err != nil {
|
||||||
|
return nil, s.err
|
||||||
|
}
|
||||||
|
return &LoginResponse{Token: "tok"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAuthenticator) LoginWithCookie(ctx context.Context, req LoginRequest, _ http.ResponseWriter) (*LoginResponse, error) {
|
||||||
|
return s.Login(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAuthenticator) Logout(_ context.Context, _ LogoutRequest) error {
|
||||||
|
return s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAuthenticator) LogoutWithCookie(ctx context.Context, req LogoutRequest, _ http.ResponseWriter) error {
|
||||||
|
return s.Logout(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubAuthenticator) SetAuthenticateCallback(_ func(r *http.Request) (*UserContext, error)) {}
|
||||||
|
|
||||||
|
func TestChainAuthenticator_Authenticate(t *testing.T) {
|
||||||
|
successCtx := &UserContext{UserID: 42, UserName: "alice"}
|
||||||
|
failStub := &stubAuthenticator{err: fmt.Errorf("no token")}
|
||||||
|
okStub := &stubAuthenticator{userCtx: successCtx}
|
||||||
|
|
||||||
|
t.Run("primary succeeds", func(t *testing.T) {
|
||||||
|
chain := NewChainAuthenticator(okStub, failStub)
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
|
uc, err := chain.Authenticate(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if uc.UserID != 42 {
|
||||||
|
t.Errorf("expected UserID 42, got %d", uc.UserID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("primary fails, secondary succeeds", func(t *testing.T) {
|
||||||
|
chain := NewChainAuthenticator(failStub, okStub)
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
|
uc, err := chain.Authenticate(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if uc.UserID != 42 {
|
||||||
|
t.Errorf("expected UserID 42, got %d", uc.UserID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("all fail", func(t *testing.T) {
|
||||||
|
chain := NewChainAuthenticator(failStub, failStub)
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
|
_, err := chain.Authenticate(req)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when all authenticators fail")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("three in chain, first two fail", func(t *testing.T) {
|
||||||
|
chain := NewChainAuthenticator(failStub, failStub, okStub)
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
|
||||||
|
uc, err := chain.Authenticate(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if uc.UserName != "alice" {
|
||||||
|
t.Errorf("expected UserName alice, got %s", uc.UserName)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChainAuthenticator_LoginLogout(t *testing.T) {
|
||||||
|
primary := &stubAuthenticator{userCtx: &UserContext{UserID: 1}}
|
||||||
|
secondary := &stubAuthenticator{userCtx: &UserContext{UserID: 2}}
|
||||||
|
chain := NewChainAuthenticator(primary, secondary)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("login delegates to primary", func(t *testing.T) {
|
||||||
|
resp, err := chain.Login(ctx, LoginRequest{Username: "u", Password: "p"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if resp.Token != "tok" {
|
||||||
|
t.Errorf("expected token from primary, got %s", resp.Token)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("logout delegates to primary", func(t *testing.T) {
|
||||||
|
if err := chain.Logout(ctx, LogoutRequest{}); err != nil {
|
||||||
|
t.Fatalf("expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("login error from primary is returned", func(t *testing.T) {
|
||||||
|
failPrimary := &stubAuthenticator{err: fmt.Errorf("db down")}
|
||||||
|
chain2 := NewChainAuthenticator(failPrimary, secondary)
|
||||||
|
_, err := chain2.Login(ctx, LoginRequest{})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error from primary login failure")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -43,16 +43,31 @@ func (c *CompositeSecurityProvider) Login(ctx context.Context, req LoginRequest)
|
|||||||
return c.auth.Login(ctx, req)
|
return c.auth.Login(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoginWithCookie delegates to the authenticator
|
||||||
|
func (c *CompositeSecurityProvider) LoginWithCookie(ctx context.Context, req LoginRequest, w http.ResponseWriter) (*LoginResponse, error) {
|
||||||
|
return c.auth.LoginWithCookie(ctx, req, w)
|
||||||
|
}
|
||||||
|
|
||||||
// Logout delegates to the authenticator
|
// Logout delegates to the authenticator
|
||||||
func (c *CompositeSecurityProvider) Logout(ctx context.Context, req LogoutRequest) error {
|
func (c *CompositeSecurityProvider) Logout(ctx context.Context, req LogoutRequest) error {
|
||||||
return c.auth.Logout(ctx, req)
|
return c.auth.Logout(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LogoutWithCookie delegates to the authenticator
|
||||||
|
func (c *CompositeSecurityProvider) LogoutWithCookie(ctx context.Context, req LogoutRequest, w http.ResponseWriter) error {
|
||||||
|
return c.auth.LogoutWithCookie(ctx, req, w)
|
||||||
|
}
|
||||||
|
|
||||||
// Authenticate delegates to the authenticator
|
// Authenticate delegates to the authenticator
|
||||||
func (c *CompositeSecurityProvider) Authenticate(r *http.Request) (*UserContext, error) {
|
func (c *CompositeSecurityProvider) Authenticate(r *http.Request) (*UserContext, error) {
|
||||||
return c.auth.Authenticate(r)
|
return c.auth.Authenticate(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAuthenticateCallback delegates to the authenticator
|
||||||
|
func (c *CompositeSecurityProvider) SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error)) {
|
||||||
|
c.auth.SetAuthenticateCallback(fn)
|
||||||
|
}
|
||||||
|
|
||||||
// GetColumnSecurity delegates to the column security provider
|
// GetColumnSecurity delegates to the column security provider
|
||||||
func (c *CompositeSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
func (c *CompositeSecurityProvider) GetColumnSecurity(ctx context.Context, userID int, schema, table string) ([]ColumnSecurity, error) {
|
||||||
return c.colSec.GetColumnSecurity(ctx, userID, schema, table)
|
return c.colSec.GetColumnSecurity(ctx, userID, schema, table)
|
||||||
|
|||||||
@@ -23,14 +23,24 @@ func (m *mockAuth) Login(ctx context.Context, req LoginRequest) (*LoginResponse,
|
|||||||
return m.loginResp, m.loginErr
|
return m.loginResp, m.loginErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAuth) LoginWithCookie(ctx context.Context, req LoginRequest, _ http.ResponseWriter) (*LoginResponse, error) {
|
||||||
|
return m.Login(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockAuth) Logout(ctx context.Context, req LogoutRequest) error {
|
func (m *mockAuth) Logout(ctx context.Context, req LogoutRequest) error {
|
||||||
return m.logoutErr
|
return m.logoutErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAuth) LogoutWithCookie(ctx context.Context, req LogoutRequest, _ http.ResponseWriter) error {
|
||||||
|
return m.Logout(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockAuth) Authenticate(r *http.Request) (*UserContext, error) {
|
func (m *mockAuth) Authenticate(r *http.Request) (*UserContext, error) {
|
||||||
return m.authUser, m.authErr
|
return m.authUser, m.authErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockAuth) SetAuthenticateCallback(_ func(r *http.Request) (*UserContext, error)) {}
|
||||||
|
|
||||||
// Optional interface implementations
|
// Optional interface implementations
|
||||||
func (m *mockAuth) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
|
func (m *mockAuth) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
|
||||||
if !m.supportsRefresh {
|
if !m.supportsRefresh {
|
||||||
|
|||||||
@@ -1397,3 +1397,332 @@ $$ LANGUAGE plpgsql;
|
|||||||
|
|
||||||
-- Get credentials by username
|
-- Get credentials by username
|
||||||
-- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin');
|
-- SELECT * FROM resolvespec_passkey_get_credentials_by_username('admin');
|
||||||
|
|
||||||
|
-- ============================================
|
||||||
|
-- Password Reset Tables
|
||||||
|
-- ============================================
|
||||||
|
|
||||||
|
-- Password reset tokens table
|
||||||
|
CREATE TABLE IF NOT EXISTS user_password_resets (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
token_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 hex of the raw token
|
||||||
|
expires_at TIMESTAMP NOT NULL,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
used BOOLEAN DEFAULT false,
|
||||||
|
used_at TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_pw_reset_token_hash ON user_password_resets(token_hash);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_pw_reset_user_id ON user_password_resets(user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_pw_reset_expires_at ON user_password_resets(expires_at);
|
||||||
|
|
||||||
|
-- ============================================
|
||||||
|
-- Stored Procedures for Password Reset
|
||||||
|
-- ============================================
|
||||||
|
|
||||||
|
-- 1. resolvespec_password_reset_request - Creates a password reset token for a user
|
||||||
|
-- Input: p_request jsonb {email: string, username: string}
|
||||||
|
-- Output: p_success (bool), p_error (text), p_data jsonb {token: string, expires_in: int}
|
||||||
|
-- NOTE: The raw token is returned so the caller can deliver it out-of-band (e.g. email).
|
||||||
|
-- Only the SHA-256 hash is stored. Invalidates any previous unused tokens for the user.
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_password_reset_request(p_request jsonb)
|
||||||
|
RETURNS TABLE(p_success boolean, p_error text, p_data jsonb) AS $$
|
||||||
|
DECLARE
|
||||||
|
v_user_id INTEGER;
|
||||||
|
v_email TEXT;
|
||||||
|
v_username TEXT;
|
||||||
|
v_raw_token TEXT;
|
||||||
|
v_token_hash TEXT;
|
||||||
|
v_expires_at TIMESTAMP;
|
||||||
|
BEGIN
|
||||||
|
v_email := p_request->>'email';
|
||||||
|
v_username := p_request->>'username';
|
||||||
|
|
||||||
|
-- Require at least one identifier
|
||||||
|
IF (v_email IS NULL OR v_email = '') AND (v_username IS NULL OR v_username = '') THEN
|
||||||
|
RETURN QUERY SELECT false, 'email or username is required'::text, NULL::jsonb;
|
||||||
|
RETURN;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Look up active user
|
||||||
|
IF v_email IS NOT NULL AND v_email <> '' THEN
|
||||||
|
SELECT id INTO v_user_id FROM users WHERE email = v_email AND is_active = true;
|
||||||
|
ELSE
|
||||||
|
SELECT id INTO v_user_id FROM users WHERE username = v_username AND is_active = true;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Return generic success even when user not found to avoid user enumeration
|
||||||
|
IF NOT FOUND THEN
|
||||||
|
RETURN QUERY SELECT true, NULL::text, jsonb_build_object('token', '', 'expires_in', 0);
|
||||||
|
RETURN;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Invalidate previous unused tokens for this user
|
||||||
|
DELETE FROM user_password_resets WHERE user_id = v_user_id AND used = false;
|
||||||
|
|
||||||
|
-- Generate a random 32-byte token and store its SHA-256 hash
|
||||||
|
v_raw_token := encode(gen_random_bytes(32), 'hex');
|
||||||
|
v_token_hash := encode(digest(v_raw_token, 'sha256'), 'hex');
|
||||||
|
v_expires_at := now() + interval '1 hour';
|
||||||
|
|
||||||
|
INSERT INTO user_password_resets (user_id, token_hash, expires_at)
|
||||||
|
VALUES (v_user_id, v_token_hash, v_expires_at);
|
||||||
|
|
||||||
|
RETURN QUERY SELECT
|
||||||
|
true,
|
||||||
|
NULL::text,
|
||||||
|
jsonb_build_object(
|
||||||
|
'token', v_raw_token,
|
||||||
|
'expires_in', 3600
|
||||||
|
);
|
||||||
|
EXCEPTION
|
||||||
|
WHEN OTHERS THEN
|
||||||
|
RETURN QUERY SELECT false, SQLERRM::text, NULL::jsonb;
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
|
||||||
|
-- 2. resolvespec_password_reset - Validates the token and updates the user's password
|
||||||
|
-- Input: p_request jsonb {token: string, new_password: string}
|
||||||
|
-- Output: p_success (bool), p_error (text)
|
||||||
|
-- NOTE: Hash the new_password with bcrypt before storing (pgcrypto crypt/gen_salt).
|
||||||
|
-- The TODO below mirrors the convention used in resolvespec_register.
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_password_reset(p_request jsonb)
|
||||||
|
RETURNS TABLE(p_success boolean, p_error text) AS $$
|
||||||
|
DECLARE
|
||||||
|
v_raw_token TEXT;
|
||||||
|
v_token_hash TEXT;
|
||||||
|
v_new_pw TEXT;
|
||||||
|
v_reset_id INTEGER;
|
||||||
|
v_user_id INTEGER;
|
||||||
|
v_expires_at TIMESTAMP;
|
||||||
|
BEGIN
|
||||||
|
v_raw_token := p_request->>'token';
|
||||||
|
v_new_pw := p_request->>'new_password';
|
||||||
|
|
||||||
|
IF v_raw_token IS NULL OR v_raw_token = '' THEN
|
||||||
|
RETURN QUERY SELECT false, 'token is required'::text;
|
||||||
|
RETURN;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
IF v_new_pw IS NULL OR v_new_pw = '' THEN
|
||||||
|
RETURN QUERY SELECT false, 'new_password is required'::text;
|
||||||
|
RETURN;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
v_token_hash := encode(digest(v_raw_token, 'sha256'), 'hex');
|
||||||
|
|
||||||
|
-- Find valid, unused reset token
|
||||||
|
SELECT id, user_id, expires_at
|
||||||
|
INTO v_reset_id, v_user_id, v_expires_at
|
||||||
|
FROM user_password_resets
|
||||||
|
WHERE token_hash = v_token_hash AND used = false;
|
||||||
|
|
||||||
|
IF NOT FOUND THEN
|
||||||
|
RETURN QUERY SELECT false, 'invalid or expired token'::text;
|
||||||
|
RETURN;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
IF v_expires_at <= now() THEN
|
||||||
|
RETURN QUERY SELECT false, 'invalid or expired token'::text;
|
||||||
|
RETURN;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- TODO: Hash new password with pgcrypto before storing
|
||||||
|
-- Enable pgcrypto: CREATE EXTENSION IF NOT EXISTS pgcrypto;
|
||||||
|
-- v_new_pw := crypt(v_new_pw, gen_salt('bf'));
|
||||||
|
|
||||||
|
-- Update password and invalidate all sessions
|
||||||
|
UPDATE users SET password = v_new_pw, updated_at = now() WHERE id = v_user_id;
|
||||||
|
DELETE FROM user_sessions WHERE user_id = v_user_id;
|
||||||
|
|
||||||
|
-- Mark token as used
|
||||||
|
UPDATE user_password_resets SET used = true, used_at = now() WHERE id = v_reset_id;
|
||||||
|
|
||||||
|
RETURN QUERY SELECT true, NULL::text;
|
||||||
|
EXCEPTION
|
||||||
|
WHEN OTHERS THEN
|
||||||
|
RETURN QUERY SELECT false, SQLERRM::text;
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
|
||||||
|
-- Example: Test password reset stored procedures
|
||||||
|
-- SELECT * FROM resolvespec_password_reset_request('{"email": "user@example.com"}'::jsonb);
|
||||||
|
-- SELECT * FROM resolvespec_password_reset('{"token": "<raw_token>", "new_password": "newpass123"}'::jsonb);
|
||||||
|
|
||||||
|
-- ============================================
|
||||||
|
-- OAuth2 Server Tables (OAuthServer persistence)
|
||||||
|
-- ============================================
|
||||||
|
|
||||||
|
-- oauth_clients: persistent RFC 7591 registered clients
|
||||||
|
CREATE TABLE IF NOT EXISTS oauth_clients (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
client_id VARCHAR(255) NOT NULL UNIQUE,
|
||||||
|
redirect_uris TEXT[] NOT NULL,
|
||||||
|
client_name VARCHAR(255),
|
||||||
|
grant_types TEXT[] DEFAULT ARRAY['authorization_code'],
|
||||||
|
allowed_scopes TEXT[] DEFAULT ARRAY['openid','profile','email'],
|
||||||
|
is_active BOOLEAN DEFAULT true,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
-- oauth_codes: short-lived authorization codes (for multi-instance deployments)
|
||||||
|
-- Note: client_id is stored without a foreign key so codes can be persisted even
|
||||||
|
-- when OAuth clients are managed in memory rather than persisted in oauth_clients.
|
||||||
|
CREATE TABLE IF NOT EXISTS oauth_codes (
|
||||||
|
id SERIAL PRIMARY KEY,
|
||||||
|
code VARCHAR(255) NOT NULL UNIQUE,
|
||||||
|
client_id VARCHAR(255) NOT NULL,
|
||||||
|
redirect_uri TEXT NOT NULL,
|
||||||
|
client_state TEXT,
|
||||||
|
code_challenge VARCHAR(255) NOT NULL,
|
||||||
|
code_challenge_method VARCHAR(10) DEFAULT 'S256',
|
||||||
|
session_token TEXT NOT NULL,
|
||||||
|
refresh_token TEXT,
|
||||||
|
scopes TEXT[],
|
||||||
|
expires_at TIMESTAMP NOT NULL,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_oauth_codes_code ON oauth_codes(code);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_oauth_codes_expires ON oauth_codes(expires_at);
|
||||||
|
|
||||||
|
-- ============================================
|
||||||
|
-- OAuth2 Server Stored Procedures
|
||||||
|
-- ============================================
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_oauth_register_client(p_data jsonb)
|
||||||
|
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
v_client_id text;
|
||||||
|
v_row jsonb;
|
||||||
|
BEGIN
|
||||||
|
v_client_id := p_data->>'client_id';
|
||||||
|
|
||||||
|
INSERT INTO oauth_clients (client_id, redirect_uris, client_name, grant_types, allowed_scopes)
|
||||||
|
VALUES (
|
||||||
|
v_client_id,
|
||||||
|
ARRAY(SELECT jsonb_array_elements_text(p_data->'redirect_uris')),
|
||||||
|
p_data->>'client_name',
|
||||||
|
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'grant_types')), ARRAY['authorization_code']),
|
||||||
|
COALESCE(ARRAY(SELECT jsonb_array_elements_text(p_data->'allowed_scopes')), ARRAY['openid','profile','email'])
|
||||||
|
)
|
||||||
|
RETURNING to_jsonb(oauth_clients.*) INTO v_row;
|
||||||
|
|
||||||
|
RETURN QUERY SELECT true, null::text, v_row;
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RETURN QUERY SELECT false, SQLERRM, null::jsonb;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_oauth_get_client(p_client_id text)
|
||||||
|
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
v_row jsonb;
|
||||||
|
BEGIN
|
||||||
|
SELECT to_jsonb(oauth_clients.*)
|
||||||
|
INTO v_row
|
||||||
|
FROM oauth_clients
|
||||||
|
WHERE client_id = p_client_id AND is_active = true;
|
||||||
|
|
||||||
|
IF v_row IS NULL THEN
|
||||||
|
RETURN QUERY SELECT false, 'client not found'::text, null::jsonb;
|
||||||
|
ELSE
|
||||||
|
RETURN QUERY SELECT true, null::text, v_row;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_oauth_save_code(p_data jsonb)
|
||||||
|
RETURNS TABLE(p_success bool, p_error text)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
INSERT INTO oauth_codes (code, client_id, redirect_uri, client_state, code_challenge, code_challenge_method, session_token, refresh_token, scopes, expires_at)
|
||||||
|
VALUES (
|
||||||
|
p_data->>'code',
|
||||||
|
p_data->>'client_id',
|
||||||
|
p_data->>'redirect_uri',
|
||||||
|
p_data->>'client_state',
|
||||||
|
p_data->>'code_challenge',
|
||||||
|
COALESCE(p_data->>'code_challenge_method', 'S256'),
|
||||||
|
p_data->>'session_token',
|
||||||
|
p_data->>'refresh_token',
|
||||||
|
ARRAY(SELECT jsonb_array_elements_text(p_data->'scopes')),
|
||||||
|
(p_data->>'expires_at')::timestamp
|
||||||
|
);
|
||||||
|
|
||||||
|
RETURN QUERY SELECT true, null::text;
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RETURN QUERY SELECT false, SQLERRM;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_oauth_exchange_code(p_code text)
|
||||||
|
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
v_row jsonb;
|
||||||
|
BEGIN
|
||||||
|
DELETE FROM oauth_codes
|
||||||
|
WHERE code = p_code AND expires_at > now()
|
||||||
|
RETURNING jsonb_build_object(
|
||||||
|
'client_id', client_id,
|
||||||
|
'redirect_uri', redirect_uri,
|
||||||
|
'client_state', client_state,
|
||||||
|
'code_challenge', code_challenge,
|
||||||
|
'code_challenge_method', code_challenge_method,
|
||||||
|
'session_token', session_token,
|
||||||
|
'refresh_token', refresh_token,
|
||||||
|
'scopes', to_jsonb(scopes)
|
||||||
|
) INTO v_row;
|
||||||
|
|
||||||
|
IF v_row IS NULL THEN
|
||||||
|
RETURN QUERY SELECT false, 'invalid or expired code'::text, null::jsonb;
|
||||||
|
ELSE
|
||||||
|
RETURN QUERY SELECT true, null::text, v_row;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_oauth_introspect(p_token text)
|
||||||
|
RETURNS TABLE(p_success bool, p_error text, p_data jsonb)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
v_row jsonb;
|
||||||
|
BEGIN
|
||||||
|
SELECT jsonb_build_object(
|
||||||
|
'active', true,
|
||||||
|
'sub', u.id::text,
|
||||||
|
'username', u.username,
|
||||||
|
'email', u.email,
|
||||||
|
'user_level', u.user_level,
|
||||||
|
-- NULLIF converts empty string to NULL; string_to_array(NULL) returns NULL;
|
||||||
|
-- to_jsonb(NULL) returns NULL; COALESCE then returns '[]' for NULL/empty roles.
|
||||||
|
'roles', COALESCE(to_jsonb(string_to_array(NULLIF(u.roles, ''), ',')), '[]'::jsonb),
|
||||||
|
'exp', EXTRACT(EPOCH FROM s.expires_at)::bigint,
|
||||||
|
'iat', EXTRACT(EPOCH FROM s.created_at)::bigint
|
||||||
|
)
|
||||||
|
INTO v_row
|
||||||
|
FROM user_sessions s
|
||||||
|
JOIN users u ON u.id = s.user_id
|
||||||
|
WHERE s.session_token = p_token
|
||||||
|
AND s.expires_at > now()
|
||||||
|
AND u.is_active = true;
|
||||||
|
|
||||||
|
IF v_row IS NULL THEN
|
||||||
|
RETURN QUERY SELECT true, null::text, '{"active":false}'::jsonb;
|
||||||
|
ELSE
|
||||||
|
RETURN QUERY SELECT true, null::text, v_row;
|
||||||
|
END IF;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_oauth_revoke(p_token text)
|
||||||
|
RETURNS TABLE(p_success bool, p_error text)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
BEGIN
|
||||||
|
DELETE FROM user_sessions WHERE session_token = p_token;
|
||||||
|
RETURN QUERY SELECT true, null::text;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|||||||
@@ -135,12 +135,6 @@ func (a *JWTAuthenticatorExample) Login(ctx context.Context, req LoginRequest) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
|
func (a *JWTAuthenticatorExample) Logout(ctx context.Context, req LogoutRequest) error {
|
||||||
// For JWT, logout could involve token blacklisting
|
|
||||||
// Add token to blacklist table
|
|
||||||
// err := a.db.WithContext(ctx).Table("token_blacklist").Create(map[string]interface{}{
|
|
||||||
// "token": req.Token,
|
|
||||||
// "expires_at": time.Now().Add(24 * time.Hour),
|
|
||||||
// }).Error
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
"github.com/bitechdev/ResolveSpec/pkg/logger"
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/modelregistry"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SecurityContext is a generic interface that any spec can implement to integrate with security features
|
// SecurityContext is a generic interface that any spec can implement to integrate with security features
|
||||||
@@ -226,6 +227,122 @@ func ApplyColumnSecurity(secCtx SecurityContext, securityList *SecurityList) err
|
|||||||
return applyColumnSecurity(secCtx, securityList)
|
return applyColumnSecurity(secCtx, securityList)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkModelUpdateAllowed returns an error if CanUpdate is false for the model.
|
||||||
|
// Rules are read from context (set by NewModelAuthMiddleware) with a fallback to the model registry.
|
||||||
|
func checkModelUpdateAllowed(secCtx SecurityContext) error {
|
||||||
|
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||||
|
if !ok {
|
||||||
|
schema := secCtx.GetSchema()
|
||||||
|
entity := secCtx.GetEntity()
|
||||||
|
var err error
|
||||||
|
if schema != "" {
|
||||||
|
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||||
|
}
|
||||||
|
if err != nil || schema == "" {
|
||||||
|
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil // model not registered, allow by default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !rules.CanUpdate {
|
||||||
|
return fmt.Errorf("update not allowed for %s", secCtx.GetEntity())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkModelDeleteAllowed returns an error if CanDelete is false for the model.
|
||||||
|
// Rules are read from context (set by NewModelAuthMiddleware) with a fallback to the model registry.
|
||||||
|
func checkModelDeleteAllowed(secCtx SecurityContext) error {
|
||||||
|
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||||
|
if !ok {
|
||||||
|
schema := secCtx.GetSchema()
|
||||||
|
entity := secCtx.GetEntity()
|
||||||
|
var err error
|
||||||
|
if schema != "" {
|
||||||
|
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||||
|
}
|
||||||
|
if err != nil || schema == "" {
|
||||||
|
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil // model not registered, allow by default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !rules.CanDelete {
|
||||||
|
return fmt.Errorf("delete not allowed for %s", secCtx.GetEntity())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckModelAuthAllowed checks whether the requested operation is permitted based on
|
||||||
|
// model rules and the current user's authentication state. It is intended for use in
|
||||||
|
// a BeforeHandle hook, fired after model resolution.
|
||||||
|
//
|
||||||
|
// Logic:
|
||||||
|
// 1. Load model rules from context (set by NewModelAuthMiddleware) or fall back to registry.
|
||||||
|
// 2. SecurityDisabled → allow.
|
||||||
|
// 3. operation == "read" && CanPublicRead → allow.
|
||||||
|
// 4. operation == "create" && CanPublicCreate → allow.
|
||||||
|
// 5. operation == "update" && CanPublicUpdate → allow.
|
||||||
|
// 6. operation == "delete" && CanPublicDelete → allow.
|
||||||
|
// 7. Guest (UserID == 0) → return "authentication required".
|
||||||
|
// 8. Authenticated user → allow (operation-specific checks remain in BeforeUpdate/BeforeDelete).
|
||||||
|
func CheckModelAuthAllowed(secCtx SecurityContext, operation string) error {
|
||||||
|
rules, ok := GetModelRulesFromContext(secCtx.GetContext())
|
||||||
|
if !ok {
|
||||||
|
schema := secCtx.GetSchema()
|
||||||
|
entity := secCtx.GetEntity()
|
||||||
|
var err error
|
||||||
|
if schema != "" {
|
||||||
|
rules, err = modelregistry.GetModelRulesByName(fmt.Sprintf("%s.%s", schema, entity))
|
||||||
|
}
|
||||||
|
if err != nil || schema == "" {
|
||||||
|
rules, err = modelregistry.GetModelRulesByName(entity)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
// Model not registered - fall through to auth check
|
||||||
|
userID, _ := secCtx.GetUserID()
|
||||||
|
if userID == 0 {
|
||||||
|
return fmt.Errorf("authentication required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rules.SecurityDisabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if operation == "read" && rules.CanPublicRead {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if operation == "create" && rules.CanPublicCreate {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if operation == "update" && rules.CanPublicUpdate {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if operation == "delete" && rules.CanPublicDelete {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
userID, _ := secCtx.GetUserID()
|
||||||
|
if userID == 0 {
|
||||||
|
return fmt.Errorf("authentication required")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckModelUpdateAllowed is the public wrapper for checkModelUpdateAllowed.
|
||||||
|
func CheckModelUpdateAllowed(secCtx SecurityContext) error {
|
||||||
|
return checkModelUpdateAllowed(secCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckModelDeleteAllowed is the public wrapper for checkModelDeleteAllowed.
|
||||||
|
func CheckModelDeleteAllowed(secCtx SecurityContext) error {
|
||||||
|
return checkModelDeleteAllowed(secCtx)
|
||||||
|
}
|
||||||
|
|
||||||
// Helper functions
|
// Helper functions
|
||||||
|
|
||||||
func contains(s, substr string) bool {
|
func contains(s, substr string) bool {
|
||||||
|
|||||||
@@ -57,17 +57,52 @@ type LogoutRequest struct {
|
|||||||
UserID int `json:"user_id"`
|
UserID int `json:"user_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PasswordResetRequest initiates a password reset for a user
|
||||||
|
type PasswordResetRequest struct {
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
Username string `json:"username,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasswordResetResponse is returned when a reset is initiated
|
||||||
|
type PasswordResetResponse struct {
|
||||||
|
// Token is the reset token to be delivered out-of-band (e.g. email).
|
||||||
|
// The stored procedure may return it for delivery or leave it empty
|
||||||
|
// if the delivery is handled entirely in the database.
|
||||||
|
Token string `json:"token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"` // seconds
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasswordResetCompleteRequest completes a password reset using the token
|
||||||
|
type PasswordResetCompleteRequest struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
NewPassword string `json:"new_password"`
|
||||||
|
}
|
||||||
|
|
||||||
// Authenticator handles user authentication operations
|
// Authenticator handles user authentication operations
|
||||||
type Authenticator interface {
|
type Authenticator interface {
|
||||||
// Login authenticates credentials and returns a token
|
// Login authenticates credentials and returns a token
|
||||||
Login(ctx context.Context, req LoginRequest) (*LoginResponse, error)
|
Login(ctx context.Context, req LoginRequest) (*LoginResponse, error)
|
||||||
|
|
||||||
|
// LoginWithCookie authenticates credentials and, when cookie sessions are enabled,
|
||||||
|
// writes the session cookie to w. Implementations that do not support cookies
|
||||||
|
// should delegate to Login and ignore w.
|
||||||
|
LoginWithCookie(ctx context.Context, req LoginRequest, w http.ResponseWriter) (*LoginResponse, error)
|
||||||
|
|
||||||
// Logout invalidates a user's session/token
|
// Logout invalidates a user's session/token
|
||||||
Logout(ctx context.Context, req LogoutRequest) error
|
Logout(ctx context.Context, req LogoutRequest) error
|
||||||
|
|
||||||
|
// LogoutWithCookie invalidates a user's session/token and, when cookie sessions are
|
||||||
|
// enabled, clears the session cookie on w. Implementations that do not support cookies
|
||||||
|
// should delegate to Logout and ignore w.
|
||||||
|
LogoutWithCookie(ctx context.Context, req LogoutRequest, w http.ResponseWriter) error
|
||||||
|
|
||||||
// Authenticate extracts and validates user from HTTP request
|
// Authenticate extracts and validates user from HTTP request
|
||||||
// Returns UserContext or error if authentication fails
|
// Returns UserContext or error if authentication fails
|
||||||
Authenticate(r *http.Request) (*UserContext, error)
|
Authenticate(r *http.Request) (*UserContext, error)
|
||||||
|
|
||||||
|
// SetAuthenticateCallback registers a fallback called when primary authentication fails.
|
||||||
|
// If the callback returns a non-nil UserContext, that result is used instead of the error.
|
||||||
|
SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Registrable allows providers to support user registration
|
// Registrable allows providers to support user registration
|
||||||
@@ -114,3 +149,12 @@ type Cacheable interface {
|
|||||||
// ClearCache clears cached security rules for a user/entity
|
// ClearCache clears cached security rules for a user/entity
|
||||||
ClearCache(ctx context.Context, userID int, schema, table string) error
|
ClearCache(ctx context.Context, userID int, schema, table string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PasswordResettable allows providers to support self-service password reset
|
||||||
|
type PasswordResettable interface {
|
||||||
|
// RequestPasswordReset creates a reset token for the given email/username
|
||||||
|
RequestPasswordReset(ctx context.Context, req PasswordResetRequest) (*PasswordResetResponse, error)
|
||||||
|
|
||||||
|
// CompletePasswordReset validates the token and sets the new password
|
||||||
|
CompletePasswordReset(ctx context.Context, req PasswordResetCompleteRequest) error
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,81 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// hashSHA256Hex returns the lowercase hex SHA-256 digest of the given string.
|
||||||
|
// Used by all keystore implementations to hash raw keys before storage or lookup.
|
||||||
|
func hashSHA256Hex(raw string) string {
|
||||||
|
sum := sha256.Sum256([]byte(raw))
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyType identifies the category of an auth key.
|
||||||
|
type KeyType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// KeyTypeJWTSecret is a per-user JWT signing secret for token generation.
|
||||||
|
KeyTypeJWTSecret KeyType = "jwt_secret"
|
||||||
|
// KeyTypeHeaderAPI is a static API key sent via a request header.
|
||||||
|
KeyTypeHeaderAPI KeyType = "header_api"
|
||||||
|
// KeyTypeOAuth2 holds OAuth2 client credentials (client_id / client_secret).
|
||||||
|
KeyTypeOAuth2 KeyType = "oauth2"
|
||||||
|
// KeyTypeGenericAPI is a generic application API key.
|
||||||
|
KeyTypeGenericAPI KeyType = "api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UserKey represents a single named auth key belonging to a user.
|
||||||
|
// KeyHash stores the SHA-256 hex digest of the raw key; the raw key is never persisted.
|
||||||
|
type UserKey struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
KeyType KeyType `json:"key_type"`
|
||||||
|
KeyHash string `json:"key_hash"` // SHA-256 hex; never the raw key
|
||||||
|
Name string `json:"name"`
|
||||||
|
Scopes []string `json:"scopes,omitempty"`
|
||||||
|
Meta map[string]any `json:"meta,omitempty"`
|
||||||
|
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
LastUsedAt *time.Time `json:"last_used_at,omitempty"`
|
||||||
|
IsActive bool `json:"is_active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateKeyRequest specifies the parameters for a new key.
|
||||||
|
type CreateKeyRequest struct {
|
||||||
|
UserID int
|
||||||
|
KeyType KeyType
|
||||||
|
Name string
|
||||||
|
Scopes []string
|
||||||
|
Meta map[string]any
|
||||||
|
ExpiresAt *time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateKeyResponse is returned exactly once when a key is created.
|
||||||
|
// The caller is responsible for persisting RawKey; it is not stored anywhere.
|
||||||
|
type CreateKeyResponse struct {
|
||||||
|
Key UserKey
|
||||||
|
RawKey string // crypto/rand 32 bytes, base64url-encoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyStore manages per-user auth keys with pluggable storage backends.
|
||||||
|
// Implementations: ConfigKeyStore (static list) and DatabaseKeyStore (stored procedures).
|
||||||
|
type KeyStore interface {
|
||||||
|
// CreateKey generates a new key, stores its hash, and returns the raw key once.
|
||||||
|
CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error)
|
||||||
|
|
||||||
|
// GetUserKeys returns all active, non-expired keys for a user.
|
||||||
|
// Pass an empty KeyType to return all types.
|
||||||
|
GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error)
|
||||||
|
|
||||||
|
// DeleteKey soft-deletes a key by ID after verifying ownership.
|
||||||
|
DeleteKey(ctx context.Context, userID int, keyID int64) error
|
||||||
|
|
||||||
|
// ValidateKey checks a raw key, returns the matching UserKey on success.
|
||||||
|
// The implementation hashes the raw key before any lookup.
|
||||||
|
// Pass an empty KeyType to accept any type.
|
||||||
|
ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error)
|
||||||
|
}
|
||||||
@@ -0,0 +1,119 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// KeyStoreAuthenticator implements the Authenticator interface using a KeyStore.
|
||||||
|
// It is suitable for long-lived application credentials (API keys, JWT secrets, etc.)
|
||||||
|
// rather than interactive sessions. Login and Logout are not supported — key lifecycle
|
||||||
|
// is managed directly through the KeyStore.
|
||||||
|
//
|
||||||
|
// Key extraction order:
|
||||||
|
// 1. Authorization: Bearer <key>
|
||||||
|
// 2. Authorization: ApiKey <key>
|
||||||
|
// 3. X-API-Key header
|
||||||
|
type KeyStoreAuthenticator struct {
|
||||||
|
keyStore KeyStore
|
||||||
|
keyType KeyType // empty = accept any type
|
||||||
|
authenticateCallback func(r *http.Request) (*UserContext, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewKeyStoreAuthenticator creates a KeyStoreAuthenticator.
|
||||||
|
// Pass an empty keyType to accept keys of any type.
|
||||||
|
func NewKeyStoreAuthenticator(ks KeyStore, keyType KeyType) *KeyStoreAuthenticator {
|
||||||
|
return &KeyStoreAuthenticator{keyStore: ks, keyType: keyType}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Login is not supported for keystore authentication.
|
||||||
|
func (a *KeyStoreAuthenticator) Login(_ context.Context, _ LoginRequest) (*LoginResponse, error) {
|
||||||
|
return nil, fmt.Errorf("keystore authenticator does not support login")
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginWithCookie is not supported for keystore authentication.
|
||||||
|
func (a *KeyStoreAuthenticator) LoginWithCookie(_ context.Context, _ LoginRequest, _ http.ResponseWriter) (*LoginResponse, error) {
|
||||||
|
return nil, fmt.Errorf("keystore authenticator does not support login")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logout is not supported for keystore authentication.
|
||||||
|
func (a *KeyStoreAuthenticator) Logout(_ context.Context, _ LogoutRequest) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogoutWithCookie is not supported for keystore authentication.
|
||||||
|
func (a *KeyStoreAuthenticator) LogoutWithCookie(_ context.Context, _ LogoutRequest, _ http.ResponseWriter) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAuthenticateCallback registers a fallback called when key authentication fails.
|
||||||
|
func (a *KeyStoreAuthenticator) SetAuthenticateCallback(fn func(r *http.Request) (*UserContext, error)) {
|
||||||
|
a.authenticateCallback = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authenticate extracts an API key from the request and validates it against the KeyStore.
|
||||||
|
// Returns a UserContext built from the matching UserKey on success.
|
||||||
|
func (a *KeyStoreAuthenticator) Authenticate(r *http.Request) (*UserContext, error) {
|
||||||
|
rawKey := extractAPIKey(r)
|
||||||
|
if rawKey == "" {
|
||||||
|
if a.authenticateCallback != nil {
|
||||||
|
return a.authenticateCallback(r)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("API key required (Authorization: Bearer/ApiKey <key> or X-API-Key header)")
|
||||||
|
}
|
||||||
|
|
||||||
|
userKey, err := a.keyStore.ValidateKey(r.Context(), rawKey, a.keyType)
|
||||||
|
if err != nil {
|
||||||
|
if a.authenticateCallback != nil {
|
||||||
|
return a.authenticateCallback(r)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("invalid API key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return userKeyToUserContext(userKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractAPIKey extracts a raw key from the request using the following precedence:
|
||||||
|
// 1. Authorization: Bearer <key>
|
||||||
|
// 2. Authorization: ApiKey <key>
|
||||||
|
// 3. X-API-Key header
|
||||||
|
func extractAPIKey(r *http.Request) string {
|
||||||
|
if auth := r.Header.Get("Authorization"); auth != "" {
|
||||||
|
if after, ok := strings.CutPrefix(auth, "Bearer "); ok {
|
||||||
|
return strings.TrimSpace(after)
|
||||||
|
}
|
||||||
|
if after, ok := strings.CutPrefix(auth, "ApiKey "); ok {
|
||||||
|
return strings.TrimSpace(after)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(r.Header.Get("X-API-Key"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// userKeyToUserContext converts a UserKey into a UserContext.
|
||||||
|
// Scopes are mapped to Roles. Key type and name are stored in Claims.
|
||||||
|
func userKeyToUserContext(k *UserKey) *UserContext {
|
||||||
|
claims := map[string]any{
|
||||||
|
"key_type": string(k.KeyType),
|
||||||
|
"key_name": k.Name,
|
||||||
|
}
|
||||||
|
|
||||||
|
meta := k.Meta
|
||||||
|
if meta == nil {
|
||||||
|
meta = map[string]any{}
|
||||||
|
}
|
||||||
|
|
||||||
|
roles := k.Scopes
|
||||||
|
if roles == nil {
|
||||||
|
roles = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &UserContext{
|
||||||
|
UserID: k.UserID,
|
||||||
|
SessionID: fmt.Sprintf("key:%d", k.ID),
|
||||||
|
Roles: roles,
|
||||||
|
Claims: claims,
|
||||||
|
Meta: meta,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,149 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigKeyStore is an in-memory keystore backed by a static slice of UserKey values.
|
||||||
|
// It is designed for config-file driven setups (e.g. service accounts defined in YAML)
|
||||||
|
// with a small, bounded number of keys. For large or dynamic key sets use DatabaseKeyStore.
|
||||||
|
//
|
||||||
|
// Pre-existing entries must have KeyHash set to the SHA-256 hex of the intended raw key.
|
||||||
|
// Keys created at runtime via CreateKey are held in memory only and lost on restart.
|
||||||
|
type ConfigKeyStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
keys []UserKey
|
||||||
|
next int64 // monotonic ID counter for runtime-created keys (atomic)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConfigKeyStore creates a ConfigKeyStore seeded with the provided keys.
|
||||||
|
// Pass nil or an empty slice to start with no pre-loaded keys.
|
||||||
|
// Zero-value entries (CreatedAt is zero) are treated as active and assigned the current time.
|
||||||
|
func NewConfigKeyStore(keys []UserKey) *ConfigKeyStore {
|
||||||
|
var maxID int64
|
||||||
|
copied := make([]UserKey, len(keys))
|
||||||
|
copy(copied, keys)
|
||||||
|
for i := range copied {
|
||||||
|
if copied[i].CreatedAt.IsZero() {
|
||||||
|
copied[i].IsActive = true
|
||||||
|
copied[i].CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
if copied[i].ID > maxID {
|
||||||
|
maxID = copied[i].ID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &ConfigKeyStore{keys: copied, next: maxID}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateKey generates a new raw key, stores its SHA-256 hash, and returns the raw key once.
|
||||||
|
func (s *ConfigKeyStore) CreateKey(_ context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
|
||||||
|
rawBytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(rawBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate key material: %w", err)
|
||||||
|
}
|
||||||
|
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
|
||||||
|
hash := hashSHA256Hex(rawKey)
|
||||||
|
|
||||||
|
id := atomic.AddInt64(&s.next, 1)
|
||||||
|
key := UserKey{
|
||||||
|
ID: id,
|
||||||
|
UserID: req.UserID,
|
||||||
|
KeyType: req.KeyType,
|
||||||
|
KeyHash: hash,
|
||||||
|
Name: req.Name,
|
||||||
|
Scopes: req.Scopes,
|
||||||
|
Meta: req.Meta,
|
||||||
|
ExpiresAt: req.ExpiresAt,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
IsActive: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.keys = append(s.keys, key)
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserKeys returns all active, non-expired keys for the given user.
|
||||||
|
// Pass an empty KeyType to return all types.
|
||||||
|
func (s *ConfigKeyStore) GetUserKeys(_ context.Context, userID int, keyType KeyType) ([]UserKey, error) {
|
||||||
|
now := time.Now()
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
var result []UserKey
|
||||||
|
for i := range s.keys {
|
||||||
|
k := &s.keys[i]
|
||||||
|
if k.UserID != userID || !k.IsActive {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if keyType != "" && k.KeyType != keyType {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, *k)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteKey soft-deletes a key by setting IsActive to false after ownership verification.
|
||||||
|
func (s *ConfigKeyStore) DeleteKey(_ context.Context, userID int, keyID int64) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for i := range s.keys {
|
||||||
|
if s.keys[i].ID == keyID {
|
||||||
|
if s.keys[i].UserID != userID {
|
||||||
|
return fmt.Errorf("key not found or permission denied")
|
||||||
|
}
|
||||||
|
s.keys[i].IsActive = false
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("key not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateKey hashes the raw key and finds a matching, active, non-expired entry.
|
||||||
|
// Uses constant-time comparison to prevent timing side-channels.
|
||||||
|
// Pass an empty KeyType to accept any type.
|
||||||
|
func (s *ConfigKeyStore) ValidateKey(_ context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
|
||||||
|
hash := hashSHA256Hex(rawKey)
|
||||||
|
hashBytes, _ := hex.DecodeString(hash)
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Write lock: ValidateKey updates LastUsedAt on the matched entry.
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for i := range s.keys {
|
||||||
|
k := &s.keys[i]
|
||||||
|
if !k.IsActive {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if k.ExpiresAt != nil && k.ExpiresAt.Before(now) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if keyType != "" && k.KeyType != keyType {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stored, _ := hex.DecodeString(k.KeyHash)
|
||||||
|
if subtle.ConstantTimeCompare(hashBytes, stored) != 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
k.LastUsedAt = &now
|
||||||
|
result := *k
|
||||||
|
return &result, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("invalid or expired key")
|
||||||
|
}
|
||||||
@@ -0,0 +1,256 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/bitechdev/ResolveSpec/pkg/cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DatabaseKeyStoreOptions configures DatabaseKeyStore.
|
||||||
|
type DatabaseKeyStoreOptions struct {
|
||||||
|
// Cache is an optional cache instance. If nil, uses the default cache.
|
||||||
|
Cache *cache.Cache
|
||||||
|
// CacheTTL is the duration to cache ValidateKey results.
|
||||||
|
// Default: 2 minutes.
|
||||||
|
CacheTTL time.Duration
|
||||||
|
// SQLNames provides custom procedure names. If nil, uses DefaultKeyStoreSQLNames().
|
||||||
|
SQLNames *KeyStoreSQLNames
|
||||||
|
// DBFactory is called to obtain a fresh *sql.DB when the existing connection is closed.
|
||||||
|
// If nil, reconnection is disabled.
|
||||||
|
DBFactory func() (*sql.DB, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DatabaseKeyStore is a KeyStore backed by PostgreSQL stored procedures.
|
||||||
|
// All DB operations go through configurable procedure names; the raw key is
|
||||||
|
// never passed to the database.
|
||||||
|
//
|
||||||
|
// See keystore_schema.sql for the required table and procedure definitions.
|
||||||
|
//
|
||||||
|
// Note: DeleteKey invalidates the cache entry for the deleted key. Due to the
|
||||||
|
// cache TTL, a deleted key may continue to authenticate for up to CacheTTL
|
||||||
|
// (default 2 minutes) if the cache entry cannot be invalidated.
|
||||||
|
type DatabaseKeyStore struct {
|
||||||
|
db *sql.DB
|
||||||
|
dbMu sync.RWMutex
|
||||||
|
dbFactory func() (*sql.DB, error)
|
||||||
|
sqlNames *KeyStoreSQLNames
|
||||||
|
cache *cache.Cache
|
||||||
|
cacheTTL time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDatabaseKeyStore creates a DatabaseKeyStore with optional configuration.
|
||||||
|
func NewDatabaseKeyStore(db *sql.DB, opts ...DatabaseKeyStoreOptions) *DatabaseKeyStore {
|
||||||
|
o := DatabaseKeyStoreOptions{}
|
||||||
|
if len(opts) > 0 {
|
||||||
|
o = opts[0]
|
||||||
|
}
|
||||||
|
if o.CacheTTL == 0 {
|
||||||
|
o.CacheTTL = 2 * time.Minute
|
||||||
|
}
|
||||||
|
c := o.Cache
|
||||||
|
if c == nil {
|
||||||
|
c = cache.GetDefaultCache()
|
||||||
|
}
|
||||||
|
names := MergeKeyStoreSQLNames(DefaultKeyStoreSQLNames(), o.SQLNames)
|
||||||
|
return &DatabaseKeyStore{
|
||||||
|
db: db,
|
||||||
|
dbFactory: o.DBFactory,
|
||||||
|
sqlNames: names,
|
||||||
|
cache: c,
|
||||||
|
cacheTTL: o.CacheTTL,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *DatabaseKeyStore) getDB() *sql.DB {
|
||||||
|
ks.dbMu.RLock()
|
||||||
|
defer ks.dbMu.RUnlock()
|
||||||
|
return ks.db
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ks *DatabaseKeyStore) reconnectDB() error {
|
||||||
|
if ks.dbFactory == nil {
|
||||||
|
return fmt.Errorf("no db factory configured for reconnect")
|
||||||
|
}
|
||||||
|
newDB, err := ks.dbFactory()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ks.dbMu.Lock()
|
||||||
|
ks.db = newDB
|
||||||
|
ks.dbMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateKey generates a raw key, stores its SHA-256 hash via the create procedure,
|
||||||
|
// and returns the raw key once.
|
||||||
|
func (ks *DatabaseKeyStore) CreateKey(ctx context.Context, req CreateKeyRequest) (*CreateKeyResponse, error) {
|
||||||
|
rawBytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(rawBytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate key material: %w", err)
|
||||||
|
}
|
||||||
|
rawKey := base64.RawURLEncoding.EncodeToString(rawBytes)
|
||||||
|
hash := hashSHA256Hex(rawKey)
|
||||||
|
|
||||||
|
type createRequest struct {
|
||||||
|
UserID int `json:"user_id"`
|
||||||
|
KeyType KeyType `json:"key_type"`
|
||||||
|
KeyHash string `json:"key_hash"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Scopes []string `json:"scopes,omitempty"`
|
||||||
|
Meta map[string]any `json:"meta,omitempty"`
|
||||||
|
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
reqJSON, err := json.Marshal(createRequest{
|
||||||
|
UserID: req.UserID,
|
||||||
|
KeyType: req.KeyType,
|
||||||
|
KeyHash: hash,
|
||||||
|
Name: req.Name,
|
||||||
|
Scopes: req.Scopes,
|
||||||
|
Meta: req.Meta,
|
||||||
|
ExpiresAt: req.ExpiresAt,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal create key request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var keyJSON sql.NullString
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1::jsonb)`, ks.sqlNames.CreateKey)
|
||||||
|
if err = ks.getDB().QueryRowContext(ctx, query, string(reqJSON)).Scan(&success, &errorMsg, &keyJSON); err != nil {
|
||||||
|
return nil, fmt.Errorf("create key procedure failed: %w", err)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
return nil, errors.New(nullStringOr(errorMsg, "create key failed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
var key UserKey
|
||||||
|
if err = json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse created key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CreateKeyResponse{Key: key, RawKey: rawKey}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserKeys returns all active, non-expired keys for the given user.
|
||||||
|
// Pass an empty KeyType to return all types.
|
||||||
|
func (ks *DatabaseKeyStore) GetUserKeys(ctx context.Context, userID int, keyType KeyType) ([]UserKey, error) {
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var keysJSON sql.NullString
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_keys::text FROM %s($1, $2)`, ks.sqlNames.GetUserKeys)
|
||||||
|
if err := ks.getDB().QueryRowContext(ctx, query, userID, string(keyType)).Scan(&success, &errorMsg, &keysJSON); err != nil {
|
||||||
|
return nil, fmt.Errorf("get user keys procedure failed: %w", err)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
return nil, errors.New(nullStringOr(errorMsg, "get user keys failed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
var keys []UserKey
|
||||||
|
if keysJSON.Valid && keysJSON.String != "" && keysJSON.String != "[]" {
|
||||||
|
if err := json.Unmarshal([]byte(keysJSON.String), &keys); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse user keys: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if keys == nil {
|
||||||
|
keys = []UserKey{}
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteKey soft-deletes a key after verifying ownership and invalidates its cache entry.
|
||||||
|
// The delete procedure returns the key_hash so no separate lookup is needed.
|
||||||
|
// Note: cache invalidation is best-effort; a cached entry may persist for up to CacheTTL.
|
||||||
|
func (ks *DatabaseKeyStore) DeleteKey(ctx context.Context, userID int, keyID int64) error {
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var keyHash sql.NullString
|
||||||
|
|
||||||
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_key_hash FROM %s($1, $2)`, ks.sqlNames.DeleteKey)
|
||||||
|
if err := ks.getDB().QueryRowContext(ctx, query, userID, keyID).Scan(&success, &errorMsg, &keyHash); err != nil {
|
||||||
|
return fmt.Errorf("delete key procedure failed: %w", err)
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
return errors.New(nullStringOr(errorMsg, "delete key failed"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if keyHash.Valid && keyHash.String != "" && ks.cache != nil {
|
||||||
|
_ = ks.cache.Delete(ctx, keystoreCacheKey(keyHash.String))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateKey hashes the raw key and calls the validate procedure.
|
||||||
|
// Results are cached for CacheTTL to reduce DB load on hot paths.
|
||||||
|
func (ks *DatabaseKeyStore) ValidateKey(ctx context.Context, rawKey string, keyType KeyType) (*UserKey, error) {
|
||||||
|
hash := hashSHA256Hex(rawKey)
|
||||||
|
cacheKey := keystoreCacheKey(hash)
|
||||||
|
|
||||||
|
if ks.cache != nil {
|
||||||
|
var cached UserKey
|
||||||
|
if err := ks.cache.Get(ctx, cacheKey, &cached); err == nil {
|
||||||
|
if cached.IsActive {
|
||||||
|
return &cached, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("invalid or expired key")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var success bool
|
||||||
|
var errorMsg sql.NullString
|
||||||
|
var keyJSON sql.NullString
|
||||||
|
|
||||||
|
runQuery := func() error {
|
||||||
|
query := fmt.Sprintf(`SELECT p_success, p_error, p_key::text FROM %s($1, $2)`, ks.sqlNames.ValidateKey)
|
||||||
|
return ks.getDB().QueryRowContext(ctx, query, hash, string(keyType)).Scan(&success, &errorMsg, &keyJSON)
|
||||||
|
}
|
||||||
|
if err := runQuery(); err != nil {
|
||||||
|
if isDBClosed(err) {
|
||||||
|
if reconnErr := ks.reconnectDB(); reconnErr == nil {
|
||||||
|
err = runQuery()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("validate key procedure failed: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !success {
|
||||||
|
return nil, errors.New(nullStringOr(errorMsg, "invalid or expired key"))
|
||||||
|
}
|
||||||
|
|
||||||
|
var key UserKey
|
||||||
|
if err := json.Unmarshal([]byte(keyJSON.String), &key); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse validated key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ks.cache != nil {
|
||||||
|
_ = ks.cache.Set(ctx, cacheKey, key, ks.cacheTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func keystoreCacheKey(hash string) string {
|
||||||
|
return "keystore:validate:" + hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// nullStringOr returns s.String if valid, otherwise the fallback.
|
||||||
|
func nullStringOr(s sql.NullString, fallback string) string {
|
||||||
|
if s.Valid && s.String != "" {
|
||||||
|
return s.String
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
@@ -0,0 +1,187 @@
|
|||||||
|
-- Keystore schema for per-user auth keys
|
||||||
|
-- Apply alongside database_schema.sql (requires the users table)
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS user_keys (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
key_type VARCHAR(50) NOT NULL,
|
||||||
|
key_hash VARCHAR(64) NOT NULL UNIQUE, -- SHA-256 hex digest (64 chars)
|
||||||
|
name VARCHAR(255) NOT NULL DEFAULT '',
|
||||||
|
scopes TEXT, -- JSON array, e.g. '["read","write"]'
|
||||||
|
meta JSONB,
|
||||||
|
expires_at TIMESTAMP,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_used_at TIMESTAMP,
|
||||||
|
is_active BOOLEAN DEFAULT true
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_keys_user_id ON user_keys(user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_keys_key_hash ON user_keys(key_hash);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_keys_key_type ON user_keys(key_type);
|
||||||
|
|
||||||
|
-- resolvespec_keystore_get_user_keys
|
||||||
|
-- Returns all active, non-expired keys for a user.
|
||||||
|
-- Pass empty p_key_type to return all key types.
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_keystore_get_user_keys(
|
||||||
|
p_user_id INTEGER,
|
||||||
|
p_key_type TEXT DEFAULT ''
|
||||||
|
)
|
||||||
|
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_keys JSONB)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
v_keys JSONB;
|
||||||
|
BEGIN
|
||||||
|
SELECT COALESCE(
|
||||||
|
jsonb_agg(
|
||||||
|
jsonb_build_object(
|
||||||
|
'id', k.id,
|
||||||
|
'user_id', k.user_id,
|
||||||
|
'key_type', k.key_type,
|
||||||
|
'name', k.name,
|
||||||
|
'scopes', CASE WHEN k.scopes IS NOT NULL THEN k.scopes::jsonb ELSE '[]'::jsonb END,
|
||||||
|
'meta', COALESCE(k.meta, '{}'::jsonb),
|
||||||
|
'expires_at', k.expires_at,
|
||||||
|
'created_at', k.created_at,
|
||||||
|
'last_used_at', k.last_used_at,
|
||||||
|
'is_active', k.is_active
|
||||||
|
)
|
||||||
|
),
|
||||||
|
'[]'::jsonb
|
||||||
|
)
|
||||||
|
INTO v_keys
|
||||||
|
FROM user_keys k
|
||||||
|
WHERE k.user_id = p_user_id
|
||||||
|
AND k.is_active = true
|
||||||
|
AND (k.expires_at IS NULL OR k.expires_at > NOW())
|
||||||
|
AND (p_key_type = '' OR k.key_type = p_key_type);
|
||||||
|
|
||||||
|
RETURN QUERY SELECT true, NULL::TEXT, v_keys;
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
-- resolvespec_keystore_create_key
|
||||||
|
-- Inserts a new key row. key_hash is provided by the caller (Go hashes the raw key).
|
||||||
|
-- Returns the created key record (without key_hash).
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_keystore_create_key(
|
||||||
|
p_request JSONB
|
||||||
|
)
|
||||||
|
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
v_id BIGINT;
|
||||||
|
v_created_at TIMESTAMP;
|
||||||
|
v_key JSONB;
|
||||||
|
BEGIN
|
||||||
|
INSERT INTO user_keys (user_id, key_type, key_hash, name, scopes, meta, expires_at)
|
||||||
|
VALUES (
|
||||||
|
(p_request->>'user_id')::INTEGER,
|
||||||
|
p_request->>'key_type',
|
||||||
|
p_request->>'key_hash',
|
||||||
|
COALESCE(p_request->>'name', ''),
|
||||||
|
p_request->>'scopes',
|
||||||
|
p_request->'meta',
|
||||||
|
CASE WHEN p_request->>'expires_at' IS NOT NULL
|
||||||
|
THEN (p_request->>'expires_at')::TIMESTAMP
|
||||||
|
ELSE NULL
|
||||||
|
END
|
||||||
|
)
|
||||||
|
RETURNING id, created_at INTO v_id, v_created_at;
|
||||||
|
|
||||||
|
v_key := jsonb_build_object(
|
||||||
|
'id', v_id,
|
||||||
|
'user_id', (p_request->>'user_id')::INTEGER,
|
||||||
|
'key_type', p_request->>'key_type',
|
||||||
|
'name', COALESCE(p_request->>'name', ''),
|
||||||
|
'scopes', CASE WHEN p_request->>'scopes' IS NOT NULL
|
||||||
|
THEN (p_request->>'scopes')::jsonb
|
||||||
|
ELSE '[]'::jsonb END,
|
||||||
|
'meta', COALESCE(p_request->'meta', '{}'::jsonb),
|
||||||
|
'expires_at', p_request->>'expires_at',
|
||||||
|
'created_at', v_created_at,
|
||||||
|
'is_active', true
|
||||||
|
);
|
||||||
|
|
||||||
|
RETURN QUERY SELECT true, NULL::TEXT, v_key;
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
-- resolvespec_keystore_delete_key
|
||||||
|
-- Soft-deletes a key (is_active = false) after verifying ownership.
|
||||||
|
-- Returns p_key_hash so the caller can invalidate cache entries without a separate query.
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_keystore_delete_key(
|
||||||
|
p_user_id INTEGER,
|
||||||
|
p_key_id BIGINT
|
||||||
|
)
|
||||||
|
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key_hash TEXT)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
v_hash TEXT;
|
||||||
|
BEGIN
|
||||||
|
UPDATE user_keys
|
||||||
|
SET is_active = false
|
||||||
|
WHERE id = p_key_id AND user_id = p_user_id AND is_active = true
|
||||||
|
RETURNING key_hash INTO v_hash;
|
||||||
|
|
||||||
|
IF NOT FOUND THEN
|
||||||
|
RETURN QUERY SELECT false, 'key not found or already deleted'::TEXT, NULL::TEXT;
|
||||||
|
RETURN;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
RETURN QUERY SELECT true, NULL::TEXT, v_hash;
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RETURN QUERY SELECT false, SQLERRM, NULL::TEXT;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
-- resolvespec_keystore_validate_key
|
||||||
|
-- Looks up a key by its SHA-256 hash, checks active status and expiry,
|
||||||
|
-- updates last_used_at, and returns the key record.
|
||||||
|
-- p_key_type can be empty to accept any key type.
|
||||||
|
CREATE OR REPLACE FUNCTION resolvespec_keystore_validate_key(
|
||||||
|
p_key_hash TEXT,
|
||||||
|
p_key_type TEXT DEFAULT ''
|
||||||
|
)
|
||||||
|
RETURNS TABLE(p_success BOOLEAN, p_error TEXT, p_key JSONB)
|
||||||
|
LANGUAGE plpgsql AS $$
|
||||||
|
DECLARE
|
||||||
|
v_key_rec user_keys%ROWTYPE;
|
||||||
|
v_key JSONB;
|
||||||
|
BEGIN
|
||||||
|
SELECT * INTO v_key_rec
|
||||||
|
FROM user_keys
|
||||||
|
WHERE key_hash = p_key_hash
|
||||||
|
AND is_active = true
|
||||||
|
AND (expires_at IS NULL OR expires_at > NOW())
|
||||||
|
AND (p_key_type = '' OR key_type = p_key_type);
|
||||||
|
|
||||||
|
IF NOT FOUND THEN
|
||||||
|
RETURN QUERY SELECT false, 'invalid or expired key'::TEXT, NULL::JSONB;
|
||||||
|
RETURN;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
UPDATE user_keys SET last_used_at = NOW() WHERE id = v_key_rec.id;
|
||||||
|
|
||||||
|
v_key := jsonb_build_object(
|
||||||
|
'id', v_key_rec.id,
|
||||||
|
'user_id', v_key_rec.user_id,
|
||||||
|
'key_type', v_key_rec.key_type,
|
||||||
|
'name', v_key_rec.name,
|
||||||
|
'scopes', CASE WHEN v_key_rec.scopes IS NOT NULL
|
||||||
|
THEN v_key_rec.scopes::jsonb
|
||||||
|
ELSE '[]'::jsonb END,
|
||||||
|
'meta', COALESCE(v_key_rec.meta, '{}'::jsonb),
|
||||||
|
'expires_at', v_key_rec.expires_at,
|
||||||
|
'created_at', v_key_rec.created_at,
|
||||||
|
'last_used_at', NOW(),
|
||||||
|
'is_active', v_key_rec.is_active
|
||||||
|
);
|
||||||
|
|
||||||
|
RETURN QUERY SELECT true, NULL::TEXT, v_key;
|
||||||
|
EXCEPTION WHEN OTHERS THEN
|
||||||
|
RETURN QUERY SELECT false, SQLERRM, NULL::JSONB;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
package security
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// KeyStoreSQLNames holds the configurable stored procedure names used by DatabaseKeyStore.
|
||||||
|
// Use DefaultKeyStoreSQLNames() for defaults and MergeKeyStoreSQLNames() for partial overrides.
|
||||||
|
type KeyStoreSQLNames struct {
|
||||||
|
GetUserKeys string // default: "resolvespec_keystore_get_user_keys"
|
||||||
|
CreateKey string // default: "resolvespec_keystore_create_key"
|
||||||
|
DeleteKey string // default: "resolvespec_keystore_delete_key"
|
||||||
|
ValidateKey string // default: "resolvespec_keystore_validate_key"
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultKeyStoreSQLNames returns a KeyStoreSQLNames with all default resolvespec_keystore_* values.
|
||||||
|
func DefaultKeyStoreSQLNames() *KeyStoreSQLNames {
|
||||||
|
return &KeyStoreSQLNames{
|
||||||
|
GetUserKeys: "resolvespec_keystore_get_user_keys",
|
||||||
|
CreateKey: "resolvespec_keystore_create_key",
|
||||||
|
DeleteKey: "resolvespec_keystore_delete_key",
|
||||||
|
ValidateKey: "resolvespec_keystore_validate_key",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeKeyStoreSQLNames returns a copy of base with any non-empty fields from override applied.
|
||||||
|
// If override is nil, a copy of base is returned.
|
||||||
|
func MergeKeyStoreSQLNames(base, override *KeyStoreSQLNames) *KeyStoreSQLNames {
|
||||||
|
if override == nil {
|
||||||
|
copied := *base
|
||||||
|
return &copied
|
||||||
|
}
|
||||||
|
merged := *base
|
||||||
|
if override.GetUserKeys != "" {
|
||||||
|
merged.GetUserKeys = override.GetUserKeys
|
||||||
|
}
|
||||||
|
if override.CreateKey != "" {
|
||||||
|
merged.CreateKey = override.CreateKey
|
||||||
|
}
|
||||||
|
if override.DeleteKey != "" {
|
||||||
|
merged.DeleteKey = override.DeleteKey
|
||||||
|
}
|
||||||
|
if override.ValidateKey != "" {
|
||||||
|
merged.ValidateKey = override.ValidateKey
|
||||||
|
}
|
||||||
|
return &merged
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateKeyStoreSQLNames checks that all non-empty procedure names are valid SQL identifiers.
|
||||||
|
func ValidateKeyStoreSQLNames(names *KeyStoreSQLNames) error {
|
||||||
|
fields := map[string]string{
|
||||||
|
"GetUserKeys": names.GetUserKeys,
|
||||||
|
"CreateKey": names.CreateKey,
|
||||||
|
"DeleteKey": names.DeleteKey,
|
||||||
|
"ValidateKey": names.ValidateKey,
|
||||||
|
}
|
||||||
|
for field, val := range fields {
|
||||||
|
if val != "" && !validSQLIdentifier.MatchString(val) {
|
||||||
|
return fmt.Errorf("KeyStoreSQLNames.%s contains invalid characters: %q", field, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user